Skip to content

Instantly share code, notes, and snippets.

@bigsnarfdude
Last active November 28, 2025 15:00
Show Gist options
  • Select an option

  • Save bigsnarfdude/0cea7b1ae7e86c044cea6ce09b773204 to your computer and use it in GitHub Desktop.

Select an option

Save bigsnarfdude/0cea7b1ae7e86c044cea6ce09b773204 to your computer and use it in GitHub Desktop.
Continuous Thinking Machines (CTM) paper thoughts

Continuous Thinking Machines (CTM) thoughts

Key Innovations

  1. Neuron-Level Models (NLMs) - Each neuron has private weights processing temporal history
  2. Neural Synchronization - Uses neuron correlations over time as representations
  3. Internal Ticks - Decoupled temporal dimension for iterative refinement

Repo Completeness: ✅ Fully Implemented

Component Status Location
Core CTM models/ctm.py (604 lines)
NLMs (SuperLinear) models/modules.py
SynapseUNET models/modules.py
All 6 tasks tasks/ subdirs
Pre-trained models Google Drive links in README
Notebooks examples/01_mnist.ipynb etc

Quick Start

# Run simplest task (parity)
python -m tasks.parity.train --iterations 75 --memory_length 25

Key Hyperparameters by Task

Task D (neurons) T (iterations) M (memory length)
ImageNet 4096 75 25
Mazes 2048 75 25
Parity 1024 75 25
CIFAR-10 256 50 15

Suggested Learning Path

  1. Start with examples/01_mnist.ipynb - Understand basics
  2. Read models/ctm.py - Core architecture
  3. Run parity task - Simplest to train from scratch
  4. Download ImageNet checkpoint - Test inference

CTM vs Transformers - Key Differences

Aspect Transformer CTM
Time axis Processes sequence positions Has internal "thinking" ticks independent of input
Neurons Shared weights across all positions Each neuron has private weights
Representation Token embeddings Synchronization (correlations between neurons over time)
Computation Fixed depth (layers) Adaptive - can "think longer" on hard problems
Recurrence None (or limited) Continuous recurrence over internal ticks

The Core Insight

Transformers: Input → Fixed layers → Output (same compute for easy/hard)

CTM: Input → Internal ticks (T=1,2,3...75) → Output when "certain"

  • Each tick refines the representation
  • Model learns to use more ticks for harder inputs
  • Neurons develop temporal patterns that encode information

Architecture Deep Dive

High-Level: What Happens Each Tick

                        ONE INTERNAL TICK (t → t+1)
┌──────────────────────────────────────────────────────────────────────┐
│                                                                      │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐            │
│  │ Previous    │     │   Synapse   │     │    NLM      │            │
│  │ State zᵗ    │────▶│   (U-Net)   │────▶│ (per neuron)│────▶ zᵗ⁺¹ │
│  └─────────────┘     └─────────────┘     └─────────────┘            │
│        │                   ▲                                         │
│        │                   │                                         │
│        │            ┌──────┴──────┐                                  │
│        │            │  Attention  │◀──── Input features              │
│        │            │   output    │      (from image/data)           │
│        │            └─────────────┘                                  │
│        │                                                             │
│        ▼                                                             │
│  ┌─────────────────────────────────────┐                            │
│  │  Synchronization Matrix Sᵗ          │────▶ Output yᵗ             │
│  │  (correlations between neurons)     │────▶ Query qᵗ (attention)  │
│  └─────────────────────────────────────┘                            │
│                                                                      │
└──────────────────────────────────────────────────────────────────────┘
                              │
                              ▼
                      REPEAT FOR t=1,2,3...T

Transformer vs CTM Architecture Comparison

Transformer (for reference):

Input tokens: [A] [B] [C] [D]
                ↓   ↓   ↓   ↓
            ┌───────────────────┐
Layer 1     │  Self-Attention   │  ← All tokens see each other
            └───────────────────┘
                ↓   ↓   ↓   ↓
            ┌───────────────────┐
Layer 2     │  Self-Attention   │
            └───────────────────┘
                ↓   ↓   ↓   ↓
              [A'] [B'] [C'] [D']  → Output

- Fixed compute: always N layers
- Each position uses SAME weights
- Representation = token embeddings

CTM:

                    ┌──────────────────────────────────────────────┐
                    │           INTERNAL TIME AXIS (T ticks)       │
                    │    t=1    t=2    t=3   ...   t=75            │
                    └──────────────────────────────────────────────┘
                         ↓      ↓      ↓            ↓

Input ──→ [Feature    ┌─────────────────────────────────────┐
          Extractor]  │                                     │
              │       │   z¹ ──→ z² ──→ z³ ──→ ... ──→ z⁷⁵  │  ← Neuron activations
              │       │   │      │      │             │     │    evolve over ticks
              ↓       │   ↓      ↓      ↓             ↓     │
         ┌────────┐   │  ┌──────────────────────────────┐   │
         │Attention│◀──│──│  Synchronization Matrix S   │   │
         │  Keys  │   │  │  (neuron-to-neuron corr.)   │   │
         └────────┘   │  └──────────────────────────────┘   │
              │       │         ↓      ↓             ↓      │
              │       │       y¹     y²    ...     y⁷⁵      │  ← Outputs at each tick
              │       └─────────────────────────────────────┘
              │                        ↓
              └──────────────────→  Final output when "certain"

The Three Novel Components

1. Neuron-Level Models (NLMs)

Each neuron is its own mini-network with private weights.

Traditional NN:

  • All neurons share same weights W

CTM NLM:

  • Each neuron d has PRIVATE weights: W₁, b₁ ... W_D, b_D
  • Input to each: last M pre-activations
  • A_d = [a_d^{t-M}, a_d^{t-M+1}, ..., a_d^{t}]
┌────────────────────────────────────────────────────┐
│  Neuron 42's history buffer (M=25 ticks):          │
│                                                    │
│  A₄₂ = [a₄₂^{t-24}, a₄₂^{t-23}, ... , a₄₂^{t}]   │
│        ┌───┬───┬───┬─────┬───┬───┐               │
│        │0.2│0.5│0.1│ ... │0.8│0.3│               │
│        └───┴───┴───┴─────┴───┴───┘               │
│                    ↓                              │
│              ┌─────────┐                          │
│              │   W₄₂   │  ← Private weights       │
│              │  (MLP)  │                          │
│              └─────────┘                          │
│                    ↓                              │
│           z₄₂^{t+1} = new activation             │
└────────────────────────────────────────────────────┘

Parameter comparison (D=4096, M=25, hidden=64):

  • Standard MLP layer: 4096 × 4096 = 16.7M params (shared)
  • CTM NLMs: Each neuron independent with own parameters

2. Synchronization (The Key Insight)

Instead of using neuron values as representation, use neuron correlations over time.

Example: 4 neurons over 10 ticks

              t=1  t=2  t=3  t=4  t=5  t=6  t=7  t=8  t=9  t=10
            ┌────────────────────────────────────────────────────┐
  Neuron 1  │ 0.1  0.8  0.2  0.9  0.1  0.8  0.2  0.9  0.1  0.8  │ ╲
            │  ↗    ↘    ↗    ↘    ↗    ↘    ↗    ↘    ↗    ↘   │  ╲ SYNCHRONIZED
  Neuron 2  │ 0.2  0.7  0.3  0.8  0.2  0.7  0.3  0.8  0.2  0.7  │  ╱ (move together)
            │  ↗    ↘    ↗    ↘    ↗    ↘    ↗    ↘    ↗    ↘   │ ╱
            ├────────────────────────────────────────────────────┤
  Neuron 3  │ 0.9  0.2  0.8  0.1  0.9  0.2  0.8  0.1  0.9  0.2  │ ╲
            │  ↘    ↗    ↘    ↗    ↘    ↗    ↘    ↗    ↘    ↗   │  ╲ ANTI-SYNC
  Neuron 4  │ 0.8  0.3  0.7  0.2  0.8  0.3  0.7  0.2  0.8  0.3  │  ╱ (opposite phase)
            │  ↘    ↗    ↘    ↗    ↘    ↗    ↘    ↗    ↘    ↗   │ ╱
            └────────────────────────────────────────────────────┘

Resulting Synchronization Matrix S:

            N1      N2      N3      N4
       ┌────────────────────────────────┐
  N1   │  1.0    0.98   -0.95   -0.92  │   N1-N2: highly correlated
  N2   │  0.98   1.0    -0.93   -0.90  │   N1-N3: anti-correlated
  N3   │ -0.95  -0.93    1.0     0.97  │   N3-N4: highly correlated
  N4   │ -0.92  -0.90    0.97    1.0   │
       └────────────────────────────────┘

The pattern of this matrix encodes the representation! Different inputs produce different synchronization patterns, and this is robust to small noise in values.

3. Adaptive Computation

The model can "think longer" on harder problems.

Easy input (clear "7"):           Hard input (ambiguous "7" vs "1"):

t=1:  ░░░ uncertain               t=1:  ░░░ uncertain
t=2:  ▒▒▒ getting clearer         t=2:  ░░░ still uncertain
t=3:  ███ CERTAIN → output        t=3:  ░░░ uncertain
                                  t=4:  ▒▒▒ slightly clearer
      Stop early, save compute    t=5:  ▒▒▒ refining...
                                  ...
                                  t=50: ███ finally certain → output

                                  Uses more "thinking time"

Inside the Synapse: U-Net MLP

The Synapse shares information across all D neurons (D=4096 for ImageNet).

Input: [z^t (4096) ∥ attention_output (1024)] = 5120 dims
                            │
                            ▼
    ┌─────────────────────────────────────────────────────────┐
    │                    SYNAPSE U-NET                        │
    │                                                         │
    │   5120 → 2560 → 1280 → 640 → 320 → 160 → 16            │ (encode)
    │     │      │      │      │      │      │                │
    │   SKIP   SKIP   SKIP   SKIP   SKIP   SKIP              │
    │     │      │      │      │      │      │                │
    │   16 → 160 → 320 → 640 → 1280 → 2560 → 4096            │ (decode)
    │                                                         │
    └─────────────────────────────────────────────────────────┘
                            │
                            ▼
                     Pre-activation aᵗ (4096 dims)
  • Skip connections preserve information
  • Bottleneck forces compression/abstraction
  • Output feeds into NLMs

Learnable Temporal Decay

Not all history is equal—recent ticks might matter more! The model learns decay parameter r_ij per neuron pair.

HIGH r_ij (fast decay):          LOW r_ij (slow decay):

Weight on past ticks:            Weight on past ticks:

1.0 │    ████                    1.0 │████████████████████
    │   █████                        │████████████████████
    │  ██████                        │████████████████████
    │ ███████                        │████████████████████
0.0 └─────────▶ time             0.0 └─────────────────────▶ time
     t-M    t (now)                   t-M              t (now)

"Only care about                 "All history equally
 recent ticks"                    important"

Complete Forward Pass

INPUT: Image x (224×224×3)
         │
         ▼
┌──────────────────┐
│  Feature Extract │  ResNet-152, output 14×14×1024
│  (run ONCE)      │
└──────────────────┘
         │
         ▼
   K, V for attention (196 tokens × 1024 dims)
         │
         │
═════════════════════════════════════════════════════════════════════
INITIALIZE:
  z⁰ ∈ R^D         (learned, same for all inputs)
  A⁰ = zeros(D, M) (empty history buffer)
  S⁰ = zeros       (no synchronization yet)
═════════════════════════════════════════════════════════════════════
         │
         ▼
FOR t = 1 to T:

  ┌──────────────────────────────────────────────────┐
  │ 1. COMPUTE SYNC → QUERY                          │
  │    S^t = sync(Z^t)         # correlations so far │
  │    q^t = W_q · S^t_action  # attention query     │
  └──────────────────────────────────────────────────┘
                   │
                   ▼
  ┌──────────────────────────────────────────────────┐
  │ 2. CROSS-ATTENTION                               │
  │    o^t = Attention(Q=q^t, K=K, V=V)             │
  │    # "Where should I look in the image?"         │
  └──────────────────────────────────────────────────┘
                   │
                   ▼
  ┌──────────────────────────────────────────────────┐
  │ 3. SYNAPSE (info sharing)                        │
  │    input = concat(z^t, o^t)                      │
  │    a^t = SynapseUNet(input)  # pre-activation    │
  └──────────────────────────────────────────────────┘
                   │
                   ▼
  ┌──────────────────────────────────────────────────┐
  │ 4. UPDATE HISTORY BUFFER                         │
  │    A^t = roll_and_append(A^{t-1}, a^t)          │
  │    # FIFO: drop oldest, add newest               │
  └──────────────────────────────────────────────────┘
                   │
                   ▼
  ┌──────────────────────────────────────────────────┐
  │ 5. NLMs (each neuron updates independently)      │
  │    for d in 1..D:                                │
  │        z^{t+1}_d = NLM_d(A^t_d)  # private W_d  │
  │    # Efficient: vectorized einsum, not loop      │
  └──────────────────────────────────────────────────┘
                   │
                   ▼
  ┌──────────────────────────────────────────────────┐
  │ 6. STORE ACTIVATION & OUTPUT                     │
  │    Z^t = append(Z^{t-1}, z^t)  # full history   │
  │    y^t = W_out · S^t_out       # classification  │
  └──────────────────────────────────────────────────┘

END FOR
═════════════════════════════════════════════════════════════════════

OUTPUTS: y¹, y², y³, ..., y^T   (one prediction per tick)

TRAINING: Loss = (L_{t_min} + L_{t_certain}) / 2

INFERENCE: Use y^{t*} where t* = argmax(certainty)
           OR average predictions after certainty threshold

The Training Trick

Loss = average of TWO special ticks:

  1. t_min_loss = tick with lowest classification loss
  2. t_max_cert = tick with highest certainty (1 - entropy)

This teaches the model to:

  • Actually solve the problem (min loss)
  • Know when it's solved (max certainty)
  • Naturally learn adaptive computation

Emergent Behaviors

1. Adaptive Computation

Easy "7":                    Hard "7" (looks like "1"):

t=1: ░░░ 40%                 t=1:  ░░░ 35%
t=2: ▒▒▒ 70%                 t=2:  ░░░ 40%
t=3: ███ 95% ← STOP          t=3:  ░░░ 42%
                             t=4:  ░░░ 45%
     3 ticks used            ...
                             t=40: ▒▒▒ 75%
                             t=50: ███ 92% ← STOP

                             50 ticks used

2. Emergent Attention Patterns ("Looking Around")

The model learns to scan different parts of the image without being explicitly told to:

t=1:  ┌─────────┐        t=5:  ┌─────────┐
      │ ○       │              │    ○    │
      │   7     │  "top-left"  │   7     │  "middle"
      │         │              │         │
      └─────────┘              └─────────┘

t=10: ┌─────────┐        t=20: ┌─────────┐
      │         │              │    ○    │
      │   7  ○  │  "the hook"  │   7     │  "confirm"
      │         │              │         │
      └─────────┘              └─────────┘

3. Neuron Dynamics

Oscillations emerge naturally through training:

At initialization:           After training:

Neuron activity over time:   Neuron activity over time:

────────────────────         ╱╲  ╱╲  ╱╲  ╱╲  ╱╲
                               ╲╱  ╲╱  ╲╱  ╲╱

Flat, boring                 Rich periodic dynamics!

Diversity of patterns encodes information.


Why CTM Matters

Aspect Transformer CTM
Time axis Sequence positions (tied to input) Internal "thinking" (decoupled from input)
Weight sharing All positions share attention weights Each neuron has PRIVATE weights (NLMs)
Representation Token embeddings (neuron values) Synchronization matrix (neuron correlations)
Computation Fixed (N layers, same for all inputs) Adaptive (1-T ticks, more for harder)
Recurrence None (feedforward) or limited Continuous over ticks with memory buffer
Bio-inspired? Not really Yes! Timing/sync like biological neurons
@bigsnarfdude
Copy link
Author

Analogy

HUMAN VISION
────────────

  Eyes + V1        →      Higher brain       →    Decision
  ─────────────           ────────────            ────────
  "Raw features"          "Thinking"              "Action"
  edges, colors           reasoning               "it's a cat"
  shapes                  over time

CTM
───

  CNN/ResNet       →      CTM Loop           →    Output
  ──────────              ────────                ──────
  "Raw features"          "Thinking"              "Action"
  edges, textures         15-75 ticks             class logits
  spatial patterns        temporal reasoning

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment