Last active
January 14, 2026 23:10
-
-
Save proger/3f7ab9bf5f6e5950f353c7f9cd20afb4 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Gated Linear RNNs with state expansion are Linear Transformers with data-dependent cumulative masking | |
| """ | |
| #%% | |
| import torch | |
| from torch import tensor | |
| def pascal1(g): | |
| "compute the mask: prescan gates in log space (explicit)" | |
| N, T = g.shape | |
| l = g.new_zeros(N, T, T) + float('-inf') | |
| for t in range(T): | |
| l[:, t, t] = 0 | |
| for s in range(t): | |
| l[:, s, t] = sum(g[:, k] for k in range(s+1, t)) + g[:, t] | |
| return l | |
| def pascal(g): | |
| "compute the mask: prescan gates in log space (dynamic programming)" | |
| N, T = g.shape | |
| l = g.new_zeros(N, T, T) + float('-inf') | |
| for t in range(T): | |
| l[:, t, t] = 0 | |
| for s in range(t-1, -1, -1): | |
| l[:, s, t] = l[:, s+1, t] + g[:, s+1] | |
| return l | |
| def masked_attend(q, k, v, g): | |
| "masked linear attention: mask is data dependent, no softmax -- can be an RNN" | |
| y = torch.einsum('nsk,ntk,ntv,nts->nsv', q, k, v, g.exp()) | |
| return y | |
| def causal_attend(q, k, v): | |
| "causal attention: the mask is triangular and static" | |
| qk = torch.einsum('nsk,ntk->nst', q, k) # for size | |
| return masked_attend(q, k, v, torch.tril(torch.ones_like(qk))) | |
| def alibi_attend(q, k, v): | |
| "alibi attention: the mask it static but weird; has softmax, unbounded memory" | |
| qk = torch.einsum('nsk,ntk->nst', q, k) | |
| m = ... # some mask with static bias | |
| qk = (qk * m).softmax(dim=-1) | |
| qk_v = torch.einsum('nst,ntv->nsv', qk, v) | |
| return qk_v | |
| def lscan(q, k, v, f): | |
| "linear time scan: fast weight programmer-style loop" | |
| N, D, T = k.shape | |
| N, T = f.shape | |
| h = k.new_zeros(N, D, D, T) | |
| y = k.new_zeros(N, D, T) | |
| h[..., 0] = torch.einsum('nk,nv->nkv', k[..., 0], v[..., 0]) | |
| y[..., 0] = torch.einsum('nk,nkv->nv', q[..., 0], h[..., 0]) | |
| for i in range(1, T): | |
| h[..., i] = f[:, None, None, i] * h[..., i-1] + torch.einsum('nk,nv->nkv', k[..., i], v[..., i]) | |
| y[..., i] = torch.einsum('nk,nkv->nv', q[..., i], h[..., i]) | |
| return y | |
| if __name__ == '__main__': | |
| primes = tensor([1, 2, 3, 5, 7, 11, 13, 17, 19])[None, :] | |
| a = pascal(primes.log()).exp() | |
| b = primes.cumprod(dim=-1).float() | |
| assert torch.allclose(a[:, 0, :], b), f'{a[:,0,:]} != {b}' | |
| torch.manual_seed(0) | |
| N, T, D = 1, 8, 3 | |
| q = torch.randn(N, T, D) | |
| k = torch.randn(N, T, D) | |
| v = torch.randn(N, T, D) | |
| #f = torch.rand(N, T) # token-level forget gates: "Gated RNN" with outer product state expansion | |
| f = torch.ones(N, T)*0.9 # same sequence-level forget gate: FWP with Decay | |
| g = pascal(f.log()) # Prescan of all gates | |
| y1 = masked_attend(q, k, v, g) | |
| print(y1, 'gated_attend') # N,T,D | |
| y2 = lscan(q.mT, k.mT, v.mT, f).mT | |
| print(y2, 'lscan') | |
| assert torch.allclose(y1, y2, atol=1e-6), 'gate_attend and lscan should be the same' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment