Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created April 18, 2026 18:50
Show Gist options
  • Select an option

  • Save pszemraj/601bce9fb6486f3d5d49325024769238 to your computer and use it in GitHub Desktop.

Select an option

Save pszemraj/601bce9fb6486f3d5d49325024769238 to your computer and use it in GitHub Desktop.
reference/annotated version of Parcae arch (looped LM) from param golf experiments
"""
parcae_reference.py
═══════════════════
A single-file, heavily annotated reference implementation of the Parcae
architecture (Prairie, Novack, Berg-Kirkpatrick, Fu — arXiv 2604.12946).
PURPOSE
───────
This is a **pedagogical reference**, not a training rig. Every architectural
choice is traceable to the source repo at `sandyresearch/parcae` with line
citations. The goal: a developer can read this top-to-bottom and understand
1. what Parcae actually does mechanically
2. why each decision is what it is
3. how to modify it without breaking the stability guarantees
What's included:
• Full model: Embed → Prelude → Loop[Inject → Core] → C → Coda → Head
• Real paper init (scaled-zero / std = sqrt(2/5d), ssm_decay = √(1/5))
• TBPTT split (no-grad warmup + gradient tail)
• Per-sequence depth sampling (the stability knob the paper emphasizes)
• Spectral-norm diagnostics
• A forward pass that runs on CPU in bf16 autocast
What's removed (read the real repo if you need these):
• Distributed training, DDP, Fabric
• Gradient checkpointing
• Fused cross-entropy kernels (CCE, linear CE, triton)
• Value-embeddings ("ve") — used in nanochat-style tweaks
• Monitoring hooks (`extreme_metrics`, telemetry)
• Compile/dynamo disables
• MuonAdamW optimizer (use AdamW — close enough for understanding)
SOURCE MAP
──────────
DiagonalInjection ← parcae_lm/modules/injection.py:10-60
Init tables (scaled*) ← parcae_lm/utils/init.py:75-95
Parcae.forward ← parcae_lm/models/parcae/parcae.py:139-239
iterate_forward (loop) ← parcae_lm/models/parcae/parcae.py:244-360
core_block_forward ← parcae_lm/models/parcae/parcae.py:362-386
initialize_state ← parcae_lm/models/parcae/parcae.py:537-556
per-seq depth sampler ← parcae_lm/models/parcae/parcae.py:477-535
140M reference config ← parcae_lm/configs/parcae/parcae-small-140m.py
TO RUN
──────
python parcae_reference.py # build + 1 forward pass + stability check
python parcae_reference.py --tiny # even smaller, fast CPU test
Author: reference extracted & annotated for clarity. Architecture © Prairie et al. 2026.
"""
from __future__ import annotations
import argparse
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
# ═══════════════════════════════════════════════════════════════════════════
# SECTION 1 — CONFIG
# ═══════════════════════════════════════════════════════════════════════════
# The real config lives at parcae_lm/models/parcae/config.py. It inherits from
# the base Config in parcae_lm/models/config.py and adds Parcae-specific fields.
# Everything here is flattened and documented.
@dataclass
class ParcaeConfig:
# ─── Vocabulary & context ──────────────────────────────────────────────
vocab_size: int = 32768 # parcae-small-140m uses 32768
block_size: int = 2048 # max sequence length (RoPE table size)
# ─── Transformer body ──────────────────────────────────────────────────
n_embd: int = 768 # hidden width of prelude/coda
intermediate_size: int = (
3072 # MLP hidden (4× for parcae-140m, actually =n_embd*4=3072)
)
num_attention_heads: int = 6
num_key_value_heads: int = 6 # <num_attention_heads → GQA
# ─── Prelude / Core / Coda layer counts ────────────────────────────────
# The name "Parcae" = the 3 Roman fates:
# Nona/Prelude (P) — spins the thread (projects tokens into latent space)
# Decima/Core (R) — measures the thread (looped recurrence)
# Morta/Coda (C) — cuts the thread (final projection to logits)
n_layers_in_prelude: int = 2
n_layers_in_recurrent_block: int = (
2 # UNIQUE layers in the core; applied mu_rec times
)
n_layers_in_coda: int = 2
# ─── Recurrent state dimensions ────────────────────────────────────────
# The recurrent state x_t lives in R^{recurrent_embedding_dimension}. In all
# released checkpoints this equals n_embd, but the architecture supports a
# separate recurrent width (e.g. for low-rank recurrence experiments).
recurrent_embedding_dimension: int = 768
recurrent_intermediation_embedding_dimension: int = 3072
# ─── Recurrence depth & TBPTT ──────────────────────────────────────────
# Core loop runs `mean_recurrence` times per forward pass.
# Of those, the last `mean_backprop_depth` steps carry gradients (TBPTT).
# Earlier steps run under torch.no_grad() — cheap, no activations stored.
# This is critical: naive backprop through 8 loops of 2 layers each ≈ 16×
# activation memory vs a plain 16-layer transformer.
mean_recurrence: int = 8
mean_backprop_depth: int = 4
# ─── Injection ─────────────────────────────────────────────────────────
# Three options in the repo: "diagonal" (the SSM-style one, default),
# "linear" (concat+project, heavier), "add" (just x + e, ablation baseline).
# "diagonal" is the only one with the spectral-norm stability guarantee.
injection_type: str = "diagonal"
# ─── Depth sampling ────────────────────────────────────────────────────
# "per-batch": one (n, k) pair per whole batch — cheapest, can cause spikes
# "per-sequence": each sequence in the batch gets its own depth — paper default
# "per-token": each token independently — most expensive
# The paper argues per-sequence is the stability sweet spot.
recurrent_iteration_method: str = "per-sequence"
sampling_scheme: str = "poisson-truncated-full"
# ─── Init strategy (picks std for every layer) ─────────────────────────
# "scaled-zero" (used by parcae-small-140m):
# std = sqrt(2 / (5 * n_embd)) (Nguyen-Salazar scaled init)
# out_proj = 0, out_attn = 0 (zero-init residual branches)
# "scaled": same std, but out_proj scaled by sqrt(2*num_layers) instead of 0
# Both specify ssm_decay = sqrt(1/5) ≈ 0.4472 as the diagonal decay target.
init_strategy: str = "scaled-zero"
ssm_decay: float = math.sqrt(1.0 / 5.0) # ρ(Ā) target at init
# ─── Misc ──────────────────────────────────────────────────────────────
norm_eps: float = 1e-5
bias: bool = False
qk_norm: bool = True # per-head RMS-norm on Q and K
prelude_norm: bool = True # extra RMSNorm between prelude and core
tie_embeddings: bool = True
rope_base: float = 50_000.0
# ─── Tiny preset for CPU development (override via .tiny()) ────────────
@classmethod
def tiny(cls) -> "ParcaeConfig":
"""A CPU-friendly config for development / unit tests."""
return cls(
vocab_size=1024,
block_size=128,
n_embd=128,
intermediate_size=256,
num_attention_heads=4,
num_key_value_heads=4,
recurrent_embedding_dimension=128,
recurrent_intermediation_embedding_dimension=256,
n_layers_in_prelude=1,
n_layers_in_recurrent_block=2,
n_layers_in_coda=1,
mean_recurrence=4,
mean_backprop_depth=2,
)
# Derived
@property
def head_dim(self) -> int:
assert self.n_embd % self.num_attention_heads == 0
return self.n_embd // self.num_attention_heads
@property
def total_effective_depth(self) -> int:
"""How many block-applications happen per forward at train µ_rec."""
return (
self.n_layers_in_prelude
+ self.n_layers_in_recurrent_block * self.mean_recurrence
+ self.n_layers_in_coda
)
# ═══════════════════════════════════════════════════════════════════════════
# SECTION 2 — INITIALIZATION HELPERS
# ═══════════════════════════════════════════════════════════════════════════
# These directly mirror parcae_lm/utils/init.py and parcae_lm/models/parcae/init.py.
# The "scaled" family uses std = sqrt(2/5d), a Nguyen-Salazar-style init that
# keeps pre-activation variance ≈ 1 through a ReLU² MLP.
def _scaled_std(dim: int) -> float:
"""The ubiquitous std = sqrt(2/(5d)) from 'scaled' init family."""
return math.sqrt(2.0 / (5.0 * dim))
def _out_proj_std(dim: int, num_layers: int, strategy: str) -> float:
"""
Output-projection init depends on strategy:
'scaled-zero': 0 (residual branches start dead; the model learns to add)
'scaled': sqrt(2/5d) / sqrt(2 * num_layers) (signal propagation bound)
"""
if strategy == "scaled-zero":
return 0.0
# 'scaled' family: Le Scao / Biderman output init
return _scaled_std(dim) / math.sqrt(2 * max(num_layers, 1))
def _trunc_normal_(t: Tensor, std: float):
"""All of Parcae's normals are truncated at ±3σ (wrapped_trunc_normal in repo)."""
if std == 0.0:
t.zero_()
return
nn.init.trunc_normal_(t, mean=0.0, std=std, a=-3 * std, b=3 * std)
# ═══════════════════════════════════════════════════════════════════════════
# SECTION 3 — DIAGONAL INJECTION
# ═══════════════════════════════════════════════════════════════════════════
# THIS IS THE HEART OF PARCAE. Everything else in the model is a standard
# pre-norm transformer. The paper's theoretical contribution is this module.
#
# Mechanism (from parcae_lm/modules/injection.py:42-46):
#
# x_{t+1} = exp(-dt * A) * x_t + dt * (e @ B.T)
# └────────────┬────┘ └────────┬────┘
# state decay input injection
#
# x_t ∈ R^{state_dim} current recurrent hidden state (per position)
# e ∈ R^{input_dim} PRELUDE OUTPUT, frozen across all loop iterations
# A = exp(A_log), A_log ∈ R^{state_dim} (A > 0 always)
# dt = softplus(dt_bias), dt_bias ∈ R^{state_dim} (dt > 0 always)
# B ∈ R^{state_dim × input_dim} input-to-state projection
#
# WHY THIS IS STABLE
# ──────────────────
# The decay factor exp(-dt*A) is ALWAYS in (0, 1) — guaranteed by construction
# since dt > 0 and A > 0. In dynamical-systems language: the spectral radius
# ρ(Ā) is bounded below 1, so the state contracts to a fixed point in the
# absence of new input. This directly prevents residual explosion, which was
# THE failure mode for prior looped transformers (e.g. Geiping's Huginn).
#
# WHY exp(-dt*A) AND NOT JUST A SCALAR
# ─────────────────────────────────────
# Zero-order hold (ZOH) discretization of the continuous-time ODE
# dx/dt = -A * x + B * e
# gives exactly this form. Borrowed wholesale from structured SSMs (S4/Mamba).
# The diagonal A (not general matrix) makes the update elementwise-parallel.
#
# INITIALIZATION (parcae_lm/models/parcae/init.py:31-52)
# ──────────────────────────────────────────────────────
# A_log = 0 → A = 1 (uniform decay rate across dims)
# dt_bias chosen so softplus(dt_bias) = -log(ssm_decay)
# → exp(-dt * 1) = ssm_decay (default: sqrt(1/5) ≈ 0.447)
# B = identity (if state_dim == input_dim) else orthogonal
#
# HOW MUCH STATE IS RETAINED
# ──────────────────────────
# With default init (decay = 0.447), after k loops with no input e:
# ||x_k|| = decay^k * ||x_0||
# k=1: 0.447 of state survives
# k=4: 0.040 ~96% of state has been replaced by injected input
# k=8: 0.0016 state is now almost entirely a function of e
#
# This is intentional. Parcae's authors argue for a "fast-mixing" regime where
# the recurrent state quickly forgets its initialization and becomes a
# function of the prelude output. This is NOT a long-horizon RNN; it's a
# contractive dynamical system being driven to a fixed point.
#
# CAVEAT — observed at Parameter Golf scale
# ─────────────────────────────────────────
# Empirically, at 17M params / 15-step training runs, decay ≈ 0.9 beats the
# paper's 0.447 by ~1.0 val-loss units. The paper calibrated decay for 100B-token
# runs; short regimes may want longer state memory. Verify at your scale.
class DiagonalInjection(nn.Module):
def __init__(self, config: ParcaeConfig):
super().__init__()
self.state_dim = config.recurrent_embedding_dimension
self.input_dim = config.n_embd
self.decay_target = config.ssm_decay
# Parameters — all three are marked _no_weight_decay in the real repo
# (parcae_lm/modules/injection.py:31,35,40). Weight decay would push
# A_log and dt_bias toward 0, which would collapse the decay rate.
self.A_log = nn.Parameter(torch.empty(self.state_dim))
self.dt_bias = nn.Parameter(torch.empty(self.state_dim))
self.B = nn.Parameter(torch.empty(self.state_dim, self.input_dim))
self.A_log._no_weight_decay = True
self.dt_bias._no_weight_decay = True
self.B._no_weight_decay = True
self.reset_parameters()
@torch.no_grad()
def reset_parameters(self):
# A_log = 0 → A = exp(0) = 1
# (parcae_lm/models/parcae/init.py:31)
self.A_log.zero_()
# dt_bias: target softplus(dt_bias) * A = -log(decay_target)
# Since A = 1 at init, we need softplus(dt_bias) = -log(decay_target).
# Inverse softplus of x: log(exp(x) - 1) = x + log(1 - exp(-x))
# The repo uses `dt + torch.log(-torch.expm1(-dt))` which is numerically
# stable — expm1 = exp(x) - 1 avoids precision loss near 0.
# (parcae_lm/models/parcae/init.py:36-40)
target_dt = -math.log(self.decay_target)
dt = torch.full_like(self.dt_bias, target_dt)
self.dt_bias.copy_(dt + torch.log(-torch.expm1(-dt))) # inverse softplus
# B: identity if square, else orthogonal
# (parcae_lm/models/parcae/init.py:44-51)
if self.state_dim == self.input_dim:
self.B.zero_()
self.B.fill_diagonal_(1.0)
else:
nn.init.orthogonal_(self.B)
def forward(self, x_t: Tensor, e: Tensor) -> Tensor:
"""
Args:
x_t: (B, T, state_dim) — current recurrent state
e: (B, T, input_dim) — prelude output, SAME tensor every loop step
Returns:
x_{t+1}: (B, T, state_dim)
"""
dt = F.softplus(self.dt_bias) # (state_dim,), > 0
A = torch.exp(self.A_log) # (state_dim,), > 0
decay = torch.exp(-dt * A) # (state_dim,), ∈ (0, 1)
return x_t * decay + dt * (e @ self.B.T)
# ─── Diagnostics (not in the training path) ────────────────────────────
@torch.no_grad()
def spectral_radius(self) -> dict:
"""Returns (min, mean, max) of the diagonal decay values — all must be < 1."""
dt = F.softplus(self.dt_bias)
A = torch.exp(self.A_log)
d = torch.exp(-dt * A)
return {"min": d.min().item(), "mean": d.mean().item(), "max": d.max().item()}
# ═══════════════════════════════════════════════════════════════════════════
# SECTION 4 — STANDARD TRANSFORMER PIECES (no Parcae-specific logic)
# ═══════════════════════════════════════════════════════════════════════════
# These are bog-standard modern transformer components. Nothing novel here;
# parcae_lm uses the same building blocks as nanochat/litgpt.
class RMSNorm(nn.Module):
"""Root-mean-square layer norm. No mean-centering, cheaper than LayerNorm."""
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: Tensor) -> Tensor:
return F.rms_norm(x, (x.shape[-1],), self.weight, self.eps)
def build_rope(head_dim: int, seq_len: int, base: float = 50_000.0):
"""Precompute (cos, sin) tables of shape (1, 1, T, head_dim/2).
We use the split-half form (first half / second half) rather than the
interleaved-complex form; both are mathematically equivalent RoPE but
have different weight layouts — stay consistent across training.
"""
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(seq_len).float()
freqs = torch.outer(t, inv_freq) # (T, head_dim/2)
return (torch.cos(freqs)[None, None, :, :], torch.sin(freqs)[None, None, :, :])
def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
"""x: (B, n_heads, T, head_dim). Split-half rotation."""
half = x.shape[-1] // 2
x1, x2 = x[..., :half], x[..., half:]
return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
class CausalSelfAttention(nn.Module):
"""Multi-head causal attention with GQA support + optional QK-norm + RoPE.
QK-NORM (qk_norm=True in parcae_lm):
Per-head RMS-norm applied to Q and K before the dot product. Prevents
logit blow-up when Q/K magnitudes drift during training — a stability
tweak popularized by ChameleonQK and Gemma2. Parcae uses it at every
scale. Nearly free compute-wise.
"""
def __init__(self, config: ParcaeConfig):
super().__init__()
self.n_head = config.num_attention_heads
self.n_kv_head = config.num_key_value_heads
self.head_dim = config.head_dim
self.n_rep = self.n_head // self.n_kv_head # GQA repeat factor
self.qk_norm = config.qk_norm
# Separate Q / K / V projections. K and V are narrower when doing GQA.
self.c_q = nn.Linear(
config.n_embd, self.n_head * self.head_dim, bias=config.bias
)
self.c_k = nn.Linear(
config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias
)
self.c_v = nn.Linear(
config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias
)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
def forward(self, x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
B, T, C = x.shape
q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)
if self.qk_norm:
# Per-head RMSNorm — no learnable weights, just magnitude control.
q = F.rms_norm(q, (q.shape[-1],))
k = F.rms_norm(k, (k.shape[-1],))
# GQA: replicate KV heads to match Q heads.
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
# Flash-SDPA-backed causal attention (uses cuDNN / mem-efficient on GPU).
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.c_proj(y.transpose(1, 2).reshape(B, T, C))
class ReLU2(nn.Module):
"""ReLU-squared activation. This, not SwiGLU, is what parcae-small-140m uses."""
def forward(self, x: Tensor) -> Tensor:
return F.relu(x).pow(2)
class BaseMLP(nn.Module):
"""fc → ReLU² → proj. No gating (parcae_lm uses BaseMLP, not SwiGLU)."""
def __init__(self, config: ParcaeConfig):
super().__init__()
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.act = ReLU2()
def forward(self, x: Tensor) -> Tensor:
return self.proj(self.act(self.fc(x)))
class TransformerPreNormBlock(nn.Module):
"""The pre-norm block used by every layer of Prelude, Core, and Coda.
x = x + Attn(RMSNorm(x))
x = x + MLP (RMSNorm(x))
Note: Parcae uses the SAME block class everywhere. The looping happens
outside — in `iterate_forward`, not inside the block.
"""
def __init__(self, config: ParcaeConfig):
super().__init__()
self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config)
self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps)
self.mlp = BaseMLP(config)
def forward(self, x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
x = x + self.attn(self.norm_1(x), cos, sin)
x = x + self.mlp(self.norm_2(x))
return x
# ═══════════════════════════════════════════════════════════════════════════
# SECTION 5 — PARCAE MODEL
# ═══════════════════════════════════════════════════════════════════════════
# The actual assembly. Most lines here are bookkeeping; the interesting bits
# are `iterate_forward` and `_sample_per_sequence_depths`.
class Parcae(nn.Module):
"""
Flow (mirrors parcae_lm/models/parcae/parcae.py):
input_ids
│ wte (embedding)
embed
│ prelude blocks (run once) (2 unique layers for 140M)
e = ln_prelude(prelude_out) <───┐ e is the INPUT SIGNAL,
│ │ injected at each loop step.
│ initialize_state(e) → x │ It is FROZEN during the loop.
▼ │
┌──────────────────────────────────┐ │
│ REPEAT µ_rec TIMES: │ │
│ x = adapter(x, e) ◄────── DiagonalInjection ┘
│ for block in core_block: │
│ x = block(x) (2 unique layers, reused)
└──────────────────────────────────┘
│ C (linear projection back to n_embd)
coda blocks (run once) (2 unique layers for 140M)
│ ln_f + lm_head
logits → cross_entropy
"""
def __init__(self, config: ParcaeConfig):
super().__init__()
self.config = config
d = config.n_embd
# Token embedding (padded_vocab_size in real repo; skipping padding here).
self.wte = nn.Embedding(config.vocab_size, d)
# Prelude: UNIQUE layers, run exactly once.
self.prelude = nn.ModuleList(
TransformerPreNormBlock(config) for _ in range(config.n_layers_in_prelude)
)
# Optional RMSNorm between prelude and core — paper-default True.
self.ln_prelude = (
RMSNorm(d, eps=config.norm_eps) if config.prelude_norm else None
)
# Injection module — the mechanism.
if config.injection_type == "diagonal":
self.adapter = DiagonalInjection(config)
elif config.injection_type == "add":
# Ablation: just x + e. Keeps prelude/loop/coda structure.
self.adapter = _AdditiveInjection()
else:
raise ValueError(
f"injection_type={config.injection_type!r} not implemented here"
" — see parcae_lm/modules/injection.py for 'linear'"
)
# Core: SHARED layers, applied µ_rec times per forward.
self.core_block = nn.ModuleList(
TransformerPreNormBlock(config)
for _ in range(config.n_layers_in_recurrent_block)
)
# C: project recurrent state back to n_embd for the coda.
# NOTE: This is NOT identity-initialized. It gets the default "scaled" std
# = sqrt(2/5d). It's marked _no_weight_decay in the real repo.
# (parcae_lm/models/parcae/parcae.py:33-39)
self.C = nn.Linear(d, d, bias=config.bias)
self.C.weight._no_weight_decay = True
# Coda: UNIQUE layers, run once after the loop.
self.coda = nn.ModuleList(
TransformerPreNormBlock(config) for _ in range(config.n_layers_in_coda)
)
self.ln_f = RMSNorm(d, eps=config.norm_eps)
self.lm_head = nn.Linear(d, config.vocab_size, bias=False)
if config.tie_embeddings:
self.lm_head.weight = self.wte.weight
# RoPE frequencies as a non-trainable buffer.
cos, sin = build_rope(config.head_dim, config.block_size, config.rope_base)
self.register_buffer("rope_cos", cos, persistent=False)
self.register_buffer("rope_sin", sin, persistent=False)
self._apply_paper_init()
self.step = 0 # used by the per-sequence depth sampler seed
# ─── Initialization ────────────────────────────────────────────────────
@torch.no_grad()
def _apply_paper_init(self):
"""Apply the 'scaled-zero' init family from parcae_lm/utils/init.py."""
d = self.config.n_embd
n_layers = self.config.total_effective_depth
strategy = self.config.init_strategy
std_std = _scaled_std(d)
std_outproj = _out_proj_std(d, n_layers, strategy)
for name, p in self.named_parameters():
if p.dim() < 2:
continue # skip RMSNorm weights, A_log, dt_bias (handled elsewhere)
if "A_log" in name or "dt_bias" in name:
continue # DiagonalInjection handles its own init
if name == "wte.weight" or name == "lm_head.weight":
_trunc_normal_(p, std_std) # embedding uses std_std
continue
# Attention output proj & MLP down-proj are "out_proj"-style
if name.endswith("c_proj.weight") or name.endswith("mlp.proj.weight"):
_trunc_normal_(p, std_outproj)
continue
# B (DiagonalInjection): already initialized to identity/orthogonal
if name.endswith("adapter.B"):
continue
# Everything else: default std
_trunc_normal_(p, std_std)
# Re-init DiagonalInjection (its own scheme overrides above for A_log/dt_bias/B)
if isinstance(self.adapter, DiagonalInjection):
self.adapter.reset_parameters()
# ─── State initialization ──────────────────────────────────────────────
def initialize_state(self, e: Tensor) -> Tensor:
"""
Create x_0. The real repo offers several strategies (state_init):
'zero' : x_0 = 0 (deterministic)
'normal' : x_0 ~ N(0, 1) (unit variance)
'embed' : x_0 ~ N(0, 1/sqrt(d)) (matches embedding scale)
'like-init' : x_0 ~ trunc_normal(std_std) (default, matches init dist)
'unit' : random, then normalized to unit L2 per token
We use 'like-init' — statistically matched to what a freshly embedded
token looks like. Source: parcae_lm/models/parcae/parcae.py:537-556
"""
std = _scaled_std(self.config.n_embd)
x = torch.empty_like(e)
_trunc_normal_(x, std)
return x
# ─── Per-sequence depth sampling ───────────────────────────────────────
@torch.no_grad()
def _sample_per_sequence_depths(
self, batch_size: int, device
) -> tuple[Tensor, Tensor]:
"""
For each sequence in the batch, sample a recurrence depth (n, k).
n : number of no-grad "convergence" steps
k : number of gradient-bearing steps (TBPTT window)
Why per-sequence instead of per-batch?
──────────────────────────────────────
Per-batch sampling gave the paper's authors periodic loss spikes: when
the sampler picked a short depth, the model's recurrent state hadn't
converged, and the gradient signal was way off-distribution. Per-sequence
averages this out within a batch. See Figure 10/11 of the paper.
Scheme we implement: 'poisson-truncated-full' — total = Poisson(t+s),
clamped ≥ 1; k = min(total, s); n = total - k.
Source: parcae_lm/models/parcae/parcae.py:520-524
"""
t_target = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0)
s_target = self.config.mean_backprop_depth
total_target = float(t_target + s_target)
gen = torch.Generator(device="cpu")
gen.manual_seed((514229 + self.step) % (2**31 - 1))
total = torch.clamp(
torch.poisson(torch.full((batch_size,), total_target), generator=gen), min=1
)
k = torch.clamp(total, max=float(s_target))
n = total - k
return n.long().to(device), k.long().to(device)
# ─── The recurrence loop ───────────────────────────────────────────────
def iterate_forward(
self, e: Tensor, cos: Tensor, sin: Tensor, mu_rec_override: Optional[int] = None
) -> Tensor:
"""
Apply the core block µ_rec times, with:
• no_grad warm-up for first (µ_rec - backprop_depth) steps
• gradient tracking for last (backprop_depth) steps
• per-sequence depth masking so each sequence uses its sampled depth
The per-sequence trick: we always run MAX(sampled) steps, but for
sequences whose sampled depth is less than the step index, we leave
their state untouched (via torch.where). This is O(max_depth) forward
passes but respects per-sequence variation in depth.
Source: parcae_lm/models/parcae/parcae.py:317-356
"""
x = self.initialize_state(e)
if mu_rec_override is not None:
# Deterministic path used at eval / for debugging.
T = mu_rec_override
n_nograd = max(T - self.config.mean_backprop_depth, 0)
n_grad = T - n_nograd
with torch.no_grad():
for _ in range(n_nograd):
x = self._core_step(x, e, cos, sin)
for _ in range(n_grad):
x = self._core_step(x, e, cos, sin)
return x
# Training path — per-sequence sampling.
if self.training and self.config.recurrent_iteration_method == "per-sequence":
B = e.shape[0]
n_per_sample, k_per_sample = self._sample_per_sequence_depths(B, e.device)
max_n = int(n_per_sample.max().item())
max_k = int(k_per_sample.max().item())
with torch.no_grad():
for step in range(max_n):
x_new = self._core_step(x, e, cos, sin)
# Mask: only update sequences whose sampled n > step.
active = (step < n_per_sample).view(B, 1, 1)
x = torch.where(active, x_new, x)
for step in range(max_k):
x_new = self._core_step(x, e, cos, sin)
active = (step < k_per_sample).view(B, 1, 1)
x = torch.where(active, x_new, x)
return x
# Fallback: plain per-batch (use config.mean_recurrence as fixed depth).
T = self.config.mean_recurrence
n_nograd = max(T - self.config.mean_backprop_depth, 0)
n_grad = T - n_nograd
with torch.no_grad():
for _ in range(n_nograd):
x = self._core_step(x, e, cos, sin)
for _ in range(n_grad):
x = self._core_step(x, e, cos, sin)
return x
def _core_step(self, x: Tensor, e: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
"""One iteration: INJECT first (updates state with prelude signal),
then run every core block. Source: parcae_lm/models/parcae/parcae.py:362-386
"""
x = self.adapter(x, e)
for block in self.core_block:
x = block(x, cos, sin)
return x
# ─── Full forward pass ─────────────────────────────────────────────────
def forward(
self,
input_ids: Tensor,
labels: Optional[Tensor] = None,
mu_rec_override: Optional[int] = None,
) -> dict:
"""
Args:
input_ids: (B, T) int64
labels: (B, T) int64 — if given, returns loss
mu_rec_override: int — force a specific recurrence depth (eval/debug)
"""
B, T = input_ids.shape
assert (
T <= self.config.block_size
), f"seq {T} > block_size {self.config.block_size}"
cos = self.rope_cos[:, :, :T]
sin = self.rope_sin[:, :, :T]
# 1. Embed
x = self.wte(input_ids) # (B, T, d)
# 2. Prelude
for block in self.prelude:
x = block(x, cos, sin)
# 3. Optional norm between prelude and core
if self.ln_prelude is not None:
x = self.ln_prelude(x)
# 4. The prelude output is the INPUT SIGNAL for the loop. Frozen from here on.
e = x
# 5. Iterate the core block
x = self.iterate_forward(e, cos, sin, mu_rec_override=mu_rec_override)
# 6. Project state back to n_embd
x = self.C(x)
# 7. Coda
for block in self.coda:
x = block(x, cos, sin)
# 8. Final norm + head
x = self.ln_f(x)
logits = self.lm_head(x)
out = {"logits": logits}
if labels is not None:
out["loss"] = F.cross_entropy(
logits.reshape(-1, logits.size(-1)).float(), # upcast for CE stability
labels.reshape(-1),
ignore_index=-100,
)
if self.training:
self.step += 1
return out
class _AdditiveInjection(nn.Module):
"""Ablation baseline: x_{t+1} = x_t + e (no state decay, no learned mixing)."""
def forward(self, x: Tensor, e: Tensor) -> Tensor:
return x + e
# ═══════════════════════════════════════════════════════════════════════════
# SECTION 6 — DEMO / SELF-TEST
# ═══════════════════════════════════════════════════════════════════════════
def _demo(tiny: bool = False):
torch.manual_seed(0)
cfg = (
ParcaeConfig.tiny()
if tiny
else ParcaeConfig(
vocab_size=2048,
block_size=128,
n_embd=192,
intermediate_size=512,
num_attention_heads=6,
num_key_value_heads=6,
recurrent_embedding_dimension=192,
recurrent_intermediation_embedding_dimension=512,
n_layers_in_prelude=2,
n_layers_in_recurrent_block=2,
n_layers_in_coda=2,
mean_recurrence=4,
mean_backprop_depth=2,
)
)
model = Parcae(cfg)
n_params = sum(p.numel() for p in model.parameters()) - (
model.lm_head.weight.numel() if cfg.tie_embeddings else 0
)
print(f"Parcae config (total_effective_depth={cfg.total_effective_depth}):")
print(
f" n_embd={cfg.n_embd} heads={cfg.num_attention_heads} "
f"µ_rec={cfg.mean_recurrence} bp_depth={cfg.mean_backprop_depth} "
f"inject={cfg.injection_type} init={cfg.init_strategy}"
)
print(f"Params: {n_params:,}\n")
# Spectral-norm diagnostic
if isinstance(model.adapter, DiagonalInjection):
sr = model.adapter.spectral_radius()
print(
f"DiagonalInjection @ init: ρ(decay) = "
f"min={sr['min']:.4f} mean={sr['mean']:.4f} max={sr['max']:.4f}"
)
assert sr["max"] < 1.0, "spectral radius must be < 1 by construction!"
print(" ✓ Spectral radius guarantee holds.\n")
# One forward pass (bf16 autocast, training mode → exercises per-seq sampler)
ids = torch.randint(0, cfg.vocab_size, (2, 32))
labels = torch.randint(0, cfg.vocab_size, (2, 32))
model.train()
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
out = model(ids, labels=labels)
print(
f"Forward (training, per-sequence sampling): "
f"logits={tuple(out['logits'].shape)} loss={out['loss'].item():.4f}"
)
# Backward — make sure gradients flow through the gradient-tail of the loop
out["loss"].backward()
inj_grad_norm = (
model.adapter.A_log.grad.norm().item()
if isinstance(model.adapter, DiagonalInjection)
else 0.0
)
print(
f"Backward OK. ||grad A_log|| = {inj_grad_norm:.6f} "
f"(nonzero ⇒ injection is trainable)"
)
# Eval — deterministic µ_rec
model.eval()
with torch.no_grad():
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
out_eval = model(ids, mu_rec_override=cfg.mean_recurrence)
print(
f"Forward (eval, µ_rec={cfg.mean_recurrence} fixed): "
f"logits={tuple(out_eval['logits'].shape)}"
)
# Test-time scaling: verify L(T) saturates as the paper predicts
# (it will not saturate at init — but the shape of the forward should work
# for any µ_rec within reason)
print(
"\nTest-time µ_rec sweep (from a random-init model, no expectation of monotonicity):"
)
for T in [1, 2, 4, 8, 16]:
with torch.no_grad():
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
o = model(ids, labels=labels, mu_rec_override=T)
print(f" µ_rec={T:>2d}: loss={o['loss'].item():.4f}")
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--tiny", action="store_true", help="Use tiny CPU config")
args = ap.parse_args()
_demo(tiny=args.tiny)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment