Created
April 18, 2026 18:50
-
-
Save pszemraj/601bce9fb6486f3d5d49325024769238 to your computer and use it in GitHub Desktop.
reference/annotated version of Parcae arch (looped LM) from param golf experiments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| parcae_reference.py | |
| ═══════════════════ | |
| A single-file, heavily annotated reference implementation of the Parcae | |
| architecture (Prairie, Novack, Berg-Kirkpatrick, Fu — arXiv 2604.12946). | |
| PURPOSE | |
| ─────── | |
| This is a **pedagogical reference**, not a training rig. Every architectural | |
| choice is traceable to the source repo at `sandyresearch/parcae` with line | |
| citations. The goal: a developer can read this top-to-bottom and understand | |
| 1. what Parcae actually does mechanically | |
| 2. why each decision is what it is | |
| 3. how to modify it without breaking the stability guarantees | |
| What's included: | |
| • Full model: Embed → Prelude → Loop[Inject → Core] → C → Coda → Head | |
| • Real paper init (scaled-zero / std = sqrt(2/5d), ssm_decay = √(1/5)) | |
| • TBPTT split (no-grad warmup + gradient tail) | |
| • Per-sequence depth sampling (the stability knob the paper emphasizes) | |
| • Spectral-norm diagnostics | |
| • A forward pass that runs on CPU in bf16 autocast | |
| What's removed (read the real repo if you need these): | |
| • Distributed training, DDP, Fabric | |
| • Gradient checkpointing | |
| • Fused cross-entropy kernels (CCE, linear CE, triton) | |
| • Value-embeddings ("ve") — used in nanochat-style tweaks | |
| • Monitoring hooks (`extreme_metrics`, telemetry) | |
| • Compile/dynamo disables | |
| • MuonAdamW optimizer (use AdamW — close enough for understanding) | |
| SOURCE MAP | |
| ────────── | |
| DiagonalInjection ← parcae_lm/modules/injection.py:10-60 | |
| Init tables (scaled*) ← parcae_lm/utils/init.py:75-95 | |
| Parcae.forward ← parcae_lm/models/parcae/parcae.py:139-239 | |
| iterate_forward (loop) ← parcae_lm/models/parcae/parcae.py:244-360 | |
| core_block_forward ← parcae_lm/models/parcae/parcae.py:362-386 | |
| initialize_state ← parcae_lm/models/parcae/parcae.py:537-556 | |
| per-seq depth sampler ← parcae_lm/models/parcae/parcae.py:477-535 | |
| 140M reference config ← parcae_lm/configs/parcae/parcae-small-140m.py | |
| TO RUN | |
| ────── | |
| python parcae_reference.py # build + 1 forward pass + stability check | |
| python parcae_reference.py --tiny # even smaller, fast CPU test | |
| Author: reference extracted & annotated for clarity. Architecture © Prairie et al. 2026. | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # SECTION 1 — CONFIG | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # The real config lives at parcae_lm/models/parcae/config.py. It inherits from | |
| # the base Config in parcae_lm/models/config.py and adds Parcae-specific fields. | |
| # Everything here is flattened and documented. | |
| @dataclass | |
| class ParcaeConfig: | |
| # ─── Vocabulary & context ────────────────────────────────────────────── | |
| vocab_size: int = 32768 # parcae-small-140m uses 32768 | |
| block_size: int = 2048 # max sequence length (RoPE table size) | |
| # ─── Transformer body ────────────────────────────────────────────────── | |
| n_embd: int = 768 # hidden width of prelude/coda | |
| intermediate_size: int = ( | |
| 3072 # MLP hidden (4× for parcae-140m, actually =n_embd*4=3072) | |
| ) | |
| num_attention_heads: int = 6 | |
| num_key_value_heads: int = 6 # <num_attention_heads → GQA | |
| # ─── Prelude / Core / Coda layer counts ──────────────────────────────── | |
| # The name "Parcae" = the 3 Roman fates: | |
| # Nona/Prelude (P) — spins the thread (projects tokens into latent space) | |
| # Decima/Core (R) — measures the thread (looped recurrence) | |
| # Morta/Coda (C) — cuts the thread (final projection to logits) | |
| n_layers_in_prelude: int = 2 | |
| n_layers_in_recurrent_block: int = ( | |
| 2 # UNIQUE layers in the core; applied mu_rec times | |
| ) | |
| n_layers_in_coda: int = 2 | |
| # ─── Recurrent state dimensions ──────────────────────────────────────── | |
| # The recurrent state x_t lives in R^{recurrent_embedding_dimension}. In all | |
| # released checkpoints this equals n_embd, but the architecture supports a | |
| # separate recurrent width (e.g. for low-rank recurrence experiments). | |
| recurrent_embedding_dimension: int = 768 | |
| recurrent_intermediation_embedding_dimension: int = 3072 | |
| # ─── Recurrence depth & TBPTT ────────────────────────────────────────── | |
| # Core loop runs `mean_recurrence` times per forward pass. | |
| # Of those, the last `mean_backprop_depth` steps carry gradients (TBPTT). | |
| # Earlier steps run under torch.no_grad() — cheap, no activations stored. | |
| # This is critical: naive backprop through 8 loops of 2 layers each ≈ 16× | |
| # activation memory vs a plain 16-layer transformer. | |
| mean_recurrence: int = 8 | |
| mean_backprop_depth: int = 4 | |
| # ─── Injection ───────────────────────────────────────────────────────── | |
| # Three options in the repo: "diagonal" (the SSM-style one, default), | |
| # "linear" (concat+project, heavier), "add" (just x + e, ablation baseline). | |
| # "diagonal" is the only one with the spectral-norm stability guarantee. | |
| injection_type: str = "diagonal" | |
| # ─── Depth sampling ──────────────────────────────────────────────────── | |
| # "per-batch": one (n, k) pair per whole batch — cheapest, can cause spikes | |
| # "per-sequence": each sequence in the batch gets its own depth — paper default | |
| # "per-token": each token independently — most expensive | |
| # The paper argues per-sequence is the stability sweet spot. | |
| recurrent_iteration_method: str = "per-sequence" | |
| sampling_scheme: str = "poisson-truncated-full" | |
| # ─── Init strategy (picks std for every layer) ───────────────────────── | |
| # "scaled-zero" (used by parcae-small-140m): | |
| # std = sqrt(2 / (5 * n_embd)) (Nguyen-Salazar scaled init) | |
| # out_proj = 0, out_attn = 0 (zero-init residual branches) | |
| # "scaled": same std, but out_proj scaled by sqrt(2*num_layers) instead of 0 | |
| # Both specify ssm_decay = sqrt(1/5) ≈ 0.4472 as the diagonal decay target. | |
| init_strategy: str = "scaled-zero" | |
| ssm_decay: float = math.sqrt(1.0 / 5.0) # ρ(Ā) target at init | |
| # ─── Misc ────────────────────────────────────────────────────────────── | |
| norm_eps: float = 1e-5 | |
| bias: bool = False | |
| qk_norm: bool = True # per-head RMS-norm on Q and K | |
| prelude_norm: bool = True # extra RMSNorm between prelude and core | |
| tie_embeddings: bool = True | |
| rope_base: float = 50_000.0 | |
| # ─── Tiny preset for CPU development (override via .tiny()) ──────────── | |
| @classmethod | |
| def tiny(cls) -> "ParcaeConfig": | |
| """A CPU-friendly config for development / unit tests.""" | |
| return cls( | |
| vocab_size=1024, | |
| block_size=128, | |
| n_embd=128, | |
| intermediate_size=256, | |
| num_attention_heads=4, | |
| num_key_value_heads=4, | |
| recurrent_embedding_dimension=128, | |
| recurrent_intermediation_embedding_dimension=256, | |
| n_layers_in_prelude=1, | |
| n_layers_in_recurrent_block=2, | |
| n_layers_in_coda=1, | |
| mean_recurrence=4, | |
| mean_backprop_depth=2, | |
| ) | |
| # Derived | |
| @property | |
| def head_dim(self) -> int: | |
| assert self.n_embd % self.num_attention_heads == 0 | |
| return self.n_embd // self.num_attention_heads | |
| @property | |
| def total_effective_depth(self) -> int: | |
| """How many block-applications happen per forward at train µ_rec.""" | |
| return ( | |
| self.n_layers_in_prelude | |
| + self.n_layers_in_recurrent_block * self.mean_recurrence | |
| + self.n_layers_in_coda | |
| ) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # SECTION 2 — INITIALIZATION HELPERS | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # These directly mirror parcae_lm/utils/init.py and parcae_lm/models/parcae/init.py. | |
| # The "scaled" family uses std = sqrt(2/5d), a Nguyen-Salazar-style init that | |
| # keeps pre-activation variance ≈ 1 through a ReLU² MLP. | |
| def _scaled_std(dim: int) -> float: | |
| """The ubiquitous std = sqrt(2/(5d)) from 'scaled' init family.""" | |
| return math.sqrt(2.0 / (5.0 * dim)) | |
| def _out_proj_std(dim: int, num_layers: int, strategy: str) -> float: | |
| """ | |
| Output-projection init depends on strategy: | |
| 'scaled-zero': 0 (residual branches start dead; the model learns to add) | |
| 'scaled': sqrt(2/5d) / sqrt(2 * num_layers) (signal propagation bound) | |
| """ | |
| if strategy == "scaled-zero": | |
| return 0.0 | |
| # 'scaled' family: Le Scao / Biderman output init | |
| return _scaled_std(dim) / math.sqrt(2 * max(num_layers, 1)) | |
| def _trunc_normal_(t: Tensor, std: float): | |
| """All of Parcae's normals are truncated at ±3σ (wrapped_trunc_normal in repo).""" | |
| if std == 0.0: | |
| t.zero_() | |
| return | |
| nn.init.trunc_normal_(t, mean=0.0, std=std, a=-3 * std, b=3 * std) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # SECTION 3 — DIAGONAL INJECTION | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # THIS IS THE HEART OF PARCAE. Everything else in the model is a standard | |
| # pre-norm transformer. The paper's theoretical contribution is this module. | |
| # | |
| # Mechanism (from parcae_lm/modules/injection.py:42-46): | |
| # | |
| # x_{t+1} = exp(-dt * A) * x_t + dt * (e @ B.T) | |
| # └────────────┬────┘ └────────┬────┘ | |
| # state decay input injection | |
| # | |
| # x_t ∈ R^{state_dim} current recurrent hidden state (per position) | |
| # e ∈ R^{input_dim} PRELUDE OUTPUT, frozen across all loop iterations | |
| # A = exp(A_log), A_log ∈ R^{state_dim} (A > 0 always) | |
| # dt = softplus(dt_bias), dt_bias ∈ R^{state_dim} (dt > 0 always) | |
| # B ∈ R^{state_dim × input_dim} input-to-state projection | |
| # | |
| # WHY THIS IS STABLE | |
| # ────────────────── | |
| # The decay factor exp(-dt*A) is ALWAYS in (0, 1) — guaranteed by construction | |
| # since dt > 0 and A > 0. In dynamical-systems language: the spectral radius | |
| # ρ(Ā) is bounded below 1, so the state contracts to a fixed point in the | |
| # absence of new input. This directly prevents residual explosion, which was | |
| # THE failure mode for prior looped transformers (e.g. Geiping's Huginn). | |
| # | |
| # WHY exp(-dt*A) AND NOT JUST A SCALAR | |
| # ───────────────────────────────────── | |
| # Zero-order hold (ZOH) discretization of the continuous-time ODE | |
| # dx/dt = -A * x + B * e | |
| # gives exactly this form. Borrowed wholesale from structured SSMs (S4/Mamba). | |
| # The diagonal A (not general matrix) makes the update elementwise-parallel. | |
| # | |
| # INITIALIZATION (parcae_lm/models/parcae/init.py:31-52) | |
| # ────────────────────────────────────────────────────── | |
| # A_log = 0 → A = 1 (uniform decay rate across dims) | |
| # dt_bias chosen so softplus(dt_bias) = -log(ssm_decay) | |
| # → exp(-dt * 1) = ssm_decay (default: sqrt(1/5) ≈ 0.447) | |
| # B = identity (if state_dim == input_dim) else orthogonal | |
| # | |
| # HOW MUCH STATE IS RETAINED | |
| # ────────────────────────── | |
| # With default init (decay = 0.447), after k loops with no input e: | |
| # ||x_k|| = decay^k * ||x_0|| | |
| # k=1: 0.447 of state survives | |
| # k=4: 0.040 ~96% of state has been replaced by injected input | |
| # k=8: 0.0016 state is now almost entirely a function of e | |
| # | |
| # This is intentional. Parcae's authors argue for a "fast-mixing" regime where | |
| # the recurrent state quickly forgets its initialization and becomes a | |
| # function of the prelude output. This is NOT a long-horizon RNN; it's a | |
| # contractive dynamical system being driven to a fixed point. | |
| # | |
| # CAVEAT — observed at Parameter Golf scale | |
| # ───────────────────────────────────────── | |
| # Empirically, at 17M params / 15-step training runs, decay ≈ 0.9 beats the | |
| # paper's 0.447 by ~1.0 val-loss units. The paper calibrated decay for 100B-token | |
| # runs; short regimes may want longer state memory. Verify at your scale. | |
| class DiagonalInjection(nn.Module): | |
| def __init__(self, config: ParcaeConfig): | |
| super().__init__() | |
| self.state_dim = config.recurrent_embedding_dimension | |
| self.input_dim = config.n_embd | |
| self.decay_target = config.ssm_decay | |
| # Parameters — all three are marked _no_weight_decay in the real repo | |
| # (parcae_lm/modules/injection.py:31,35,40). Weight decay would push | |
| # A_log and dt_bias toward 0, which would collapse the decay rate. | |
| self.A_log = nn.Parameter(torch.empty(self.state_dim)) | |
| self.dt_bias = nn.Parameter(torch.empty(self.state_dim)) | |
| self.B = nn.Parameter(torch.empty(self.state_dim, self.input_dim)) | |
| self.A_log._no_weight_decay = True | |
| self.dt_bias._no_weight_decay = True | |
| self.B._no_weight_decay = True | |
| self.reset_parameters() | |
| @torch.no_grad() | |
| def reset_parameters(self): | |
| # A_log = 0 → A = exp(0) = 1 | |
| # (parcae_lm/models/parcae/init.py:31) | |
| self.A_log.zero_() | |
| # dt_bias: target softplus(dt_bias) * A = -log(decay_target) | |
| # Since A = 1 at init, we need softplus(dt_bias) = -log(decay_target). | |
| # Inverse softplus of x: log(exp(x) - 1) = x + log(1 - exp(-x)) | |
| # The repo uses `dt + torch.log(-torch.expm1(-dt))` which is numerically | |
| # stable — expm1 = exp(x) - 1 avoids precision loss near 0. | |
| # (parcae_lm/models/parcae/init.py:36-40) | |
| target_dt = -math.log(self.decay_target) | |
| dt = torch.full_like(self.dt_bias, target_dt) | |
| self.dt_bias.copy_(dt + torch.log(-torch.expm1(-dt))) # inverse softplus | |
| # B: identity if square, else orthogonal | |
| # (parcae_lm/models/parcae/init.py:44-51) | |
| if self.state_dim == self.input_dim: | |
| self.B.zero_() | |
| self.B.fill_diagonal_(1.0) | |
| else: | |
| nn.init.orthogonal_(self.B) | |
| def forward(self, x_t: Tensor, e: Tensor) -> Tensor: | |
| """ | |
| Args: | |
| x_t: (B, T, state_dim) — current recurrent state | |
| e: (B, T, input_dim) — prelude output, SAME tensor every loop step | |
| Returns: | |
| x_{t+1}: (B, T, state_dim) | |
| """ | |
| dt = F.softplus(self.dt_bias) # (state_dim,), > 0 | |
| A = torch.exp(self.A_log) # (state_dim,), > 0 | |
| decay = torch.exp(-dt * A) # (state_dim,), ∈ (0, 1) | |
| return x_t * decay + dt * (e @ self.B.T) | |
| # ─── Diagnostics (not in the training path) ──────────────────────────── | |
| @torch.no_grad() | |
| def spectral_radius(self) -> dict: | |
| """Returns (min, mean, max) of the diagonal decay values — all must be < 1.""" | |
| dt = F.softplus(self.dt_bias) | |
| A = torch.exp(self.A_log) | |
| d = torch.exp(-dt * A) | |
| return {"min": d.min().item(), "mean": d.mean().item(), "max": d.max().item()} | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # SECTION 4 — STANDARD TRANSFORMER PIECES (no Parcae-specific logic) | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # These are bog-standard modern transformer components. Nothing novel here; | |
| # parcae_lm uses the same building blocks as nanochat/litgpt. | |
| class RMSNorm(nn.Module): | |
| """Root-mean-square layer norm. No mean-centering, cheaper than LayerNorm.""" | |
| def __init__(self, dim: int, eps: float = 1e-5): | |
| super().__init__() | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| self.eps = eps | |
| def forward(self, x: Tensor) -> Tensor: | |
| return F.rms_norm(x, (x.shape[-1],), self.weight, self.eps) | |
| def build_rope(head_dim: int, seq_len: int, base: float = 50_000.0): | |
| """Precompute (cos, sin) tables of shape (1, 1, T, head_dim/2). | |
| We use the split-half form (first half / second half) rather than the | |
| interleaved-complex form; both are mathematically equivalent RoPE but | |
| have different weight layouts — stay consistent across training. | |
| """ | |
| inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) | |
| t = torch.arange(seq_len).float() | |
| freqs = torch.outer(t, inv_freq) # (T, head_dim/2) | |
| return (torch.cos(freqs)[None, None, :, :], torch.sin(freqs)[None, None, :, :]) | |
| def apply_rope(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: | |
| """x: (B, n_heads, T, head_dim). Split-half rotation.""" | |
| half = x.shape[-1] // 2 | |
| x1, x2 = x[..., :half], x[..., half:] | |
| return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1) | |
| class CausalSelfAttention(nn.Module): | |
| """Multi-head causal attention with GQA support + optional QK-norm + RoPE. | |
| QK-NORM (qk_norm=True in parcae_lm): | |
| Per-head RMS-norm applied to Q and K before the dot product. Prevents | |
| logit blow-up when Q/K magnitudes drift during training — a stability | |
| tweak popularized by ChameleonQK and Gemma2. Parcae uses it at every | |
| scale. Nearly free compute-wise. | |
| """ | |
| def __init__(self, config: ParcaeConfig): | |
| super().__init__() | |
| self.n_head = config.num_attention_heads | |
| self.n_kv_head = config.num_key_value_heads | |
| self.head_dim = config.head_dim | |
| self.n_rep = self.n_head // self.n_kv_head # GQA repeat factor | |
| self.qk_norm = config.qk_norm | |
| # Separate Q / K / V projections. K and V are narrower when doing GQA. | |
| self.c_q = nn.Linear( | |
| config.n_embd, self.n_head * self.head_dim, bias=config.bias | |
| ) | |
| self.c_k = nn.Linear( | |
| config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias | |
| ) | |
| self.c_v = nn.Linear( | |
| config.n_embd, self.n_kv_head * self.head_dim, bias=config.bias | |
| ) | |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | |
| def forward(self, x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: | |
| B, T, C = x.shape | |
| q = self.c_q(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) | |
| k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) | |
| v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) | |
| q = apply_rope(q, cos, sin) | |
| k = apply_rope(k, cos, sin) | |
| if self.qk_norm: | |
| # Per-head RMSNorm — no learnable weights, just magnitude control. | |
| q = F.rms_norm(q, (q.shape[-1],)) | |
| k = F.rms_norm(k, (k.shape[-1],)) | |
| # GQA: replicate KV heads to match Q heads. | |
| if self.n_rep > 1: | |
| k = k.repeat_interleave(self.n_rep, dim=1) | |
| v = v.repeat_interleave(self.n_rep, dim=1) | |
| # Flash-SDPA-backed causal attention (uses cuDNN / mem-efficient on GPU). | |
| y = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
| return self.c_proj(y.transpose(1, 2).reshape(B, T, C)) | |
| class ReLU2(nn.Module): | |
| """ReLU-squared activation. This, not SwiGLU, is what parcae-small-140m uses.""" | |
| def forward(self, x: Tensor) -> Tensor: | |
| return F.relu(x).pow(2) | |
| class BaseMLP(nn.Module): | |
| """fc → ReLU² → proj. No gating (parcae_lm uses BaseMLP, not SwiGLU).""" | |
| def __init__(self, config: ParcaeConfig): | |
| super().__init__() | |
| self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) | |
| self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) | |
| self.act = ReLU2() | |
| def forward(self, x: Tensor) -> Tensor: | |
| return self.proj(self.act(self.fc(x))) | |
| class TransformerPreNormBlock(nn.Module): | |
| """The pre-norm block used by every layer of Prelude, Core, and Coda. | |
| x = x + Attn(RMSNorm(x)) | |
| x = x + MLP (RMSNorm(x)) | |
| Note: Parcae uses the SAME block class everywhere. The looping happens | |
| outside — in `iterate_forward`, not inside the block. | |
| """ | |
| def __init__(self, config: ParcaeConfig): | |
| super().__init__() | |
| self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps) | |
| self.attn = CausalSelfAttention(config) | |
| self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps) | |
| self.mlp = BaseMLP(config) | |
| def forward(self, x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: | |
| x = x + self.attn(self.norm_1(x), cos, sin) | |
| x = x + self.mlp(self.norm_2(x)) | |
| return x | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # SECTION 5 — PARCAE MODEL | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # The actual assembly. Most lines here are bookkeeping; the interesting bits | |
| # are `iterate_forward` and `_sample_per_sequence_depths`. | |
| class Parcae(nn.Module): | |
| """ | |
| Flow (mirrors parcae_lm/models/parcae/parcae.py): | |
| input_ids | |
| │ | |
| │ wte (embedding) | |
| ▼ | |
| embed | |
| │ | |
| │ prelude blocks (run once) (2 unique layers for 140M) | |
| ▼ | |
| e = ln_prelude(prelude_out) <───┐ e is the INPUT SIGNAL, | |
| │ │ injected at each loop step. | |
| │ initialize_state(e) → x │ It is FROZEN during the loop. | |
| ▼ │ | |
| ┌──────────────────────────────────┐ │ | |
| │ REPEAT µ_rec TIMES: │ │ | |
| │ x = adapter(x, e) ◄────── DiagonalInjection ┘ | |
| │ for block in core_block: │ | |
| │ x = block(x) (2 unique layers, reused) | |
| └──────────────────────────────────┘ | |
| │ | |
| │ C (linear projection back to n_embd) | |
| ▼ | |
| coda blocks (run once) (2 unique layers for 140M) | |
| │ | |
| │ ln_f + lm_head | |
| ▼ | |
| logits → cross_entropy | |
| """ | |
| def __init__(self, config: ParcaeConfig): | |
| super().__init__() | |
| self.config = config | |
| d = config.n_embd | |
| # Token embedding (padded_vocab_size in real repo; skipping padding here). | |
| self.wte = nn.Embedding(config.vocab_size, d) | |
| # Prelude: UNIQUE layers, run exactly once. | |
| self.prelude = nn.ModuleList( | |
| TransformerPreNormBlock(config) for _ in range(config.n_layers_in_prelude) | |
| ) | |
| # Optional RMSNorm between prelude and core — paper-default True. | |
| self.ln_prelude = ( | |
| RMSNorm(d, eps=config.norm_eps) if config.prelude_norm else None | |
| ) | |
| # Injection module — the mechanism. | |
| if config.injection_type == "diagonal": | |
| self.adapter = DiagonalInjection(config) | |
| elif config.injection_type == "add": | |
| # Ablation: just x + e. Keeps prelude/loop/coda structure. | |
| self.adapter = _AdditiveInjection() | |
| else: | |
| raise ValueError( | |
| f"injection_type={config.injection_type!r} not implemented here" | |
| " — see parcae_lm/modules/injection.py for 'linear'" | |
| ) | |
| # Core: SHARED layers, applied µ_rec times per forward. | |
| self.core_block = nn.ModuleList( | |
| TransformerPreNormBlock(config) | |
| for _ in range(config.n_layers_in_recurrent_block) | |
| ) | |
| # C: project recurrent state back to n_embd for the coda. | |
| # NOTE: This is NOT identity-initialized. It gets the default "scaled" std | |
| # = sqrt(2/5d). It's marked _no_weight_decay in the real repo. | |
| # (parcae_lm/models/parcae/parcae.py:33-39) | |
| self.C = nn.Linear(d, d, bias=config.bias) | |
| self.C.weight._no_weight_decay = True | |
| # Coda: UNIQUE layers, run once after the loop. | |
| self.coda = nn.ModuleList( | |
| TransformerPreNormBlock(config) for _ in range(config.n_layers_in_coda) | |
| ) | |
| self.ln_f = RMSNorm(d, eps=config.norm_eps) | |
| self.lm_head = nn.Linear(d, config.vocab_size, bias=False) | |
| if config.tie_embeddings: | |
| self.lm_head.weight = self.wte.weight | |
| # RoPE frequencies as a non-trainable buffer. | |
| cos, sin = build_rope(config.head_dim, config.block_size, config.rope_base) | |
| self.register_buffer("rope_cos", cos, persistent=False) | |
| self.register_buffer("rope_sin", sin, persistent=False) | |
| self._apply_paper_init() | |
| self.step = 0 # used by the per-sequence depth sampler seed | |
| # ─── Initialization ──────────────────────────────────────────────────── | |
| @torch.no_grad() | |
| def _apply_paper_init(self): | |
| """Apply the 'scaled-zero' init family from parcae_lm/utils/init.py.""" | |
| d = self.config.n_embd | |
| n_layers = self.config.total_effective_depth | |
| strategy = self.config.init_strategy | |
| std_std = _scaled_std(d) | |
| std_outproj = _out_proj_std(d, n_layers, strategy) | |
| for name, p in self.named_parameters(): | |
| if p.dim() < 2: | |
| continue # skip RMSNorm weights, A_log, dt_bias (handled elsewhere) | |
| if "A_log" in name or "dt_bias" in name: | |
| continue # DiagonalInjection handles its own init | |
| if name == "wte.weight" or name == "lm_head.weight": | |
| _trunc_normal_(p, std_std) # embedding uses std_std | |
| continue | |
| # Attention output proj & MLP down-proj are "out_proj"-style | |
| if name.endswith("c_proj.weight") or name.endswith("mlp.proj.weight"): | |
| _trunc_normal_(p, std_outproj) | |
| continue | |
| # B (DiagonalInjection): already initialized to identity/orthogonal | |
| if name.endswith("adapter.B"): | |
| continue | |
| # Everything else: default std | |
| _trunc_normal_(p, std_std) | |
| # Re-init DiagonalInjection (its own scheme overrides above for A_log/dt_bias/B) | |
| if isinstance(self.adapter, DiagonalInjection): | |
| self.adapter.reset_parameters() | |
| # ─── State initialization ────────────────────────────────────────────── | |
| def initialize_state(self, e: Tensor) -> Tensor: | |
| """ | |
| Create x_0. The real repo offers several strategies (state_init): | |
| 'zero' : x_0 = 0 (deterministic) | |
| 'normal' : x_0 ~ N(0, 1) (unit variance) | |
| 'embed' : x_0 ~ N(0, 1/sqrt(d)) (matches embedding scale) | |
| 'like-init' : x_0 ~ trunc_normal(std_std) (default, matches init dist) | |
| 'unit' : random, then normalized to unit L2 per token | |
| We use 'like-init' — statistically matched to what a freshly embedded | |
| token looks like. Source: parcae_lm/models/parcae/parcae.py:537-556 | |
| """ | |
| std = _scaled_std(self.config.n_embd) | |
| x = torch.empty_like(e) | |
| _trunc_normal_(x, std) | |
| return x | |
| # ─── Per-sequence depth sampling ─────────────────────────────────────── | |
| @torch.no_grad() | |
| def _sample_per_sequence_depths( | |
| self, batch_size: int, device | |
| ) -> tuple[Tensor, Tensor]: | |
| """ | |
| For each sequence in the batch, sample a recurrence depth (n, k). | |
| n : number of no-grad "convergence" steps | |
| k : number of gradient-bearing steps (TBPTT window) | |
| Why per-sequence instead of per-batch? | |
| ────────────────────────────────────── | |
| Per-batch sampling gave the paper's authors periodic loss spikes: when | |
| the sampler picked a short depth, the model's recurrent state hadn't | |
| converged, and the gradient signal was way off-distribution. Per-sequence | |
| averages this out within a batch. See Figure 10/11 of the paper. | |
| Scheme we implement: 'poisson-truncated-full' — total = Poisson(t+s), | |
| clamped ≥ 1; k = min(total, s); n = total - k. | |
| Source: parcae_lm/models/parcae/parcae.py:520-524 | |
| """ | |
| t_target = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0) | |
| s_target = self.config.mean_backprop_depth | |
| total_target = float(t_target + s_target) | |
| gen = torch.Generator(device="cpu") | |
| gen.manual_seed((514229 + self.step) % (2**31 - 1)) | |
| total = torch.clamp( | |
| torch.poisson(torch.full((batch_size,), total_target), generator=gen), min=1 | |
| ) | |
| k = torch.clamp(total, max=float(s_target)) | |
| n = total - k | |
| return n.long().to(device), k.long().to(device) | |
| # ─── The recurrence loop ─────────────────────────────────────────────── | |
| def iterate_forward( | |
| self, e: Tensor, cos: Tensor, sin: Tensor, mu_rec_override: Optional[int] = None | |
| ) -> Tensor: | |
| """ | |
| Apply the core block µ_rec times, with: | |
| • no_grad warm-up for first (µ_rec - backprop_depth) steps | |
| • gradient tracking for last (backprop_depth) steps | |
| • per-sequence depth masking so each sequence uses its sampled depth | |
| The per-sequence trick: we always run MAX(sampled) steps, but for | |
| sequences whose sampled depth is less than the step index, we leave | |
| their state untouched (via torch.where). This is O(max_depth) forward | |
| passes but respects per-sequence variation in depth. | |
| Source: parcae_lm/models/parcae/parcae.py:317-356 | |
| """ | |
| x = self.initialize_state(e) | |
| if mu_rec_override is not None: | |
| # Deterministic path used at eval / for debugging. | |
| T = mu_rec_override | |
| n_nograd = max(T - self.config.mean_backprop_depth, 0) | |
| n_grad = T - n_nograd | |
| with torch.no_grad(): | |
| for _ in range(n_nograd): | |
| x = self._core_step(x, e, cos, sin) | |
| for _ in range(n_grad): | |
| x = self._core_step(x, e, cos, sin) | |
| return x | |
| # Training path — per-sequence sampling. | |
| if self.training and self.config.recurrent_iteration_method == "per-sequence": | |
| B = e.shape[0] | |
| n_per_sample, k_per_sample = self._sample_per_sequence_depths(B, e.device) | |
| max_n = int(n_per_sample.max().item()) | |
| max_k = int(k_per_sample.max().item()) | |
| with torch.no_grad(): | |
| for step in range(max_n): | |
| x_new = self._core_step(x, e, cos, sin) | |
| # Mask: only update sequences whose sampled n > step. | |
| active = (step < n_per_sample).view(B, 1, 1) | |
| x = torch.where(active, x_new, x) | |
| for step in range(max_k): | |
| x_new = self._core_step(x, e, cos, sin) | |
| active = (step < k_per_sample).view(B, 1, 1) | |
| x = torch.where(active, x_new, x) | |
| return x | |
| # Fallback: plain per-batch (use config.mean_recurrence as fixed depth). | |
| T = self.config.mean_recurrence | |
| n_nograd = max(T - self.config.mean_backprop_depth, 0) | |
| n_grad = T - n_nograd | |
| with torch.no_grad(): | |
| for _ in range(n_nograd): | |
| x = self._core_step(x, e, cos, sin) | |
| for _ in range(n_grad): | |
| x = self._core_step(x, e, cos, sin) | |
| return x | |
| def _core_step(self, x: Tensor, e: Tensor, cos: Tensor, sin: Tensor) -> Tensor: | |
| """One iteration: INJECT first (updates state with prelude signal), | |
| then run every core block. Source: parcae_lm/models/parcae/parcae.py:362-386 | |
| """ | |
| x = self.adapter(x, e) | |
| for block in self.core_block: | |
| x = block(x, cos, sin) | |
| return x | |
| # ─── Full forward pass ───────────────────────────────────────────────── | |
| def forward( | |
| self, | |
| input_ids: Tensor, | |
| labels: Optional[Tensor] = None, | |
| mu_rec_override: Optional[int] = None, | |
| ) -> dict: | |
| """ | |
| Args: | |
| input_ids: (B, T) int64 | |
| labels: (B, T) int64 — if given, returns loss | |
| mu_rec_override: int — force a specific recurrence depth (eval/debug) | |
| """ | |
| B, T = input_ids.shape | |
| assert ( | |
| T <= self.config.block_size | |
| ), f"seq {T} > block_size {self.config.block_size}" | |
| cos = self.rope_cos[:, :, :T] | |
| sin = self.rope_sin[:, :, :T] | |
| # 1. Embed | |
| x = self.wte(input_ids) # (B, T, d) | |
| # 2. Prelude | |
| for block in self.prelude: | |
| x = block(x, cos, sin) | |
| # 3. Optional norm between prelude and core | |
| if self.ln_prelude is not None: | |
| x = self.ln_prelude(x) | |
| # 4. The prelude output is the INPUT SIGNAL for the loop. Frozen from here on. | |
| e = x | |
| # 5. Iterate the core block | |
| x = self.iterate_forward(e, cos, sin, mu_rec_override=mu_rec_override) | |
| # 6. Project state back to n_embd | |
| x = self.C(x) | |
| # 7. Coda | |
| for block in self.coda: | |
| x = block(x, cos, sin) | |
| # 8. Final norm + head | |
| x = self.ln_f(x) | |
| logits = self.lm_head(x) | |
| out = {"logits": logits} | |
| if labels is not None: | |
| out["loss"] = F.cross_entropy( | |
| logits.reshape(-1, logits.size(-1)).float(), # upcast for CE stability | |
| labels.reshape(-1), | |
| ignore_index=-100, | |
| ) | |
| if self.training: | |
| self.step += 1 | |
| return out | |
| class _AdditiveInjection(nn.Module): | |
| """Ablation baseline: x_{t+1} = x_t + e (no state decay, no learned mixing).""" | |
| def forward(self, x: Tensor, e: Tensor) -> Tensor: | |
| return x + e | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| # SECTION 6 — DEMO / SELF-TEST | |
| # ═══════════════════════════════════════════════════════════════════════════ | |
| def _demo(tiny: bool = False): | |
| torch.manual_seed(0) | |
| cfg = ( | |
| ParcaeConfig.tiny() | |
| if tiny | |
| else ParcaeConfig( | |
| vocab_size=2048, | |
| block_size=128, | |
| n_embd=192, | |
| intermediate_size=512, | |
| num_attention_heads=6, | |
| num_key_value_heads=6, | |
| recurrent_embedding_dimension=192, | |
| recurrent_intermediation_embedding_dimension=512, | |
| n_layers_in_prelude=2, | |
| n_layers_in_recurrent_block=2, | |
| n_layers_in_coda=2, | |
| mean_recurrence=4, | |
| mean_backprop_depth=2, | |
| ) | |
| ) | |
| model = Parcae(cfg) | |
| n_params = sum(p.numel() for p in model.parameters()) - ( | |
| model.lm_head.weight.numel() if cfg.tie_embeddings else 0 | |
| ) | |
| print(f"Parcae config (total_effective_depth={cfg.total_effective_depth}):") | |
| print( | |
| f" n_embd={cfg.n_embd} heads={cfg.num_attention_heads} " | |
| f"µ_rec={cfg.mean_recurrence} bp_depth={cfg.mean_backprop_depth} " | |
| f"inject={cfg.injection_type} init={cfg.init_strategy}" | |
| ) | |
| print(f"Params: {n_params:,}\n") | |
| # Spectral-norm diagnostic | |
| if isinstance(model.adapter, DiagonalInjection): | |
| sr = model.adapter.spectral_radius() | |
| print( | |
| f"DiagonalInjection @ init: ρ(decay) = " | |
| f"min={sr['min']:.4f} mean={sr['mean']:.4f} max={sr['max']:.4f}" | |
| ) | |
| assert sr["max"] < 1.0, "spectral radius must be < 1 by construction!" | |
| print(" ✓ Spectral radius guarantee holds.\n") | |
| # One forward pass (bf16 autocast, training mode → exercises per-seq sampler) | |
| ids = torch.randint(0, cfg.vocab_size, (2, 32)) | |
| labels = torch.randint(0, cfg.vocab_size, (2, 32)) | |
| model.train() | |
| with torch.autocast(device_type="cpu", dtype=torch.bfloat16): | |
| out = model(ids, labels=labels) | |
| print( | |
| f"Forward (training, per-sequence sampling): " | |
| f"logits={tuple(out['logits'].shape)} loss={out['loss'].item():.4f}" | |
| ) | |
| # Backward — make sure gradients flow through the gradient-tail of the loop | |
| out["loss"].backward() | |
| inj_grad_norm = ( | |
| model.adapter.A_log.grad.norm().item() | |
| if isinstance(model.adapter, DiagonalInjection) | |
| else 0.0 | |
| ) | |
| print( | |
| f"Backward OK. ||grad A_log|| = {inj_grad_norm:.6f} " | |
| f"(nonzero ⇒ injection is trainable)" | |
| ) | |
| # Eval — deterministic µ_rec | |
| model.eval() | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cpu", dtype=torch.bfloat16): | |
| out_eval = model(ids, mu_rec_override=cfg.mean_recurrence) | |
| print( | |
| f"Forward (eval, µ_rec={cfg.mean_recurrence} fixed): " | |
| f"logits={tuple(out_eval['logits'].shape)}" | |
| ) | |
| # Test-time scaling: verify L(T) saturates as the paper predicts | |
| # (it will not saturate at init — but the shape of the forward should work | |
| # for any µ_rec within reason) | |
| print( | |
| "\nTest-time µ_rec sweep (from a random-init model, no expectation of monotonicity):" | |
| ) | |
| for T in [1, 2, 4, 8, 16]: | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cpu", dtype=torch.bfloat16): | |
| o = model(ids, labels=labels, mu_rec_override=T) | |
| print(f" µ_rec={T:>2d}: loss={o['loss'].item():.4f}") | |
| if __name__ == "__main__": | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--tiny", action="store_true", help="Use tiny CPU config") | |
| args = ap.parse_args() | |
| _demo(tiny=args.tiny) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment