Skip to content

Instantly share code, notes, and snippets.

@ehartford
Created August 15, 2025 23:57
Show Gist options
  • Select an option

  • Save ehartford/cbf4b9d2fd6f23f0046cba6abfb5d9a0 to your computer and use it in GitHub Desktop.

Select an option

Save ehartford/cbf4b9d2fd6f23f0046cba6abfb5d9a0 to your computer and use it in GitHub Desktop.
Hrm

Qwen‑3 Refine‑Loop (Q3‑RL): a pragmatic HRM‑style hybrid

What changes vs. the original HRM idea?

Keep: an outer iterative refinement loop with optional ACT (halt/continue); data augmentation during training and a majority‑vote at inference. These were the biggest drivers of ARC performance in ablations.

Drop/optionalize: the internal H/L hierarchy and inner recurrent loop. ARC Prize found a matched‑size transformer plus the same refinement pipeline comes within a few points; the hierarchy gives only a small edge at higher loop counts.

Remove: reliance on puzzle_id embeddings; replace with context‑derived task conditioning that generalizes to unseen tasks. (ARC Prize notes puzzle_id is a strong, limiting dependency.)


  1. System outline

+-----------------------+ Outer iteration t = 0..T_max | Task Conditioner C |-----> soft prompt / control vectors (m tokens) +-----------------------+ (derived from in-context examples, NOT puzzle_id)

       |                             ^
       v                             |

+-----------------------+ refine | render/decode | Qwen-3 backbone |--------------> | (domain head) | (frozen or LoRA) | (text or grid hypothesis y_t) +-----------------------+ | |
v v +-----------------------+ +-----------------------+ | ACT / Halt Head |----p_t--> | Loop Controller | +-----------------------+ +-----------------------+ | v augment / de-augment / vote

Backbone: Qwen‑3 (7–8B class). No internal HRM block; we keep the standard transformer and wrap it in a loop.

Task Conditioner (C): learns a soft prompt from the in‑context example pairs (few‑shot)—not a discrete ID. (Details below.)

Domain head: either Text LM head (token generation) or Grid head for ARC (tokenized grids).

Loop controller: runs K refinement passes; uses ACT for early stopping if enabled, else fixed K.

Augmenter & Voter: run train‑time augmentation (≈300/ task) and inference‑time majority vote over a reduced augmentation set—matching ablation insights.


  1. Data representations

2.1 ARC (grid) representation

Alphabet: colors 0–9 → 10 tokens; plus . for empty; {[ , ], ;} for separators.

Grid → tokens: row‑major, e.g., [[0 1 1];[. 2 2]] → "[ [ 0 1 1 ] ; [ . 2 2 ] ]".

Examples in context:

EXAMPLE_1_IN : [ ... ] EXAMPLE_1_OUT: [ ... ] ... TEST_IN : [ ... ] HYP_t : [ ... ] # current hypothesis y_t (absent at t=0) AUGMENT : rot=90 flipH=0 flipV=1 colperm=0-1-2-3-4-5-6-7-8-9

Text tasks: keep standard LM tokens; “HYP_t” is the partial answer text.

This avoids the puzzle_id embedding the ARC Prize post found critical (and brittle) in HRM; conditioning is carried by the examples + learned soft prompt instead.


  1. Modules & shapes

3.1 Task Conditioner C

Produces m soft‑prompt tokens prepended to the Qwen‑3 token stream.

Inputs: hidden states for each in‑context example pair (the tokens between EXAMPLE_k_IN: and EXAMPLE_k_OUT:) plus the AUGMENT descriptor tokens. Network:

Pooling: Per‑example mean‑pool over token embeddings → 2 vectors per example (in/out).

Aggregator: concatenate [mean_in; mean_out; mean_in – mean_out] across examples → a fixed vector u (dim 3d * n_examples, with d the model dim).

Projector: 2‑layer MLP with GELU to m*d, reshape to (m, d).

Optional: cross‑attention pooling (tiny transformer with 1–2 layers) instead of mean‑pool.

Params: ~5–15M depending on m (m=16 recommended).

3.2 ACT / Halt head

Token: a single token; take its final hidden state h_halt ∈ R^d.

Head: p_halt = σ(W2 · GELU(W1 h_halt)).

Loss: BCE to teacher signal; see §6.

3.3 Domain head (ARC grid)

Use standard LM head (token‑level cross‑entropy) over the grid tokens.

(Optional) a small de-tokenizer checker that reconstructs grids and enforces shape/color constraints; this is a non‑parametric post‑processor.

This stays compatible with plain Qwen inference; no custom conv decoder is required, simplifying integration.


  1. Outer refinement loop

At iteration t = 0..T_max:

  1. Build prompt with Task Conditioner soft tokens, problem tokens, and (for t>0) the previous hypothesis HYP_{t-1} plus the AUGMENT line describing the transform used for this pass.

  2. Forward pass through Qwen‑3 → generate a new hypothesis y_t (greedy or small beam).

  3. Compute p_halt from token.

  4. If ACT enabled and t ≥ 1 and p_halt > τ, stop; else continue.

  5. On inference, after finishing (by halt or T_max), de‑augment and add to vote pool.

Why this design? ARC Prize found the outer refinement (and training with many refinement steps) is the key effect; ACT helps but is not essential. Training with higher loop budgets improves even 1‑loop inference.


  1. Augmentations & voting

Train‑time: sample ≈300 augmentation codes per task (rotations, flips, limited color permutations). ARC Prize found 300 ≈ 1000 for near‑max performance.

Inference‑time: run a smaller subset (e.g., 32–64) to save cost; de‑augment predictions; majority vote the final grid/text.

Transform descriptors: include a compact text line (e.g., rot=90 flipH=1 colperm=...) so the model knows which view it’s solving—without an opaque ID.


  1. Training objectives

Let ground truth be y*.

  1. Final CE loss on the last hypothesis: L_final = CE(y_T, y*).

  2. Intermediate supervision (optional but stabilizing): L_inter = (1/T) ∑_{t=1..T} CE(y_t, y*) * w_t, with a cosine ramp w_t.

  3. Edit regularizer (keep refinements focused): Encourage small deltas unless needed: L_edit = λ_edit * min( ||y_t - y_{t-1}||₀ , ||y* - y_{t-1}||₀ ) (implemented as token‑level Hamming proxy).

  4. ACT loss (if enabled): Teacher signal: for perfect training sequences, label continue=1 up to the first t where y_t == y*, then halt=1 afterwards. L_halt = BCE(p_halt_t, target_t).

Total: L = L_final + α L_inter + β L_edit + γ L_halt.

ARC Prize indicates training with larger loop budgets (e.g., 16) boosts even single‑loop inference. Use T_train ∈ {8, 16}, then T_infer ∈ {1, 4} depending on SLA.


  1. Optimization schedule

Phase A (adapter‑first):

Freeze Qwen‑3; train: Task Conditioner + (optional) LoRA on the last N transformer blocks (N=8), LM head, ACT head.

LR: 3e‑4 (adapters), 5e‑4 (Task Conditioner / ACT), cosine decay 5–10k steps.

Phase B (light unfreeze):

Unfreeze top N = 8–12 Qwen layers; keep LoRA on the rest.

LR: base 5e‑5, adapters 1e‑4; KL‑penalty (0.1) to previous checkpoint logits to avoid drift.

Batching: gradient accumulation to reach effective batch ≥ 64 (across augmentations).

Precision: bf16 with grad‑scaler; FlashAttention‑2 OK.

Loop budget: T_train = 8 (or 16 if budget allows).

Aug count: 300 per task (sampled with per‑type caps).


  1. Inference policies

Fast mode: T_infer = 1, no ACT, 16–32 augmentations → very fast baseline.

Balanced: T_infer = 4, ACT τ=0.7, 32–64 augmentations.

Max‑quality (ARC eval): T_infer = 8, ACT τ=0.6, 64–128 augmentations (costlier).

ARC Prize’s ablation suggests training loops >> inference loops in impact; prioritize T_train if compute is tight.


  1. Implementation sketch (PyTorch‑style)

This assumes a standard HuggingFace Qwen‑3 class; no custom CUDA needed. For inspiration on ACT and loop scaffolding, the official HRM repo gives reference patterns for halt heads & looped refinement (different internals, same controller idea).

class TaskConditioner(nn.Module): def init(self, d_model, m_tokens=16, hidden=4_096): super().init() self.proj1 = nn.Linear(3d_model, hidden) self.proj2 = nn.Linear(hidden, m_tokensd_model) self.act = nn.GELU() self.m = m_tokens self.d = d_model

def forward(self, hs_examples):  # list of (H_in, H_out) tensors, each [S_i, d]
    pooled = []
    for H_in, H_out in hs_examples:
        mu_in  = H_in.mean(dim=0)
        mu_out = H_out.mean(dim=0)
        pooled.append(torch.cat([mu_in, mu_out, mu_in - mu_out], dim=-1))
    u = torch.mean(torch.stack(pooled, dim=0), dim=0)         # [3d]
    z = self.act(self.proj1(u))                                # [hidden]
    soft = self.proj2(z).view(self.m, self.d)                  # [m, d]
    return soft

class HaltHead(nn.Module): def init(self, d_model): super().init() self.ff = nn.Sequential( nn.Linear(d_model, 4d_model//3), nn.GELU(), nn.Linear(4d_model//3, 1) ) def forward(self, h_halt): # [B, d] return torch.sigmoid(self.ff(h_halt)).squeeze(-1) # [B]

class Q3RLWrapper(nn.Module): def init(self, qwen, conditioner, halt_head, tokenizer, m_tokens=16): super().init() self.qwen = qwen self.cond = conditioner self.halt = halt_head self.tok = tokenizer self.m = m_tokens

@torch.no_grad()
def _extract_example_states(self, input_ids, attn_mask, example_spans):
    # Run one pass to collect hidden states of example regions
    out = self.qwen.model(input_ids=input_ids, attention_mask=attn_mask,
                          output_hidden_states=True, use_cache=False)
    hs = out.hidden_states[-1]  # [B, S, d]
    hs_pairs = []
    for (s_in, e_in, s_out, e_out) in example_spans:
        H_in  = hs[:, s_in:e_in, :].mean(dim=1)   # [B, d]
        H_out = hs[:, s_out:e_out, :].mean(dim=1) # [B, d]
        hs_pairs.append((H_in, H_out))
    return hs_pairs

def forward_once(self, input_ids, attn_mask, example_spans):
    # 1) get soft prompt
    hs_pairs = self._extract_example_states(input_ids, attn_mask, example_spans)
    soft = self.cond(hs_pairs)                                    # [m, d]
    # 2) prepend soft tokens (prefix tuning)
    soft_ids = self._soft_to_ids(soft)                            # virtual; or pass via inputs_embeds
    # 3) run generation; ensure a <HALT> token exists in the prompt
    outputs = self.qwen(input_ids=input_ids, attention_mask=attn_mask,
                        past_key_values=None, return_dict=True)
    logits  = outputs.logits                                       # [B, S, V]
    # locate HALT hidden state
    h_halt = outputs.hidden_states[-1][torch.arange(input_ids.size(0)), self._halt_idx(input_ids), :]
    p_halt = self.halt(h_halt)                                     # [B]
    return logits, p_halt

# Outer loop: build new prompt with HYP_{t}, rerun forward_once...

Notes

You can implement the soft prompt either by inputs_embeds (recommended) or by reserving m special tokens whose embeddings you overwrite per‑batch.

For generation, use HF’s generate() but capture the position via special token id.

The Loop Controller is a small Python module handling augmentation sampling, prompt building, de‑augmentation, and voting.


  1. Prompt templates

ARC (per augmented view)

<SOFT_PROMPT m=16> You are given a set of input-output grid examples. Infer the rule and apply it. EXAMPLE_1_IN : [ ... ] EXAMPLE_1_OUT: [ ... ] EXAMPLE_2_IN : [ ... ] EXAMPLE_2_OUT: [ ... ] ... AUGMENT : rot=90 flipH=0 flipV=1 colperm=0312456789

TEST_IN : [ ... ] {%- if t>0 -%} HYP_{t-1} : [ ... ] {%- endif -%} Predict TEST_OUT as a grid only, no extra text.

Text tasks (generic)

<SOFT_PROMPT m=16> Task: {{instruction}} Examples: Q: ... A: ... ... Input: {{input_text}} {%- if t>0 -%} Hypothesis_{t-1}: {{previous_answer}} {%- endif -%} Output:


  1. Evaluation & logging

Pass@k (k=1,2) over de‑augmented predictions.

Loop telemetry: average T_used, halt entropy, delta edits per step, augmentation skew in votes.

Ablations:

  1. T_train ∈ {1, 4, 8, 16} × T_infer ∈ {1, 4, 8}.

  2. Augmentation count {30, 100, 300, 1000}.

  3. Conditioner on/off; LoRA only vs LoRA+top‑N unfreeze.

  4. ACT on/off. These mirror ARC Prize’s analyses so your results are directly comparable.


  1. Distributed & performance notes

Each refinement pass is an independent forward: no RoPE bias hacks or KV cache reuse across passes needed (simpler than inner HRM loops).

Use p‑jit style data parallel on augmentations; each GPU handles a subset of views; voting happens on host.

Caching: cache tokenized prompts per augmentation; only splice HYP_{t-1} at each loop.

Throughput knobs: reduce augment views first; then reduce T_infer.


  1. Compatibility with prior HRM code

You can borrow ACT utilities and the general train/infer scaffolding patterns from the official HRM repo—but your backbone stays Qwen‑3 and the loop is outer only.

For historical grounding of outer‑loop + ACT, see Universal Transformer (outer refinement + ACT).


  1. Why this design matches the new evidence

Outer refinement dominates: We explicitly center the architecture on it; train with larger loop budgets as recommended by ARC Prize’s ablation, which notably improved even single‑loop inference performance.

Transfer vs memorization: We remove puzzle_id and rely on context‑derived conditioning, improving generalization potential and avoiding a hard train‑time prerequisite on the exact inference IDs.

Augmentation efficiency: Default to ~300 train‑time augmentations, with a smaller inference set—mirroring the finding that 300 ≈ 1000 for near‑max.


  1. Minimal task list to implement
  1. Tokenizer & grid I/O (ARC).

  2. Prompt builder with HYP_t splice + AUGMENT descriptors.

  3. Task Conditioner C (soft‑prompt) + hooks to pass inputs_embeds to Qwen‑3.

  4. Halt head (+ special token and position extraction).

  5. Loop Controller (train/infer) with augmentation sampler, de‑augmenter, voter.

  6. Losses L_final, L_inter, L_edit, L_halt.

  7. Trainer (two‑phase freeze/unfreeze; LoRA config).

  8. Telemetry (per‑loop stats, vote traces).

  9. Ablation scripts to reproduce the key plots (T_train vs T_infer; aug counts).


  1. Useful references

ARC Prize analysis (Aug 15, 2025) — outer‑loop & aug findings; hierarchy marginal; puzzle_id brittleness.

HRM paper (June/July 2025) — for ACT mechanics, training hygiene, and general loop ideas.

HRM official code — useful scaffolding for halting heads / loops (adapt, don’t adopt wholesale).

ARC‑AGI without pretraining (Liao & Gu) — an adjacent test‑time training framing worth cross‑checking as you evaluate transfer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment