Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Created May 12, 2026 05:46
Show Gist options
  • Select an option

  • Save pszemraj/a5617899bc5d1b4a1d517e4d7dd25762 to your computer and use it in GitHub Desktop.

Select an option

Save pszemraj/a5617899bc5d1b4a1d517e4d7dd25762 to your computer and use it in GitHub Desktop.
MIMO-HGDN: initial pass at Olmo3-hybrid style gated delta net (hgdn) incl negative eigenvalues and MIMO from Mamba3

Mamba-3 → OLMo-Hybrid GDN Prototype

Files:

  • MAMBA3_TO_OLMO_HYBRID_GDN.md: architecture proposal and integration plan.
  • hybrid_mimo_gdn.py: pure-PyTorch reference implementation of a hybrid MIMO Gated DeltaNet decoder.

Smoke test:

python hybrid_mimo_gdn.py

The implementation is intentionally not kernel-optimized. It is a correctness/specification prototype for small experiments and for designing a fused rank-R GDN kernel.

"""Hybrid MIMO Gated DeltaNet language model prototype.
This is a research/reference implementation of an OLMo-Hybrid-style
Gated DeltaNet model with the most portable Mamba-3 ideas applied:
1. MIMO / rank-R DeltaNet updates with fixed recurrent state size.
2. Optional trapezoidal two-timestep input mixing.
3. Optional data-dependent rotary rotations on GDN q/k features.
The implementation is deliberately pure PyTorch. It is not a replacement for
Flash Linear Attention kernels, but it gives a correct, hackable baseline for
small-scale experiments and for specifying the fused kernel you would want for
large-scale training or decode.
Shapes use batch-first tensors throughout. The GDN state for each head is
S in R^{d_v x d_k}; for MIMO rank R the state size is unchanged while each
step performs a rank-R update/read.
"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Iterable, Optional, Sequence
import torch
import torch.nn as nn
import torch.nn.functional as F
Tensor = torch.Tensor
def exists(x) -> bool:
return x is not None
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
orig_dtype = x.dtype
x_float = x.float()
y = x_float * torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return (y.to(orig_dtype) * self.weight.to(orig_dtype))
class SwiGLU(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: Optional[int] = None,
multiple_of: int = 256,
ffn_dim_multiplier: Optional[float] = None,
):
super().__init__()
if hidden_dim is None:
hidden_dim = int((ffn_dim_multiplier or (8.0 / 3.0)) * dim)
hidden_dim = multiple_of * math.ceil(hidden_dim / multiple_of)
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class CausalDepthwiseConv1d(nn.Module):
def __init__(self, dim: int, kernel_size: int = 4, enabled: bool = True):
super().__init__()
self.enabled = bool(enabled and kernel_size > 1)
self.kernel_size = int(kernel_size)
if self.enabled:
self.conv = nn.Conv1d(
dim,
dim,
kernel_size=self.kernel_size,
groups=dim,
bias=False,
)
nn.init.zeros_(self.conv.weight)
center = self.kernel_size - 1
with torch.no_grad():
self.conv.weight[:, 0, center] = 1.0
else:
self.conv = None
def forward(self, x: Tensor) -> Tensor:
if not self.enabled:
return x
y = x.transpose(1, 2)
y = F.pad(y, (self.kernel_size - 1, 0))
y = self.conv(y).transpose(1, 2)
return F.silu(y)
class RotaryEmbedding(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0):
super().__init__()
if dim % 2 != 0:
raise ValueError(f"RoPE head dim must be even, got {dim}")
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]:
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("t,d->td", t, self.inv_freq.to(device))
return freqs.cos().to(dtype), freqs.sin().to(dtype)
def apply_rotary(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
# x: (B, H, T, D) or (B, T, H, R, D). cos/sin: (T, D/2)
d = x.shape[-1]
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
if x.dim() == 4:
c = cos[None, None, :, :]
s = sin[None, None, :, :]
elif x.dim() == 5:
c = cos[None, :, None, None, :]
s = sin[None, :, None, None, :]
else:
raise ValueError(f"Unsupported tensor rank for RoPE: {x.dim()}")
y_even = x_even * c - x_odd * s
y_odd = x_even * s + x_odd * c
return torch.stack((y_even, y_odd), dim=-1).reshape(*x.shape[:-1], d)
def apply_data_dependent_rotary(x: Tensor, cumulative_angles: Tensor) -> Tensor:
# x: (B, T, H, R, D), angles: (B, T, H, D/2)
d = x.shape[-1]
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
c = cumulative_angles.cos()[:, :, :, None, :].to(dtype=x.dtype)
s = cumulative_angles.sin()[:, :, :, None, :].to(dtype=x.dtype)
y_even = x_even * c - x_odd * s
y_odd = x_even * s + x_odd * c
return torch.stack((y_even, y_odd), dim=-1).reshape(*x.shape[:-1], d)
class Attention(nn.Module):
def __init__(
self,
dim: int,
heads: int,
dim_head: Optional[int] = None,
rope_theta: float = 500000.0,
qk_norm: bool = True,
flash_attn: bool = True,
):
super().__init__()
self.dim = dim
self.heads = heads
self.dim_head = dim_head or dim // heads
self.flash_attn = flash_attn
self.qk_norm = qk_norm
inner = self.heads * self.dim_head
self.wq = nn.Linear(dim, inner, bias=False)
self.wk = nn.Linear(dim, inner, bias=False)
self.wv = nn.Linear(dim, inner, bias=False)
self.wo = nn.Linear(inner, dim, bias=False)
self.rotary = RotaryEmbedding(self.dim_head, theta=rope_theta)
def forward(self, x: Tensor) -> Tensor:
b, t, _ = x.shape
q = self.wq(x).view(b, t, self.heads, self.dim_head).transpose(1, 2)
k = self.wk(x).view(b, t, self.heads, self.dim_head).transpose(1, 2)
v = self.wv(x).view(b, t, self.heads, self.dim_head).transpose(1, 2)
if self.qk_norm:
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
cos, sin = self.rotary(t, x.device, q.dtype)
q = apply_rotary(q, cos, sin)
k = apply_rotary(k, cos, sin)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = out.transpose(1, 2).contiguous().view(b, t, self.heads * self.dim_head)
return self.wo(out)
@dataclass(frozen=True)
class MIMOGDNConfig:
dim: int
heads: int
head_k_dim: int
expand_v: float = 2.0
mimo_rank: int = 4
conv_size: int = 4
allow_neg_eigval: bool = True
parameter_efficient_mimo: bool = True
beta_rank_rescale: bool = True
orthonormalize_k: bool = False
trapezoidal: bool = False
trapezoid_lambda_init: float = 0.90
data_dependent_qk_rope: bool = False
output_norm_eps: float = 1e-5
class MIMOGatedDeltaNet(nn.Module):
"""Rank-R Gated DeltaNet layer with fixed-size state.
Standard GDN update for one head can be written with a state S in R^{d_v x d_k}:
S_t = alpha_t * (S_{t-1} - beta_t * (S_{t-1} k_t) k_t^T)
+ beta_t * v_t k_t^T
y_t = S_t q_t
This layer generalizes k, q, and v to R columns per token:
K_t, Q_t in R^{R x d_k}, V_t in R^{R x d_v}
S_t = alpha_t * (S_{t-1} - sum_r beta_{t,r} (S_{t-1} k_{t,r}) k_{t,r}^T)
+ sum_r beta_{t,r} v_{t,r} k_{t,r}^T
Y_t[:, r] = S_t q_{t,r}
The state stays d_v x d_k per head, independent of R. Increasing R raises
per-token arithmetic intensity and expressive write/read rank.
"""
def __init__(self, cfg: MIMOGDNConfig):
super().__init__()
if cfg.head_k_dim % 2 != 0 and cfg.data_dependent_qk_rope:
raise ValueError("data_dependent_qk_rope requires an even head_k_dim")
self.cfg = cfg
self.dim = cfg.dim
self.heads = cfg.heads
self.head_k_dim = cfg.head_k_dim
self.head_v_dim = int(round(cfg.head_k_dim * cfg.expand_v))
self.rank = int(cfg.mimo_rank)
if self.rank < 1:
raise ValueError("mimo_rank must be >= 1")
h, dk, dv, r = self.heads, self.head_k_dim, self.head_v_dim, self.rank
qk_total = h * dk
v_total = h * dv
self.parameter_efficient_mimo = cfg.parameter_efficient_mimo
if cfg.parameter_efficient_mimo:
self.w_q = nn.Linear(cfg.dim, qk_total, bias=False)
self.w_k = nn.Linear(cfg.dim, qk_total, bias=False)
self.w_v = nn.Linear(cfg.dim, v_total, bias=False)
self.q_rank_scale = nn.Parameter(torch.empty(h, r, dk))
self.k_rank_scale = nn.Parameter(torch.empty(h, r, dk))
self.v_rank_scale = nn.Parameter(torch.empty(h, r, dv))
nn.init.normal_(self.q_rank_scale, mean=1.0, std=0.02)
nn.init.normal_(self.k_rank_scale, mean=1.0, std=0.02)
nn.init.normal_(self.v_rank_scale, mean=1.0, std=0.02)
else:
self.w_q = nn.Linear(cfg.dim, h * r * dk, bias=False)
self.w_k = nn.Linear(cfg.dim, h * r * dk, bias=False)
self.w_v = nn.Linear(cfg.dim, h * r * dv, bias=False)
self.q_rank_scale = None
self.k_rank_scale = None
self.v_rank_scale = None
self.q_conv = CausalDepthwiseConv1d(qk_total if cfg.parameter_efficient_mimo else h * r * dk, cfg.conv_size)
self.k_conv = CausalDepthwiseConv1d(qk_total if cfg.parameter_efficient_mimo else h * r * dk, cfg.conv_size)
self.v_conv = CausalDepthwiseConv1d(v_total if cfg.parameter_efficient_mimo else h * r * dv, cfg.conv_size)
self.w_alpha = nn.Linear(cfg.dim, h, bias=False)
self.w_beta = nn.Linear(cfg.dim, h * r, bias=False)
self.w_out_gate = nn.Linear(cfg.dim, h * dv, bias=False)
self.w_out = nn.Linear(h * dv, cfg.dim, bias=False)
self.out_norm = RMSNorm(dv, eps=cfg.output_norm_eps)
self.rank_mixer = nn.Parameter(torch.zeros(h, r))
# FLA-style-ish continuous decay prior. alpha = exp(-exp(A_log) * softplus(dt)).
self.A_log = nn.Parameter(torch.empty(h))
self.dt_bias = nn.Parameter(torch.empty(h))
self.reset_decay_parameters()
if cfg.trapezoidal:
self.w_lambda = nn.Linear(cfg.dim, h, bias=False)
p = min(max(float(cfg.trapezoid_lambda_init), 1e-4), 1.0 - 1e-4)
self.lambda_bias = nn.Parameter(torch.full((h,), math.log(p / (1.0 - p))))
else:
self.w_lambda = None
self.lambda_bias = None
if cfg.data_dependent_qk_rope:
self.w_theta = nn.Linear(cfg.dim, h * (dk // 2), bias=False)
else:
self.w_theta = None
def reset_decay_parameters(self) -> None:
# A roughly follows a positive range up to 16; dt is log-uniform in [1e-3, 1e-1].
with torch.no_grad():
a = torch.empty(self.heads).uniform_(1.0, 16.0)
dt = torch.exp(torch.empty(self.heads).uniform_(math.log(1e-3), math.log(1e-1)))
self.A_log.copy_(a.log())
self.dt_bias.copy_(torch.log(torch.expm1(dt)))
def _project_qkv(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
b, t, _ = x.shape
h, r, dk, dv = self.heads, self.rank, self.head_k_dim, self.head_v_dim
if self.parameter_efficient_mimo:
q = self.q_conv(self.w_q(x)).view(b, t, h, dk)
k = self.k_conv(self.w_k(x)).view(b, t, h, dk)
v = self.v_conv(self.w_v(x)).view(b, t, h, dv)
q = q[:, :, :, None, :] * self.q_rank_scale.to(dtype=q.dtype, device=q.device)[None, None, :, :, :]
k = k[:, :, :, None, :] * self.k_rank_scale.to(dtype=k.dtype, device=k.device)[None, None, :, :, :]
v = v[:, :, :, None, :] * self.v_rank_scale.to(dtype=v.dtype, device=v.device)[None, None, :, :, :]
else:
q = self.q_conv(self.w_q(x)).view(b, t, h, r, dk)
k = self.k_conv(self.w_k(x)).view(b, t, h, r, dk)
v = self.v_conv(self.w_v(x)).view(b, t, h, r, dv)
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
if self.cfg.orthonormalize_k and r > 1:
# QR over the R columns in Dk-space. This is useful for a correctness/stability
# prototype but too expensive to be the default fast path.
flat = k.reshape(b * t * h, r, dk).transpose(-1, -2) # (BTH, Dk, R)
q_orth, _ = torch.linalg.qr(flat.float(), mode="reduced")
k = q_orth.transpose(-1, -2).to(k.dtype).reshape(b, t, h, r, dk)
if self.w_theta is not None:
angles = self.w_theta(x).view(b, t, h, dk // 2).float().cumsum(dim=1)
q = apply_data_dependent_rotary(q, angles)
k = apply_data_dependent_rotary(k, angles)
return q, k, v
def forward(self, x: Tensor) -> Tensor:
b, t, _ = x.shape
h, r, dk, dv = self.heads, self.rank, self.head_k_dim, self.head_v_dim
q, k, v = self._project_qkv(x)
alpha_pre = self.w_alpha(x).float()
alpha_log = -self.A_log.exp()[None, None, :] * F.softplus(
alpha_pre + self.dt_bias[None, None, :]
)
alpha = alpha_log.exp().to(dtype=x.dtype)
beta = torch.sigmoid(self.w_beta(x)).view(b, t, h, r).to(dtype=x.dtype)
if self.cfg.allow_neg_eigval:
beta = beta * 2.0
if self.cfg.beta_rank_rescale and r > 1 and not self.cfg.orthonormalize_k:
beta = beta / math.sqrt(r)
if self.cfg.trapezoidal:
lam = torch.sigmoid(
self.w_lambda(x).float() + self.lambda_bias[None, None, :]
).to(dtype=x.dtype)
else:
lam = None
state = torch.zeros(b, h, dv, dk, device=x.device, dtype=x.dtype)
prev_write = torch.zeros_like(state)
outs: list[Tensor] = []
for idx in range(t):
q_t = q[:, idx] # (B,H,R,Dk)
k_t = k[:, idx] # (B,H,R,Dk)
v_t = v[:, idx] # (B,H,R,Dv)
beta_t = beta[:, idx] # (B,H,R)
alpha_t = alpha[:, idx, :, None, None]
sk = torch.einsum("bhvd,bhrd->bhvr", state, k_t)
transition = torch.einsum("bhvr,bhrd,bhr->bhvd", sk, k_t, beta_t)
write = torch.einsum("bhrv,bhrd,bhr->bhvd", v_t, k_t, beta_t)
candidate = alpha_t * (state - transition)
if lam is not None:
pwk = torch.einsum("bhvd,bhrd->bhvr", prev_write, k_t)
transported_prev = alpha_t * (
prev_write - torch.einsum("bhvr,bhrd,bhr->bhvd", pwk, k_t, beta_t)
)
lam_t = lam[:, idx, :, None, None]
state = candidate + lam_t * write + (1.0 - lam_t) * transported_prev
else:
state = candidate + write
prev_write = write
y = torch.einsum("bhvd,bhrd->bhrv", state, q_t)
outs.append(y)
y_all = torch.stack(outs, dim=1) # (B,T,H,R,Dv)
rank_weights = F.softmax(self.rank_mixer.float(), dim=-1).to(dtype=y_all.dtype)
y = torch.einsum("bthrv,hr->bthv", y_all, rank_weights)
y = self.out_norm(y)
gate = F.silu(self.w_out_gate(x).view(b, t, h, dv))
y = y * gate
return self.w_out(y.reshape(b, t, h * dv))
class HybridBlock(nn.Module):
def __init__(
self,
dim: int,
mixer: nn.Module,
ffn_dim_multiplier: Optional[float] = None,
norm_eps: float = 1e-5,
):
super().__init__()
self.mixer = mixer
self.mixer_norm = RMSNorm(dim, eps=norm_eps)
self.ff_norm = RMSNorm(dim, eps=norm_eps)
self.ff = SwiGLU(dim, ffn_dim_multiplier=ffn_dim_multiplier)
def forward(self, x: Tensor) -> Tensor:
x = x + self.mixer(self.mixer_norm(x))
x = x + self.ff(self.ff_norm(x))
return x
class HybridMIMOGDNLM(nn.Module):
"""OLMo-Hybrid-style decoder with [GDN, GDN, GDN, Attn] interleaving."""
def __init__(
self,
num_tokens: int,
dim: int = 512,
depth: int = 16,
heads: int = 8,
dim_head: Optional[int] = None,
gdn_head_k_dim: Optional[int] = None,
gdn_expand_v: float = 2.0,
gdn_mimo_rank: int = 4,
gdn_trapezoidal: bool = False,
gdn_data_dependent_qk_rope: bool = False,
gdn_parameter_efficient_mimo: bool = True,
block_pattern: Sequence[str] = ("gdn", "gdn", "gdn", "attn"),
tie_embeddings: bool = True,
ffn_dim_multiplier: Optional[float] = None,
rope_theta: float = 500000.0,
max_seq_len: int = 2048,
flash_attn: bool = True,
):
super().__init__()
self.vocab_size = num_tokens
self.model_dim = dim
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(num_tokens, dim)
self.blocks = nn.ModuleList()
self.block_types: list[str] = []
dim_head = dim_head or dim // heads
gdn_head_k_dim = gdn_head_k_dim or max(8, int(0.75 * dim_head))
pattern = list(block_pattern)
if not pattern:
raise ValueError("block_pattern must be non-empty")
for i in range(depth):
typ = pattern[i % len(pattern)].lower()
if typ not in {"gdn", "attn"}:
raise ValueError(f"Unsupported block type {typ!r}")
if typ == "attn":
mixer = Attention(dim, heads=heads, dim_head=dim_head, rope_theta=rope_theta, flash_attn=flash_attn)
else:
mixer = MIMOGatedDeltaNet(
MIMOGDNConfig(
dim=dim,
heads=heads,
head_k_dim=gdn_head_k_dim,
expand_v=gdn_expand_v,
mimo_rank=gdn_mimo_rank,
trapezoidal=gdn_trapezoidal,
data_dependent_qk_rope=gdn_data_dependent_qk_rope,
parameter_efficient_mimo=gdn_parameter_efficient_mimo,
)
)
self.blocks.append(HybridBlock(dim, mixer, ffn_dim_multiplier=ffn_dim_multiplier))
self.block_types.append(typ)
self.norm = RMSNorm(dim)
self.lm_head = nn.Linear(dim, num_tokens, bias=False)
if tie_embeddings:
self.lm_head.weight = self.token_emb.weight
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module) -> None:
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
x: Tensor,
targets: Optional[Tensor] = None,
return_loss: bool = False,
mask: Optional[Tensor] = None,
):
del mask
h = self.token_emb(x)
for block in self.blocks:
h = block(h)
logits = self.lm_head(self.norm(h))
if return_loss or targets is not None:
if targets is None:
targets = x[:, 1:].contiguous()
logits_for_loss = logits[:, :-1].contiguous()
else:
logits_for_loss = logits
loss = F.cross_entropy(
logits_for_loss.view(-1, logits_for_loss.size(-1)),
targets.reshape(-1),
)
return loss if return_loss else (logits, loss)
return logits
@torch.no_grad()
def generate(
self,
prompt: Tensor,
max_length: int,
temperature: float = 1.0,
filter_thres: float = 0.9,
min_p: Optional[float] = None,
) -> Tensor:
del min_p
out = prompt
for _ in range(max_length):
x = out[:, -self.max_seq_len :]
logits = self(x)[:, -1]
logits = logits / max(float(temperature), 1e-6)
if filter_thres < 1.0:
probs = logits.softmax(dim=-1)
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True)
cum = sorted_probs.cumsum(dim=-1)
keep = cum <= filter_thres
keep[..., 0] = True
filtered = torch.full_like(logits, float("-inf"))
filtered.scatter_(dim=-1, index=sorted_idx, src=torch.where(keep, logits.gather(-1, sorted_idx), torch.full_like(sorted_probs, float("-inf"))))
logits = filtered
next_token = torch.distributions.Categorical(logits=logits).sample()[:, None]
out = torch.cat([out, next_token], dim=1)
return out
def count_parameters(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def smoke_test() -> None:
torch.manual_seed(0)
# The pure-PyTorch recurrent loop uses many tiny ops on CPU; one thread is
# far faster for the smoke test because it avoids threadpool overhead.
torch.set_num_threads(1)
model = HybridMIMOGDNLM(
num_tokens=64,
dim=32,
depth=2,
heads=1,
dim_head=32,
gdn_head_k_dim=16,
gdn_expand_v=1.5,
gdn_mimo_rank=2,
gdn_trapezoidal=True,
gdn_data_dependent_qk_rope=True,
ffn_dim_multiplier=1.0,
block_pattern=("gdn", "attn"),
flash_attn=False,
)
x = torch.randint(0, 64, (1, 8))
logits, loss = model(x, x, return_loss=False)
loss.backward()
print({"params": count_parameters(model), "logits": tuple(logits.shape), "loss": float(loss.detach())})
if __name__ == "__main__":
smoke_test()

Applying Mamba-3 Ideas to an OLMo-Hybrid-Style Gated DeltaNet

Recommendation

The best target architecture is not a Mamba-3 hybrid. It is an OLMo-Hybrid-style interleaved Gated DeltaNet model with Mamba-3's hardware-aware recurrence upgrades ported into the GDN mixer.

The highest-value variant is:

Hybrid MIMO-HGDN: [GDN, GDN, GDN, Attention] × N, with negative-eigenvalue GDN preserved, rank-R MIMO updates added to the GDN state update, short convolutions retained for the first ablation, and optional data-dependent q/k rotations as a lower-priority expressivity experiment.

The three Mamba-3 ideas transfer unevenly:

Mamba-3 idea Transfer to OLMo-Hybrid GDN Priority Why
MIMO Direct and valuable 1 Same memory-bound decode problem; state size can stay fixed while per-token compute increases.
Exponential-trapezoidal input mixing Useful but requires a new GDN kernel for the faithful version 2 Adds implicit 2-token input convolution; may let us remove short conv after ablation.
Complex/RoPE trick Possible but not cleanly equivalent 3 GDN's low-rank transition does not commute with arbitrary rotations, and negative eigenvalues already solve the key state-tracking issue.

Baseline to preserve

Use the OLMo-Hybrid design as the reference point:

  • Replace 75% of attention layers with GDN layers.
  • Keep a 3:1 interleaving ratio: GDN, GDN, GDN, Attention.
  • Keep negative-eigenvalue GDN: beta is allowed to produce reflection-like / sign-changing dynamics.
  • Keep value dimension expansion, approximately d_v = 2 d_k.
  • Keep attention layers interleaved rather than concentrated.

Do not remove the short conv as the first step. Mamba-3 can drop it because its trapezoidal SSM recurrence and B/C biases create a convolution-like path. GDN already has a tuned short-conv path in FLA/OLMo; removing it before the recurrence change is a confound.

1. Rank-R MIMO Gated DeltaNet

Standard per-head GDN has state S_t ∈ R^{d_v × d_k} and update:

[ S_t = \alpha_t \left(S_{t-1} - \beta_t (S_{t-1} k_t) k_t^\top\right) + \beta_t v_t k_t^\top ]

[ y_t = S_t q_t. ]

A fixed-state MIMO generalization uses R write/read columns per token:

[ K_t, Q_t \in R^{R \times d_k}, \quad V_t \in R^{R \times d_v}, \quad \beta_t \in R^R. ]

[ S_t = \alpha_t \left(S_{t-1} - \sum_{r=1}^{R}\beta_{t,r}(S_{t-1}k_{t,r})k_{t,r}^\top\right) + \sum_{r=1}^{R}\beta_{t,r}v_{t,r}k_{t,r}^\top. ]

[ y_{t,r} = S_t q_{t,r}, \qquad y_t = \sum_r \pi_r y_{t,r}. ]

The state remains d_v × d_k per head. The extra rank increases update/read compute without increasing recurrent cache size. This is the GDN analogue of Mamba-3's MIMO argument.

Stability choices

For R > 1, the transition

[ I - \sum_r \beta_r k_r k_r^\top ]

is only cleanly Householder-like if the k_r columns are approximately orthonormal. Three options:

  1. Reference/stability mode: QR-orthonormalize K_t per token/head. Correct but expensive.
  2. Fast prototype mode: L2-normalize each column and rescale beta by 1 / sqrt(R) unless a better orthogonality regularizer is used.
  3. Kernel mode: use a small-R tile and add an orthogonality penalty ||K K^T - I||², or construct K with a cheap structured orthogonal parameterization.

I would start with option 2 for model-search and option 1 for synthetic state-tracking sanity checks.

Parameter-efficient MIMO

The naive full projection expands q/k/v by R, which is simple but not ideal. A Mamba-3-like parameter-efficient form is:

[ q_{t,r} = q_t \odot s^q_r, \quad k_{t,r} = k_t \odot s^k_r, \quad v_{t,r} = v_t \odot s^v_r, ]

where s_r are learned per-rank scale vectors. This gives rank-specific directions without multiplying the main projection matrices by R. A full projection is still useful as an upper-bound ablation.

2. Trapezoidal two-timestep GDN input mixing

A faithful GDN analogue of Mamba-3's exponential-trapezoidal update is:

[ W_t = \sum_r \beta_{t,r}v_{t,r}k_{t,r}^\top ]

[ A_t(X) = \alpha_t\left(X - \sum_r\beta_{t,r}(Xk_{t,r})k_{t,r}^\top\right) ]

[ S_t = A_t(S_{t-1}) + \lambda_t W_t + (1-\lambda_t)A_t(W_{t-1}). ]

This gives an implicit input convolution of size 2, but unlike Mamba-3 it respects the identity-plus-low-rank transition class. Initialize lambda near the current-only path, for example sigmoid(lambda_bias)=0.9, so the layer starts close to normal GDN.

Short-conv ablation order

Run these variants in order:

  1. Baseline GDN short conv, no trapezoid.
  2. Baseline GDN short conv + trapezoid.
  3. Trapezoid with q/k/v short conv disabled.
  4. Trapezoid with only v short conv disabled.

The likely outcome is that short conv remains useful for GDN even if Mamba-3 can drop it, because GDN's low-rank transition and Mamba's diagonal SSM are not the same object.

3. Data-dependent q/k rotations

The Mamba-3 RoPE trick is exact for a diagonal/scalar SSM transition because rotations commute with decay and can be absorbed into B/C. In GDN, the low-rank term k k^T does not commute with arbitrary rotations. The direct port is therefore not equivalent.

The practical GDN version is simpler:

[ \theta_t = \sum_{i \le t} \Delta\theta_i, \quad q'_t = R(\theta_t)q_t, \quad k'_t = R(\theta_t)k_t. ]

Then use q' and k' in the usual GDN kernel. This is easy to add before the existing FLA recurrence because it changes only the projected inputs. It is an expressivity ablation, not a replacement for negative eigenvalues.

Implementation implications for the current fork

Your current parameter-golf branch already has the important OLMo-Hybrid alignment points:

  • public FLA chunk_gated_delta_rule
  • negative eigenvalues by default
  • FLA-style decay initialization and no-weight-decay decay params
  • packed q/k/v projection and packed q/k/v short conv
  • learned output RMSNorm weight
  • 3:1 OLMo-ish configs

That means the next code step should not be another SISO GDN cleanup. It should be a new experimental mixer path:

  • GDN_MIMO_R=1 default, preserving current behavior.
  • GDN_MIMO_PARAM_EFFICIENT=1 default.
  • GDN_MIMO_ORTHONORMALIZE_K=0 default for fast runs, 1 for correctness probes.
  • GDN_TRAPEZOIDAL=0 default.
  • GDN_QK_DATA_ROPE=0 default.

Kernel reality

The existing FLA op only implements rank-1 GDN per token. True shared-state MIMO cannot be obtained by simply folding rank into heads. Folding H → H×R creates R independent states and increases recurrent cache size; it is an ablation, not MIMO.

For a real large-scale run, implement a new rank-R GDN kernel:

  • state: (B, H, d_v, d_k)
  • per-token K/Q/V: (B, T, H, R, d_k/d_v)
  • update uses matrix products S K, V K^T, and Y = S Q
  • chunked training should reduce chunk size approximately as C/R, mirroring the Mamba-3 argument.

Experiment ladder

Stage A: cheap no-kernel experiments

  1. Existing OLMo-ish HGDN baseline.
  2. Add data-dependent q/k rotations before FLA recurrence.
  3. Rank-folding ablation (H×R) only to estimate whether more write/read diversity helps, while noting it increases state.

Stage B: pure PyTorch MIMO sanity

Use the attached hybrid_mimo_gdn.py reference on toy data and synthetic tasks:

  • parity / modular arithmetic
  • A5-style state tracking
  • MQAR / associative recall
  • state-based recall

Expected result: MIMO should help recall-ish compression and should not hurt state tracking if beta scaling/orthogonality is controlled.

Stage C: custom kernel

Port rank-R update into Extended-WY/chunked GDN. Compare:

  • R=1 baseline
  • R=2
  • R=4
  • R=4 + trapezoid
  • R=4 + trapezoid, no short conv

Stage D: scale

Use OLMo-Hybrid-style controlled scaling:

  • 60M, 190M, 760M, 1B first
  • 3:1 interleaved only at first
  • fit fixed-exponent scaling laws and watch the data coefficient B
  • include RULER/NIAH/MQAR plus standard validation loss

My preferred first paper-quality variant

MIMO-HGDN-R4

  • 3:1 interleaved GDN:Attention
  • negative eigenvalues enabled
  • d_k = 0.75 * attention_head_dim
  • d_v = 2 * d_k
  • R = 4
  • parameter-efficient rank scaling for q/k/v
  • learned rank mixer on outputs
  • short conv retained
  • beta rescaled by 1/sqrt(R) unless K orthogonalization is used
  • no data-dependent rotations in the mainline model
  • trapezoid as a second-line ablation

The reason is simple: MIMO targets a real decode bottleneck while preserving the GDN transition class that made OLMo Hybrid compelling. The other two ideas are worth testing, but they are less central and easier to confound.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment