Skip to content

Instantly share code, notes, and snippets.

@proger
Last active January 14, 2026 23:10
Show Gist options
  • Select an option

  • Save proger/3f7ab9bf5f6e5950f353c7f9cd20afb4 to your computer and use it in GitHub Desktop.

Select an option

Save proger/3f7ab9bf5f6e5950f353c7f9cd20afb4 to your computer and use it in GitHub Desktop.
"""
Gated Linear RNNs with state expansion are Linear Transformers with data-dependent cumulative masking
"""
#%%
import torch
from torch import tensor
def pascal1(g):
"compute the mask: prescan gates in log space (explicit)"
N, T = g.shape
l = g.new_zeros(N, T, T) + float('-inf')
for t in range(T):
l[:, t, t] = 0
for s in range(t):
l[:, s, t] = sum(g[:, k] for k in range(s+1, t)) + g[:, t]
return l
def pascal(g):
"compute the mask: prescan gates in log space (dynamic programming)"
N, T = g.shape
l = g.new_zeros(N, T, T) + float('-inf')
for t in range(T):
l[:, t, t] = 0
for s in range(t-1, -1, -1):
l[:, s, t] = l[:, s+1, t] + g[:, s+1]
return l
def masked_attend(q, k, v, g):
"masked linear attention: mask is data dependent, no softmax -- can be an RNN"
y = torch.einsum('nsk,ntk,ntv,nts->nsv', q, k, v, g.exp())
return y
def causal_attend(q, k, v):
"causal attention: the mask is triangular and static"
qk = torch.einsum('nsk,ntk->nst', q, k) # for size
return masked_attend(q, k, v, torch.tril(torch.ones_like(qk)))
def alibi_attend(q, k, v):
"alibi attention: the mask it static but weird; has softmax, unbounded memory"
qk = torch.einsum('nsk,ntk->nst', q, k)
m = ... # some mask with static bias
qk = (qk * m).softmax(dim=-1)
qk_v = torch.einsum('nst,ntv->nsv', qk, v)
return qk_v
def lscan(q, k, v, f):
"linear time scan: fast weight programmer-style loop"
N, D, T = k.shape
N, T = f.shape
h = k.new_zeros(N, D, D, T)
y = k.new_zeros(N, D, T)
h[..., 0] = torch.einsum('nk,nv->nkv', k[..., 0], v[..., 0])
y[..., 0] = torch.einsum('nk,nkv->nv', q[..., 0], h[..., 0])
for i in range(1, T):
h[..., i] = f[:, None, None, i] * h[..., i-1] + torch.einsum('nk,nv->nkv', k[..., i], v[..., i])
y[..., i] = torch.einsum('nk,nkv->nv', q[..., i], h[..., i])
return y
if __name__ == '__main__':
primes = tensor([1, 2, 3, 5, 7, 11, 13, 17, 19])[None, :]
a = pascal(primes.log()).exp()
b = primes.cumprod(dim=-1).float()
assert torch.allclose(a[:, 0, :], b), f'{a[:,0,:]} != {b}'
torch.manual_seed(0)
N, T, D = 1, 8, 3
q = torch.randn(N, T, D)
k = torch.randn(N, T, D)
v = torch.randn(N, T, D)
#f = torch.rand(N, T) # token-level forget gates: "Gated RNN" with outer product state expansion
f = torch.ones(N, T)*0.9 # same sequence-level forget gate: FWP with Decay
g = pascal(f.log()) # Prescan of all gates
y1 = masked_attend(q, k, v, g)
print(y1, 'gated_attend') # N,T,D
y2 = lscan(q.mT, k.mT, v.mT, f).mT
print(y2, 'lscan')
assert torch.allclose(y1, y2, atol=1e-6), 'gate_attend and lscan should be the same'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment