|
"""Hybrid MIMO Gated DeltaNet language model prototype. |
|
|
|
This is a research/reference implementation of an OLMo-Hybrid-style |
|
Gated DeltaNet model with the most portable Mamba-3 ideas applied: |
|
|
|
1. MIMO / rank-R DeltaNet updates with fixed recurrent state size. |
|
2. Optional trapezoidal two-timestep input mixing. |
|
3. Optional data-dependent rotary rotations on GDN q/k features. |
|
|
|
The implementation is deliberately pure PyTorch. It is not a replacement for |
|
Flash Linear Attention kernels, but it gives a correct, hackable baseline for |
|
small-scale experiments and for specifying the fused kernel you would want for |
|
large-scale training or decode. |
|
|
|
Shapes use batch-first tensors throughout. The GDN state for each head is |
|
S in R^{d_v x d_k}; for MIMO rank R the state size is unchanged while each |
|
step performs a rank-R update/read. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import math |
|
from dataclasses import dataclass |
|
from typing import Iterable, Optional, Sequence |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
Tensor = torch.Tensor |
|
|
|
|
|
def exists(x) -> bool: |
|
return x is not None |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-5): |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
orig_dtype = x.dtype |
|
x_float = x.float() |
|
y = x_float * torch.rsqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps) |
|
return (y.to(orig_dtype) * self.weight.to(orig_dtype)) |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
hidden_dim: Optional[int] = None, |
|
multiple_of: int = 256, |
|
ffn_dim_multiplier: Optional[float] = None, |
|
): |
|
super().__init__() |
|
if hidden_dim is None: |
|
hidden_dim = int((ffn_dim_multiplier or (8.0 / 3.0)) * dim) |
|
hidden_dim = multiple_of * math.ceil(hidden_dim / multiple_of) |
|
self.w1 = nn.Linear(dim, hidden_dim, bias=False) |
|
self.w2 = nn.Linear(hidden_dim, dim, bias=False) |
|
self.w3 = nn.Linear(dim, hidden_dim, bias=False) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
return self.w2(F.silu(self.w1(x)) * self.w3(x)) |
|
|
|
|
|
class CausalDepthwiseConv1d(nn.Module): |
|
def __init__(self, dim: int, kernel_size: int = 4, enabled: bool = True): |
|
super().__init__() |
|
self.enabled = bool(enabled and kernel_size > 1) |
|
self.kernel_size = int(kernel_size) |
|
if self.enabled: |
|
self.conv = nn.Conv1d( |
|
dim, |
|
dim, |
|
kernel_size=self.kernel_size, |
|
groups=dim, |
|
bias=False, |
|
) |
|
nn.init.zeros_(self.conv.weight) |
|
center = self.kernel_size - 1 |
|
with torch.no_grad(): |
|
self.conv.weight[:, 0, center] = 1.0 |
|
else: |
|
self.conv = None |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
if not self.enabled: |
|
return x |
|
y = x.transpose(1, 2) |
|
y = F.pad(y, (self.kernel_size - 1, 0)) |
|
y = self.conv(y).transpose(1, 2) |
|
return F.silu(y) |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
def __init__(self, dim: int, theta: float = 10000.0): |
|
super().__init__() |
|
if dim % 2 != 0: |
|
raise ValueError(f"RoPE head dim must be even, got {dim}") |
|
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tuple[Tensor, Tensor]: |
|
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
|
freqs = torch.einsum("t,d->td", t, self.inv_freq.to(device)) |
|
return freqs.cos().to(dtype), freqs.sin().to(dtype) |
|
|
|
|
|
def apply_rotary(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: |
|
# x: (B, H, T, D) or (B, T, H, R, D). cos/sin: (T, D/2) |
|
d = x.shape[-1] |
|
x_even = x[..., 0::2] |
|
x_odd = x[..., 1::2] |
|
if x.dim() == 4: |
|
c = cos[None, None, :, :] |
|
s = sin[None, None, :, :] |
|
elif x.dim() == 5: |
|
c = cos[None, :, None, None, :] |
|
s = sin[None, :, None, None, :] |
|
else: |
|
raise ValueError(f"Unsupported tensor rank for RoPE: {x.dim()}") |
|
y_even = x_even * c - x_odd * s |
|
y_odd = x_even * s + x_odd * c |
|
return torch.stack((y_even, y_odd), dim=-1).reshape(*x.shape[:-1], d) |
|
|
|
|
|
def apply_data_dependent_rotary(x: Tensor, cumulative_angles: Tensor) -> Tensor: |
|
# x: (B, T, H, R, D), angles: (B, T, H, D/2) |
|
d = x.shape[-1] |
|
x_even = x[..., 0::2] |
|
x_odd = x[..., 1::2] |
|
c = cumulative_angles.cos()[:, :, :, None, :].to(dtype=x.dtype) |
|
s = cumulative_angles.sin()[:, :, :, None, :].to(dtype=x.dtype) |
|
y_even = x_even * c - x_odd * s |
|
y_odd = x_even * s + x_odd * c |
|
return torch.stack((y_even, y_odd), dim=-1).reshape(*x.shape[:-1], d) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
heads: int, |
|
dim_head: Optional[int] = None, |
|
rope_theta: float = 500000.0, |
|
qk_norm: bool = True, |
|
flash_attn: bool = True, |
|
): |
|
super().__init__() |
|
self.dim = dim |
|
self.heads = heads |
|
self.dim_head = dim_head or dim // heads |
|
self.flash_attn = flash_attn |
|
self.qk_norm = qk_norm |
|
inner = self.heads * self.dim_head |
|
self.wq = nn.Linear(dim, inner, bias=False) |
|
self.wk = nn.Linear(dim, inner, bias=False) |
|
self.wv = nn.Linear(dim, inner, bias=False) |
|
self.wo = nn.Linear(inner, dim, bias=False) |
|
self.rotary = RotaryEmbedding(self.dim_head, theta=rope_theta) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
b, t, _ = x.shape |
|
q = self.wq(x).view(b, t, self.heads, self.dim_head).transpose(1, 2) |
|
k = self.wk(x).view(b, t, self.heads, self.dim_head).transpose(1, 2) |
|
v = self.wv(x).view(b, t, self.heads, self.dim_head).transpose(1, 2) |
|
if self.qk_norm: |
|
q = F.normalize(q, dim=-1) |
|
k = F.normalize(k, dim=-1) |
|
cos, sin = self.rotary(t, x.device, q.dtype) |
|
q = apply_rotary(q, cos, sin) |
|
k = apply_rotary(k, cos, sin) |
|
out = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
out = out.transpose(1, 2).contiguous().view(b, t, self.heads * self.dim_head) |
|
return self.wo(out) |
|
|
|
|
|
@dataclass(frozen=True) |
|
class MIMOGDNConfig: |
|
dim: int |
|
heads: int |
|
head_k_dim: int |
|
expand_v: float = 2.0 |
|
mimo_rank: int = 4 |
|
conv_size: int = 4 |
|
allow_neg_eigval: bool = True |
|
parameter_efficient_mimo: bool = True |
|
beta_rank_rescale: bool = True |
|
orthonormalize_k: bool = False |
|
trapezoidal: bool = False |
|
trapezoid_lambda_init: float = 0.90 |
|
data_dependent_qk_rope: bool = False |
|
output_norm_eps: float = 1e-5 |
|
|
|
|
|
class MIMOGatedDeltaNet(nn.Module): |
|
"""Rank-R Gated DeltaNet layer with fixed-size state. |
|
|
|
Standard GDN update for one head can be written with a state S in R^{d_v x d_k}: |
|
|
|
S_t = alpha_t * (S_{t-1} - beta_t * (S_{t-1} k_t) k_t^T) |
|
+ beta_t * v_t k_t^T |
|
y_t = S_t q_t |
|
|
|
This layer generalizes k, q, and v to R columns per token: |
|
|
|
K_t, Q_t in R^{R x d_k}, V_t in R^{R x d_v} |
|
S_t = alpha_t * (S_{t-1} - sum_r beta_{t,r} (S_{t-1} k_{t,r}) k_{t,r}^T) |
|
+ sum_r beta_{t,r} v_{t,r} k_{t,r}^T |
|
Y_t[:, r] = S_t q_{t,r} |
|
|
|
The state stays d_v x d_k per head, independent of R. Increasing R raises |
|
per-token arithmetic intensity and expressive write/read rank. |
|
""" |
|
|
|
def __init__(self, cfg: MIMOGDNConfig): |
|
super().__init__() |
|
if cfg.head_k_dim % 2 != 0 and cfg.data_dependent_qk_rope: |
|
raise ValueError("data_dependent_qk_rope requires an even head_k_dim") |
|
self.cfg = cfg |
|
self.dim = cfg.dim |
|
self.heads = cfg.heads |
|
self.head_k_dim = cfg.head_k_dim |
|
self.head_v_dim = int(round(cfg.head_k_dim * cfg.expand_v)) |
|
self.rank = int(cfg.mimo_rank) |
|
if self.rank < 1: |
|
raise ValueError("mimo_rank must be >= 1") |
|
|
|
h, dk, dv, r = self.heads, self.head_k_dim, self.head_v_dim, self.rank |
|
qk_total = h * dk |
|
v_total = h * dv |
|
|
|
self.parameter_efficient_mimo = cfg.parameter_efficient_mimo |
|
if cfg.parameter_efficient_mimo: |
|
self.w_q = nn.Linear(cfg.dim, qk_total, bias=False) |
|
self.w_k = nn.Linear(cfg.dim, qk_total, bias=False) |
|
self.w_v = nn.Linear(cfg.dim, v_total, bias=False) |
|
self.q_rank_scale = nn.Parameter(torch.empty(h, r, dk)) |
|
self.k_rank_scale = nn.Parameter(torch.empty(h, r, dk)) |
|
self.v_rank_scale = nn.Parameter(torch.empty(h, r, dv)) |
|
nn.init.normal_(self.q_rank_scale, mean=1.0, std=0.02) |
|
nn.init.normal_(self.k_rank_scale, mean=1.0, std=0.02) |
|
nn.init.normal_(self.v_rank_scale, mean=1.0, std=0.02) |
|
else: |
|
self.w_q = nn.Linear(cfg.dim, h * r * dk, bias=False) |
|
self.w_k = nn.Linear(cfg.dim, h * r * dk, bias=False) |
|
self.w_v = nn.Linear(cfg.dim, h * r * dv, bias=False) |
|
self.q_rank_scale = None |
|
self.k_rank_scale = None |
|
self.v_rank_scale = None |
|
|
|
self.q_conv = CausalDepthwiseConv1d(qk_total if cfg.parameter_efficient_mimo else h * r * dk, cfg.conv_size) |
|
self.k_conv = CausalDepthwiseConv1d(qk_total if cfg.parameter_efficient_mimo else h * r * dk, cfg.conv_size) |
|
self.v_conv = CausalDepthwiseConv1d(v_total if cfg.parameter_efficient_mimo else h * r * dv, cfg.conv_size) |
|
|
|
self.w_alpha = nn.Linear(cfg.dim, h, bias=False) |
|
self.w_beta = nn.Linear(cfg.dim, h * r, bias=False) |
|
self.w_out_gate = nn.Linear(cfg.dim, h * dv, bias=False) |
|
self.w_out = nn.Linear(h * dv, cfg.dim, bias=False) |
|
self.out_norm = RMSNorm(dv, eps=cfg.output_norm_eps) |
|
self.rank_mixer = nn.Parameter(torch.zeros(h, r)) |
|
|
|
# FLA-style-ish continuous decay prior. alpha = exp(-exp(A_log) * softplus(dt)). |
|
self.A_log = nn.Parameter(torch.empty(h)) |
|
self.dt_bias = nn.Parameter(torch.empty(h)) |
|
self.reset_decay_parameters() |
|
|
|
if cfg.trapezoidal: |
|
self.w_lambda = nn.Linear(cfg.dim, h, bias=False) |
|
p = min(max(float(cfg.trapezoid_lambda_init), 1e-4), 1.0 - 1e-4) |
|
self.lambda_bias = nn.Parameter(torch.full((h,), math.log(p / (1.0 - p)))) |
|
else: |
|
self.w_lambda = None |
|
self.lambda_bias = None |
|
|
|
if cfg.data_dependent_qk_rope: |
|
self.w_theta = nn.Linear(cfg.dim, h * (dk // 2), bias=False) |
|
else: |
|
self.w_theta = None |
|
|
|
def reset_decay_parameters(self) -> None: |
|
# A roughly follows a positive range up to 16; dt is log-uniform in [1e-3, 1e-1]. |
|
with torch.no_grad(): |
|
a = torch.empty(self.heads).uniform_(1.0, 16.0) |
|
dt = torch.exp(torch.empty(self.heads).uniform_(math.log(1e-3), math.log(1e-1))) |
|
self.A_log.copy_(a.log()) |
|
self.dt_bias.copy_(torch.log(torch.expm1(dt))) |
|
|
|
def _project_qkv(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: |
|
b, t, _ = x.shape |
|
h, r, dk, dv = self.heads, self.rank, self.head_k_dim, self.head_v_dim |
|
if self.parameter_efficient_mimo: |
|
q = self.q_conv(self.w_q(x)).view(b, t, h, dk) |
|
k = self.k_conv(self.w_k(x)).view(b, t, h, dk) |
|
v = self.v_conv(self.w_v(x)).view(b, t, h, dv) |
|
q = q[:, :, :, None, :] * self.q_rank_scale.to(dtype=q.dtype, device=q.device)[None, None, :, :, :] |
|
k = k[:, :, :, None, :] * self.k_rank_scale.to(dtype=k.dtype, device=k.device)[None, None, :, :, :] |
|
v = v[:, :, :, None, :] * self.v_rank_scale.to(dtype=v.dtype, device=v.device)[None, None, :, :, :] |
|
else: |
|
q = self.q_conv(self.w_q(x)).view(b, t, h, r, dk) |
|
k = self.k_conv(self.w_k(x)).view(b, t, h, r, dk) |
|
v = self.v_conv(self.w_v(x)).view(b, t, h, r, dv) |
|
q = F.normalize(q, dim=-1) |
|
k = F.normalize(k, dim=-1) |
|
if self.cfg.orthonormalize_k and r > 1: |
|
# QR over the R columns in Dk-space. This is useful for a correctness/stability |
|
# prototype but too expensive to be the default fast path. |
|
flat = k.reshape(b * t * h, r, dk).transpose(-1, -2) # (BTH, Dk, R) |
|
q_orth, _ = torch.linalg.qr(flat.float(), mode="reduced") |
|
k = q_orth.transpose(-1, -2).to(k.dtype).reshape(b, t, h, r, dk) |
|
if self.w_theta is not None: |
|
angles = self.w_theta(x).view(b, t, h, dk // 2).float().cumsum(dim=1) |
|
q = apply_data_dependent_rotary(q, angles) |
|
k = apply_data_dependent_rotary(k, angles) |
|
return q, k, v |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
b, t, _ = x.shape |
|
h, r, dk, dv = self.heads, self.rank, self.head_k_dim, self.head_v_dim |
|
q, k, v = self._project_qkv(x) |
|
|
|
alpha_pre = self.w_alpha(x).float() |
|
alpha_log = -self.A_log.exp()[None, None, :] * F.softplus( |
|
alpha_pre + self.dt_bias[None, None, :] |
|
) |
|
alpha = alpha_log.exp().to(dtype=x.dtype) |
|
|
|
beta = torch.sigmoid(self.w_beta(x)).view(b, t, h, r).to(dtype=x.dtype) |
|
if self.cfg.allow_neg_eigval: |
|
beta = beta * 2.0 |
|
if self.cfg.beta_rank_rescale and r > 1 and not self.cfg.orthonormalize_k: |
|
beta = beta / math.sqrt(r) |
|
|
|
if self.cfg.trapezoidal: |
|
lam = torch.sigmoid( |
|
self.w_lambda(x).float() + self.lambda_bias[None, None, :] |
|
).to(dtype=x.dtype) |
|
else: |
|
lam = None |
|
|
|
state = torch.zeros(b, h, dv, dk, device=x.device, dtype=x.dtype) |
|
prev_write = torch.zeros_like(state) |
|
outs: list[Tensor] = [] |
|
|
|
for idx in range(t): |
|
q_t = q[:, idx] # (B,H,R,Dk) |
|
k_t = k[:, idx] # (B,H,R,Dk) |
|
v_t = v[:, idx] # (B,H,R,Dv) |
|
beta_t = beta[:, idx] # (B,H,R) |
|
alpha_t = alpha[:, idx, :, None, None] |
|
|
|
sk = torch.einsum("bhvd,bhrd->bhvr", state, k_t) |
|
transition = torch.einsum("bhvr,bhrd,bhr->bhvd", sk, k_t, beta_t) |
|
write = torch.einsum("bhrv,bhrd,bhr->bhvd", v_t, k_t, beta_t) |
|
candidate = alpha_t * (state - transition) |
|
|
|
if lam is not None: |
|
pwk = torch.einsum("bhvd,bhrd->bhvr", prev_write, k_t) |
|
transported_prev = alpha_t * ( |
|
prev_write - torch.einsum("bhvr,bhrd,bhr->bhvd", pwk, k_t, beta_t) |
|
) |
|
lam_t = lam[:, idx, :, None, None] |
|
state = candidate + lam_t * write + (1.0 - lam_t) * transported_prev |
|
else: |
|
state = candidate + write |
|
prev_write = write |
|
|
|
y = torch.einsum("bhvd,bhrd->bhrv", state, q_t) |
|
outs.append(y) |
|
|
|
y_all = torch.stack(outs, dim=1) # (B,T,H,R,Dv) |
|
rank_weights = F.softmax(self.rank_mixer.float(), dim=-1).to(dtype=y_all.dtype) |
|
y = torch.einsum("bthrv,hr->bthv", y_all, rank_weights) |
|
y = self.out_norm(y) |
|
gate = F.silu(self.w_out_gate(x).view(b, t, h, dv)) |
|
y = y * gate |
|
return self.w_out(y.reshape(b, t, h * dv)) |
|
|
|
|
|
class HybridBlock(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
mixer: nn.Module, |
|
ffn_dim_multiplier: Optional[float] = None, |
|
norm_eps: float = 1e-5, |
|
): |
|
super().__init__() |
|
self.mixer = mixer |
|
self.mixer_norm = RMSNorm(dim, eps=norm_eps) |
|
self.ff_norm = RMSNorm(dim, eps=norm_eps) |
|
self.ff = SwiGLU(dim, ffn_dim_multiplier=ffn_dim_multiplier) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
x = x + self.mixer(self.mixer_norm(x)) |
|
x = x + self.ff(self.ff_norm(x)) |
|
return x |
|
|
|
|
|
class HybridMIMOGDNLM(nn.Module): |
|
"""OLMo-Hybrid-style decoder with [GDN, GDN, GDN, Attn] interleaving.""" |
|
|
|
def __init__( |
|
self, |
|
num_tokens: int, |
|
dim: int = 512, |
|
depth: int = 16, |
|
heads: int = 8, |
|
dim_head: Optional[int] = None, |
|
gdn_head_k_dim: Optional[int] = None, |
|
gdn_expand_v: float = 2.0, |
|
gdn_mimo_rank: int = 4, |
|
gdn_trapezoidal: bool = False, |
|
gdn_data_dependent_qk_rope: bool = False, |
|
gdn_parameter_efficient_mimo: bool = True, |
|
block_pattern: Sequence[str] = ("gdn", "gdn", "gdn", "attn"), |
|
tie_embeddings: bool = True, |
|
ffn_dim_multiplier: Optional[float] = None, |
|
rope_theta: float = 500000.0, |
|
max_seq_len: int = 2048, |
|
flash_attn: bool = True, |
|
): |
|
super().__init__() |
|
self.vocab_size = num_tokens |
|
self.model_dim = dim |
|
self.max_seq_len = max_seq_len |
|
self.token_emb = nn.Embedding(num_tokens, dim) |
|
self.blocks = nn.ModuleList() |
|
self.block_types: list[str] = [] |
|
dim_head = dim_head or dim // heads |
|
gdn_head_k_dim = gdn_head_k_dim or max(8, int(0.75 * dim_head)) |
|
|
|
pattern = list(block_pattern) |
|
if not pattern: |
|
raise ValueError("block_pattern must be non-empty") |
|
for i in range(depth): |
|
typ = pattern[i % len(pattern)].lower() |
|
if typ not in {"gdn", "attn"}: |
|
raise ValueError(f"Unsupported block type {typ!r}") |
|
if typ == "attn": |
|
mixer = Attention(dim, heads=heads, dim_head=dim_head, rope_theta=rope_theta, flash_attn=flash_attn) |
|
else: |
|
mixer = MIMOGatedDeltaNet( |
|
MIMOGDNConfig( |
|
dim=dim, |
|
heads=heads, |
|
head_k_dim=gdn_head_k_dim, |
|
expand_v=gdn_expand_v, |
|
mimo_rank=gdn_mimo_rank, |
|
trapezoidal=gdn_trapezoidal, |
|
data_dependent_qk_rope=gdn_data_dependent_qk_rope, |
|
parameter_efficient_mimo=gdn_parameter_efficient_mimo, |
|
) |
|
) |
|
self.blocks.append(HybridBlock(dim, mixer, ffn_dim_multiplier=ffn_dim_multiplier)) |
|
self.block_types.append(typ) |
|
|
|
self.norm = RMSNorm(dim) |
|
self.lm_head = nn.Linear(dim, num_tokens, bias=False) |
|
if tie_embeddings: |
|
self.lm_head.weight = self.token_emb.weight |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, module: nn.Module) -> None: |
|
if isinstance(module, nn.Linear): |
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
elif isinstance(module, nn.Embedding): |
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
|
def forward( |
|
self, |
|
x: Tensor, |
|
targets: Optional[Tensor] = None, |
|
return_loss: bool = False, |
|
mask: Optional[Tensor] = None, |
|
): |
|
del mask |
|
h = self.token_emb(x) |
|
for block in self.blocks: |
|
h = block(h) |
|
logits = self.lm_head(self.norm(h)) |
|
if return_loss or targets is not None: |
|
if targets is None: |
|
targets = x[:, 1:].contiguous() |
|
logits_for_loss = logits[:, :-1].contiguous() |
|
else: |
|
logits_for_loss = logits |
|
loss = F.cross_entropy( |
|
logits_for_loss.view(-1, logits_for_loss.size(-1)), |
|
targets.reshape(-1), |
|
) |
|
return loss if return_loss else (logits, loss) |
|
return logits |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
prompt: Tensor, |
|
max_length: int, |
|
temperature: float = 1.0, |
|
filter_thres: float = 0.9, |
|
min_p: Optional[float] = None, |
|
) -> Tensor: |
|
del min_p |
|
out = prompt |
|
for _ in range(max_length): |
|
x = out[:, -self.max_seq_len :] |
|
logits = self(x)[:, -1] |
|
logits = logits / max(float(temperature), 1e-6) |
|
if filter_thres < 1.0: |
|
probs = logits.softmax(dim=-1) |
|
sorted_probs, sorted_idx = probs.sort(dim=-1, descending=True) |
|
cum = sorted_probs.cumsum(dim=-1) |
|
keep = cum <= filter_thres |
|
keep[..., 0] = True |
|
filtered = torch.full_like(logits, float("-inf")) |
|
filtered.scatter_(dim=-1, index=sorted_idx, src=torch.where(keep, logits.gather(-1, sorted_idx), torch.full_like(sorted_probs, float("-inf")))) |
|
logits = filtered |
|
next_token = torch.distributions.Categorical(logits=logits).sample()[:, None] |
|
out = torch.cat([out, next_token], dim=1) |
|
return out |
|
|
|
|
|
def count_parameters(model: nn.Module) -> int: |
|
return sum(p.numel() for p in model.parameters()) |
|
|
|
|
|
def smoke_test() -> None: |
|
torch.manual_seed(0) |
|
# The pure-PyTorch recurrent loop uses many tiny ops on CPU; one thread is |
|
# far faster for the smoke test because it avoids threadpool overhead. |
|
torch.set_num_threads(1) |
|
model = HybridMIMOGDNLM( |
|
num_tokens=64, |
|
dim=32, |
|
depth=2, |
|
heads=1, |
|
dim_head=32, |
|
gdn_head_k_dim=16, |
|
gdn_expand_v=1.5, |
|
gdn_mimo_rank=2, |
|
gdn_trapezoidal=True, |
|
gdn_data_dependent_qk_rope=True, |
|
ffn_dim_multiplier=1.0, |
|
block_pattern=("gdn", "attn"), |
|
flash_attn=False, |
|
) |
|
x = torch.randint(0, 64, (1, 8)) |
|
logits, loss = model(x, x, return_loss=False) |
|
loss.backward() |
|
print({"params": count_parameters(model), "logits": tuple(logits.shape), "loss": float(loss.detach())}) |
|
|
|
|
|
if __name__ == "__main__": |
|
smoke_test() |