Qwen‑3 Refine‑Loop (Q3‑RL): a pragmatic HRM‑style hybrid
What changes vs. the original HRM idea?
Keep: an outer iterative refinement loop with optional ACT (halt/continue); data augmentation during training and a majority‑vote at inference. These were the biggest drivers of ARC performance in ablations.
Drop/optionalize: the internal H/L hierarchy and inner recurrent loop. ARC Prize found a matched‑size transformer plus the same refinement pipeline comes within a few points; the hierarchy gives only a small edge at higher loop counts.
Remove: reliance on puzzle_id embeddings; replace with context‑derived task conditioning that generalizes to unseen tasks. (ARC Prize notes puzzle_id is a strong, limiting dependency.)
- System outline
+-----------------------+ Outer iteration t = 0..T_max | Task Conditioner C |-----> soft prompt / control vectors (m tokens) +-----------------------+ (derived from in-context examples, NOT puzzle_id)
| ^
v |
+-----------------------+ refine | render/decode
| Qwen-3 backbone |--------------> | (domain head)
| (frozen or LoRA) | (text or grid hypothesis y_t)
+-----------------------+ |
|
v v
+-----------------------+ +-----------------------+
| ACT / Halt Head |----p_t--> | Loop Controller |
+-----------------------+ +-----------------------+
|
v
augment / de-augment / vote
Backbone: Qwen‑3 (7–8B class). No internal HRM block; we keep the standard transformer and wrap it in a loop.
Task Conditioner (C): learns a soft prompt from the in‑context example pairs (few‑shot)—not a discrete ID. (Details below.)
Domain head: either Text LM head (token generation) or Grid head for ARC (tokenized grids).
Loop controller: runs K refinement passes; uses ACT for early stopping if enabled, else fixed K.
Augmenter & Voter: run train‑time augmentation (≈300/ task) and inference‑time majority vote over a reduced augmentation set—matching ablation insights.
- Data representations
2.1 ARC (grid) representation
Alphabet: colors 0–9 → 10 tokens; plus . for empty; {[ , ], ;} for separators.
Grid → tokens: row‑major, e.g., [[0 1 1];[. 2 2]] → "[ [ 0 1 1 ] ; [ . 2 2 ] ]".
Examples in context:
EXAMPLE_1_IN : [ ... ] EXAMPLE_1_OUT: [ ... ] ... TEST_IN : [ ... ] HYP_t : [ ... ] # current hypothesis y_t (absent at t=0) AUGMENT : rot=90 flipH=0 flipV=1 colperm=0-1-2-3-4-5-6-7-8-9Text tasks: keep standard LM tokens; “HYP_t” is the partial answer text.
This avoids the puzzle_id embedding the ARC Prize post found critical (and brittle) in HRM; conditioning is carried by the examples + learned soft prompt instead.
- Modules & shapes
3.1 Task Conditioner C
Produces m soft‑prompt tokens prepended to the Qwen‑3 token stream.
Inputs: hidden states for each in‑context example pair (the tokens between EXAMPLE_k_IN: and EXAMPLE_k_OUT:) plus the AUGMENT descriptor tokens. Network:
Pooling: Per‑example mean‑pool over token embeddings → 2 vectors per example (in/out).
Aggregator: concatenate [mean_in; mean_out; mean_in – mean_out] across examples → a fixed vector u (dim 3d * n_examples, with d the model dim).
Projector: 2‑layer MLP with GELU to m*d, reshape to (m, d).
Optional: cross‑attention pooling (tiny transformer with 1–2 layers) instead of mean‑pool.
Params: ~5–15M depending on m (m=16 recommended).
3.2 ACT / Halt head
Token: a single token; take its final hidden state h_halt ∈ R^d.
Head: p_halt = σ(W2 · GELU(W1 h_halt)).
Loss: BCE to teacher signal; see §6.
3.3 Domain head (ARC grid)
Use standard LM head (token‑level cross‑entropy) over the grid tokens.
(Optional) a small de-tokenizer checker that reconstructs grids and enforces shape/color constraints; this is a non‑parametric post‑processor.
This stays compatible with plain Qwen inference; no custom conv decoder is required, simplifying integration.
- Outer refinement loop
At iteration t = 0..T_max:
-
Build prompt with Task Conditioner soft tokens, problem tokens, and (for t>0) the previous hypothesis HYP_{t-1} plus the AUGMENT line describing the transform used for this pass.
-
Forward pass through Qwen‑3 → generate a new hypothesis y_t (greedy or small beam).
-
Compute p_halt from token.
-
If ACT enabled and t ≥ 1 and p_halt > τ, stop; else continue.
-
On inference, after finishing (by halt or T_max), de‑augment and add to vote pool.
Why this design? ARC Prize found the outer refinement (and training with many refinement steps) is the key effect; ACT helps but is not essential. Training with higher loop budgets improves even 1‑loop inference.
- Augmentations & voting
Train‑time: sample ≈300 augmentation codes per task (rotations, flips, limited color permutations). ARC Prize found 300 ≈ 1000 for near‑max performance.
Inference‑time: run a smaller subset (e.g., 32–64) to save cost; de‑augment predictions; majority vote the final grid/text.
Transform descriptors: include a compact text line (e.g., rot=90 flipH=1 colperm=...) so the model knows which view it’s solving—without an opaque ID.
- Training objectives
Let ground truth be y*.
-
Final CE loss on the last hypothesis: L_final = CE(y_T, y*).
-
Intermediate supervision (optional but stabilizing): L_inter = (1/T) ∑_{t=1..T} CE(y_t, y*) * w_t, with a cosine ramp w_t.
-
Edit regularizer (keep refinements focused): Encourage small deltas unless needed: L_edit = λ_edit * min( ||y_t - y_{t-1}||₀ , ||y* - y_{t-1}||₀ ) (implemented as token‑level Hamming proxy).
-
ACT loss (if enabled): Teacher signal: for perfect training sequences, label continue=1 up to the first t where y_t == y*, then halt=1 afterwards. L_halt = BCE(p_halt_t, target_t).
Total: L = L_final + α L_inter + β L_edit + γ L_halt.
ARC Prize indicates training with larger loop budgets (e.g., 16) boosts even single‑loop inference. Use T_train ∈ {8, 16}, then T_infer ∈ {1, 4} depending on SLA.
- Optimization schedule
Phase A (adapter‑first):
Freeze Qwen‑3; train: Task Conditioner + (optional) LoRA on the last N transformer blocks (N=8), LM head, ACT head.
LR: 3e‑4 (adapters), 5e‑4 (Task Conditioner / ACT), cosine decay 5–10k steps.
Phase B (light unfreeze):
Unfreeze top N = 8–12 Qwen layers; keep LoRA on the rest.
LR: base 5e‑5, adapters 1e‑4; KL‑penalty (0.1) to previous checkpoint logits to avoid drift.
Batching: gradient accumulation to reach effective batch ≥ 64 (across augmentations).
Precision: bf16 with grad‑scaler; FlashAttention‑2 OK.
Loop budget: T_train = 8 (or 16 if budget allows).
Aug count: 300 per task (sampled with per‑type caps).
- Inference policies
Fast mode: T_infer = 1, no ACT, 16–32 augmentations → very fast baseline.
Balanced: T_infer = 4, ACT τ=0.7, 32–64 augmentations.
Max‑quality (ARC eval): T_infer = 8, ACT τ=0.6, 64–128 augmentations (costlier).
ARC Prize’s ablation suggests training loops >> inference loops in impact; prioritize T_train if compute is tight.
- Implementation sketch (PyTorch‑style)
This assumes a standard HuggingFace Qwen‑3 class; no custom CUDA needed. For inspiration on ACT and loop scaffolding, the official HRM repo gives reference patterns for halt heads & looped refinement (different internals, same controller idea).
class TaskConditioner(nn.Module): def init(self, d_model, m_tokens=16, hidden=4_096): super().init() self.proj1 = nn.Linear(3d_model, hidden) self.proj2 = nn.Linear(hidden, m_tokensd_model) self.act = nn.GELU() self.m = m_tokens self.d = d_model
def forward(self, hs_examples): # list of (H_in, H_out) tensors, each [S_i, d]
pooled = []
for H_in, H_out in hs_examples:
mu_in = H_in.mean(dim=0)
mu_out = H_out.mean(dim=0)
pooled.append(torch.cat([mu_in, mu_out, mu_in - mu_out], dim=-1))
u = torch.mean(torch.stack(pooled, dim=0), dim=0) # [3d]
z = self.act(self.proj1(u)) # [hidden]
soft = self.proj2(z).view(self.m, self.d) # [m, d]
return soft
class HaltHead(nn.Module): def init(self, d_model): super().init() self.ff = nn.Sequential( nn.Linear(d_model, 4d_model//3), nn.GELU(), nn.Linear(4d_model//3, 1) ) def forward(self, h_halt): # [B, d] return torch.sigmoid(self.ff(h_halt)).squeeze(-1) # [B]
class Q3RLWrapper(nn.Module): def init(self, qwen, conditioner, halt_head, tokenizer, m_tokens=16): super().init() self.qwen = qwen self.cond = conditioner self.halt = halt_head self.tok = tokenizer self.m = m_tokens
@torch.no_grad()
def _extract_example_states(self, input_ids, attn_mask, example_spans):
# Run one pass to collect hidden states of example regions
out = self.qwen.model(input_ids=input_ids, attention_mask=attn_mask,
output_hidden_states=True, use_cache=False)
hs = out.hidden_states[-1] # [B, S, d]
hs_pairs = []
for (s_in, e_in, s_out, e_out) in example_spans:
H_in = hs[:, s_in:e_in, :].mean(dim=1) # [B, d]
H_out = hs[:, s_out:e_out, :].mean(dim=1) # [B, d]
hs_pairs.append((H_in, H_out))
return hs_pairs
def forward_once(self, input_ids, attn_mask, example_spans):
# 1) get soft prompt
hs_pairs = self._extract_example_states(input_ids, attn_mask, example_spans)
soft = self.cond(hs_pairs) # [m, d]
# 2) prepend soft tokens (prefix tuning)
soft_ids = self._soft_to_ids(soft) # virtual; or pass via inputs_embeds
# 3) run generation; ensure a <HALT> token exists in the prompt
outputs = self.qwen(input_ids=input_ids, attention_mask=attn_mask,
past_key_values=None, return_dict=True)
logits = outputs.logits # [B, S, V]
# locate HALT hidden state
h_halt = outputs.hidden_states[-1][torch.arange(input_ids.size(0)), self._halt_idx(input_ids), :]
p_halt = self.halt(h_halt) # [B]
return logits, p_halt
# Outer loop: build new prompt with HYP_{t}, rerun forward_once...
Notes
You can implement the soft prompt either by inputs_embeds (recommended) or by reserving m special tokens whose embeddings you overwrite per‑batch.
For generation, use HF’s generate() but capture the position via special token id.
The Loop Controller is a small Python module handling augmentation sampling, prompt building, de‑augmentation, and voting.
- Prompt templates
ARC (per augmented view)
<SOFT_PROMPT m=16> You are given a set of input-output grid examples. Infer the rule and apply it. EXAMPLE_1_IN : [ ... ] EXAMPLE_1_OUT: [ ... ] EXAMPLE_2_IN : [ ... ] EXAMPLE_2_OUT: [ ... ] ... AUGMENT : rot=90 flipH=0 flipV=1 colperm=0312456789
TEST_IN : [ ... ] {%- if t>0 -%} HYP_{t-1} : [ ... ] {%- endif -%} Predict TEST_OUT as a grid only, no extra text.
Text tasks (generic)
<SOFT_PROMPT m=16> Task: {{instruction}} Examples: Q: ... A: ... ... Input: {{input_text}} {%- if t>0 -%} Hypothesis_{t-1}: {{previous_answer}} {%- endif -%} Output:
- Evaluation & logging
Pass@k (k=1,2) over de‑augmented predictions.
Loop telemetry: average T_used, halt entropy, delta edits per step, augmentation skew in votes.
Ablations:
-
T_train ∈ {1, 4, 8, 16} × T_infer ∈ {1, 4, 8}.
-
Augmentation count {30, 100, 300, 1000}.
-
Conditioner on/off; LoRA only vs LoRA+top‑N unfreeze.
-
ACT on/off. These mirror ARC Prize’s analyses so your results are directly comparable.
- Distributed & performance notes
Each refinement pass is an independent forward: no RoPE bias hacks or KV cache reuse across passes needed (simpler than inner HRM loops).
Use p‑jit style data parallel on augmentations; each GPU handles a subset of views; voting happens on host.
Caching: cache tokenized prompts per augmentation; only splice HYP_{t-1} at each loop.
Throughput knobs: reduce augment views first; then reduce T_infer.
- Compatibility with prior HRM code
You can borrow ACT utilities and the general train/infer scaffolding patterns from the official HRM repo—but your backbone stays Qwen‑3 and the loop is outer only.
For historical grounding of outer‑loop + ACT, see Universal Transformer (outer refinement + ACT).
- Why this design matches the new evidence
Outer refinement dominates: We explicitly center the architecture on it; train with larger loop budgets as recommended by ARC Prize’s ablation, which notably improved even single‑loop inference performance.
Transfer vs memorization: We remove puzzle_id and rely on context‑derived conditioning, improving generalization potential and avoiding a hard train‑time prerequisite on the exact inference IDs.
Augmentation efficiency: Default to ~300 train‑time augmentations, with a smaller inference set—mirroring the finding that 300 ≈ 1000 for near‑max.
- Minimal task list to implement
-
Tokenizer & grid I/O (ARC).
-
Prompt builder with HYP_t splice + AUGMENT descriptors.
-
Task Conditioner C (soft‑prompt) + hooks to pass inputs_embeds to Qwen‑3.
-
Halt head (+ special token and position extraction).
-
Loop Controller (train/infer) with augmentation sampler, de‑augmenter, voter.
-
Losses L_final, L_inter, L_edit, L_halt.
-
Trainer (two‑phase freeze/unfreeze; LoRA config).
-
Telemetry (per‑loop stats, vote traces).
-
Ablation scripts to reproduce the key plots (T_train vs T_infer; aug counts).
- Useful references
ARC Prize analysis (Aug 15, 2025) — outer‑loop & aug findings; hierarchy marginal; puzzle_id brittleness.
HRM paper (June/July 2025) — for ACT mechanics, training hygiene, and general loop ideas.
HRM official code — useful scaffolding for halting heads / loops (adapt, don’t adopt wholesale).
ARC‑AGI without pretraining (Liao & Gu) — an adjacent test‑time training framing worth cross‑checking as you evaluate transfer.