Skip to content

Instantly share code, notes, and snippets.

@proger
Created October 8, 2024 05:17
Show Gist options
  • Save proger/16824d972c624d13426b02e159758abc to your computer and use it in GitHub Desktop.
Save proger/16824d972c624d13426b02e159758abc to your computer and use it in GitHub Desktop.
"""
DeltaNet implementation reference for Accelerated Scan. DeltaNet performs efficient management of a large fixed-sized memory.
For a simple single chunk version see `forward_simple`.
It computes decayed values by a little bit of recurrence (`decay_values`)
and then applies linear attention (`causal_attend`).
`forward_chunkwise` is inspired by Yang 2024. It applies single chunk version pointwise and
then performs chunk-level stitching.
forward_ogloop and forward_scanloop are reference implementations of straightforward recurrences.
References:
[1] The WY Representation for Products of Householder Matrices (Bischof and Van Loan 1985)
Method 1, section 3 guides `decay_values`.
https://ecommons.cornell.edu/items/92a11030-dca1-45d4-a0ba-732cf962b2b2
[2] Parallelizing Linear Transformers with the Delta Rule over Sequence Length (Yang et al 2024)
- equation 5 is a specialization of method 1 of [1] is in `decay_values`
- equation 6 is application of decayed keys to values is also in `decay_values`
- `forward_chunkwise` uses the distributed form of equation 7 and 8
(actually look the two equations before it instead, they are easier to read)
https://arxiv.org/abs/2406.06484
[3] Linear Transformers Are Secretly Fast Weight Programmers (Schlag et al 2021)
Introduction to Transformers as RNNs. Ignore all of the kernel stuff.
https://arxiv.org/abs/2102.11174
"""
#%%
import os
os.environ['TORCH_LOGS'] = 'output_code'
import torch
from torch import einsum, randn, allclose, stack, eye, manual_seed, no_grad, set_float32_matmul_precision, compile, arange
set_float32_matmul_precision('high')
def decay_values(k, v, beta):
"decay values applying deltanet forgetting rules"
NH, T, D = shape(None, k, v, beta)
beta_ = beta.unsqueeze(-1)
w = beta_ * k.clone()
u = beta_ * v.clone()
K = einsum('nsd,ntd->nst', k, k) # (T,T) matrix
for t in range(1,T):
w[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], w[:, :t].clone())
u[:, t] -= beta_[:, t] * einsum('nt,ntd->nd', K[:, :t, t], u[:, :t].clone())
return w, u
def causal_attend(q, k, v, diagonal=0):
"apply linear attention with a causal mask"
NH, T, D = shape(q, k, v)
mask = q.new_ones(T, T).tril(diagonal=diagonal)
y = einsum("nsi,nti,st,ntj->nsj", q, k, mask, v)
return y
def forward_simple(q, k, v, beta):
"simple deltanet: linear attention to decayed values"
w, u = decay_values(k, v, beta)
return causal_attend(q, k, u)
def forward_chunkwise(q, k, v, beta, chunk_size=2):
NH, T, D = shape(q, k, v, beta)
C = T // chunk_size
q_, k_, v_, beta_ = (
q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
)
# evaluate all chunks in parallel
w, u = decay_values(k_, v_, beta_)
y = causal_attend(q_, k_, u)
# stitch chunks sequentially
y_delta, _ = stitch_forward(q, k, w, u, C=C, chunk_size=chunk_size)
return y.view(NH, T, D) + y_delta.view(NH, T, D)
def stitch_forward(q, k, w, u, C, chunk_size):
"stitch chunks sequentially"
NH, T, D = shape(q, k, None, None)
q_ = q.view(NH, C, chunk_size, D)
k_ = k.view(NH, C, chunk_size, D)
u = u.view(NH, C, chunk_size, D)
w = w.view(NH, C, chunk_size, D)
# materialize the state for the leading chunk
state = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])
deltas = [u.new_zeros(NH, chunk_size, D)]
for c in range(1, C):
y_delta1, state = stitch1_forward(state, q_[:, c], k_[:, c], w[:, c], u[:, c])
deltas.append(y_delta1)
y_delta = torch.stack(deltas, dim=1)
return y_delta, state
def stitch1_forward(state, q, k, w, u):
state_decays = einsum('nvk,ntk->ntv', state, w)
state_add = einsum('ntv,ntk->nvk', u - state_decays, k)
delta = causal_attend(q, k, state_decays)
prev_output = einsum('nvk,nsk->nsv', state, q)
y = prev_output - delta
return y, state + state_add
def forward_ogloop(q, k, v, beta):
"reference: w_t = w_{t-1} + beta_t (v_t - w_t k_t) k_t"
NH, T, D = shape(q, k, v, beta)
w = k.new_zeros(NH, D, D)
y = []
for t in range(T):
q_ = q[:, t]
k_ = k[:, t]
v_ = v[:, t]
beta_ = beta[:, t].unsqueeze(-1)
v_old = einsum("nij,nj->ni", w, k_)
delta = beta_ * (v_ - v_old)
w = w + einsum("ni,nj->nij", delta, k_)
y.append(einsum("nij,nj->ni", w, q_))
return stack(y, dim=1)
def forward_scanloop(q, k, v, beta):
"reference via linear-time scan: w_t = w_{t-1} (I - beta_t k_t k_t.T) + beta v_t k_t.T"
NH, T, D = shape(q, k, v, beta)
w = k.new_zeros(NH, D, D)
id = eye(D, device=w.device).expand(NH, D, D)
y = []
for t in range(T):
q_ = q[:, t]
k_ = k[:, t]
v_ = v[:, t]
beta_ = beta[:, t].unsqueeze(-1).unsqueeze(-1)
beta_sqrt_ = beta_.squeeze(-1).sqrt()
forget = id - einsum("ni,nj->nij", beta_sqrt_ * k_, beta_sqrt_ * k_)
update = beta_ * einsum("ni,nj->nij", v_, k_)
w = einsum("nik,nkj->nij", w, forget) + update
y.append(einsum("nij,nj->ni", w, q_))
return stack(y, dim=1)
def shape(q, k, v, beta=None):
NH, T, D = (q if q is not None else k).shape
if q is not None:
assert q.shape == (NH, T, D)
if v is not None:
assert k.shape == v.shape
if beta is not None:
assert beta.shape == (NH, T)
return NH, T, D
def make_example(NH, T, D):
manual_seed(0)
q = randn(NH, T, D) / D**0.5
q.requires_grad_()
k = randn(NH, T, D) / D**0.5
k.requires_grad_()
v = randn(NH, T, D) / D**0.5
v.requires_grad_()
beta = randn(NH, T).sigmoid()
beta.requires_grad_()
return q, k, v, beta
def test_equal(atol=1e-6):
NH, T, D = 2*3, 128, 16
#NH, T, D = 1, 8, 3
q, k, v, beta = make_example(NH, T, D)
y1 = forward_ogloop(q, k, v, beta)
y2 = forward_scanloop(q, k, v, beta)
y3 = forward_simple(q, k, v, beta)
assert allclose(y1, y2, atol=atol), (y1 - y2).abs().max()
assert allclose(y1, y3, atol=atol), (y1 - y3).abs().max()
for chunk_size in (1,2,4,8):
y = forward_chunkwise(q, k, v, beta, chunk_size)
assert allclose(y1, y, atol=atol), (y1 - y).abs().max()
test_equal()
#%%
@no_grad()
def attend_backward(d_out, q, k, v, g):
d_q = einsum('nsv,ntk,ntv,nst->nsk', d_out, k, v, g)
d_k = einsum('nsv,nsk,ntv,nst->ntk', d_out, q, v, g)
d_v = einsum('nsv,nsk,ntk,nst->ntv', d_out, q, k, g)
d_g = einsum('nsv,nsk,ntk,ntv,nst->nst', d_out, q, k, v, g)
return d_q, d_k, d_v, d_g
@no_grad()
def causal_attend_backward(d_out, q, k, v, diagonal=0):
NH, T, D = shape(q, k, v)
mask = q.new_ones(T, T).tril(diagonal=diagonal).unsqueeze(0)
d_q, q_k, d_v, _d_mask = attend_backward(d_out, q, k, v, mask)
return d_q, q_k, d_v
class CausalAttend(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v):
ctx.save_for_backward(q, k, v)
return causal_attend(q, k, v)
@staticmethod
def backward(ctx, d_out):
q, k, v = ctx.saved_tensors
return causal_attend_backward(d_out, q, k, v)
def test_equal_attend_backward(atol=1e-5):
NH, T, D = 1*1, 512, 64
q, k, v, beta = make_example(NH, T, D)
y = causal_attend(q, k, v)
d_q, d_k, d_v = causal_attend_backward(torch.ones_like(y), q, k, v)
y.sum().backward()
assert allclose(q.grad, d_q, atol=atol), 'q.grad is wrong'
assert allclose(k.grad, d_k, atol=atol), 'k.grad is wrong'
assert allclose(v.grad, d_v, atol=atol), 'v.grad is wrong'
## TODO: test gates (g)
# print((g_hook.grad - d_g).pow(2).mean(), 'error')
# print((g_hook.grad - d_g).abs().max(), 'max abs error')
# assert torch.allclose(g_hook.grad, d_g, atol=1e-1), 'g.grad is wrong'
def test_equal_attend_backward2(atol=1e-5):
NH, T, D = 1, 2, 2
q1, k1, v1, beta1 = make_example(NH, T, D)
y1 = causal_attend(q1, k1, v1)
(y1 - torch.ones_like(y1)).pow(2).mean().backward()
q, k, v, beta = make_example(NH, T, D)
y = CausalAttend.apply(q, k, v)
(y - torch.ones_like(y)).pow(2).mean().backward()
# print(q1.grad - q.grad, 'q.grad diff')
# print(k1.grad - k.grad, 'k.grad diff')
# print(v1.grad - v.grad, 'v.grad diff')
# print(k.grad, 'k.grad')
# print(k1.grad, 'k1.grad')
# print(v.grad, 'v.grad')
# print(v1.grad, 'v1.grad')
assert (q1.grad - q.grad).abs().max() < atol, 'q.grad is wrong'
assert (k1.grad - k.grad).abs().max() < atol, 'k.grad is wrong'
assert (v1.grad - v.grad).abs().max() < atol, 'v.grad is wrong'
test_equal_attend_backward()
test_equal_attend_backward2()
#%%
@no_grad()
def decay_values_backward(d_out_w, d_out_u, k, v, beta):
NH, T, D = shape(None, k, v, beta)
#
# allocations
#
eye = torch.eye(D, device=k.device, dtype=k.dtype)
eye = eye.unsqueeze(0).expand(NH, D, D)
# recompute w and u TK-style
w = k.new_zeros(NH, T, D) # ntk
u = v.new_zeros(NH, T, D) # ntw
w_bases = w.clone()
u_bases = u.clone()
bk = einsum('nt,ntk->ntk', beta, k)
bv = einsum('nt,ntw->ntw', beta, v)
K = einsum('ntd,nsd->nts', k, k) # (T,T) matrix
K = K.tril(diagonal=-1) # make_causal(0); set_diagonal(0)
bKl = einsum('nt,nts->nts', -beta, K) # multiply each row of K by beta
d_k = k.new_zeros(NH, T, D) # nsk
d_beta = beta.new_zeros(NH, T) # ns
d_v = v.new_zeros(NH, T, D) # nsv
d_out_w_backward = d_out_w #.clone() # ntk # doesn't seem to need a copy but why?
d_out_u_backward = d_out_u.clone() # ntw
#
# forward
#
for t in range(T):
c_w = einsum('nts,nsk->ntk', bKl, w)
w[:, t] = bk[:, t, :] + c_w[:, t, :]
c_u = einsum('nts,nsw->ntw', bKl, u)
u[:, t] = bv[:, t, :] + c_u[:, t, :]
w_bases = k - einsum('nts,nsk->ntk', K, w)
u_bases = v - einsum('nts,nsw->ntw', K, u)
w0 = w.clone() # we will be mutating these, so store original w and u here
u0 = u.clone()
#
# backward for d_k, d_v, d_beta
#
for t in range(T-1,-1,-1):
w[:, t, :] = 0
k[:, t, :] = 0
u[:, t, :] = 0
wk = einsum('njw,njk->nwk', w, k)
wk = eye - wk
wk = einsum('n,nwk->nwk', beta[:, t], wk)
uk = einsum('njw,njk->nwk', u, k)
uk = einsum('n,nwk->nwk', beta[:, t], uk)
# d_k
d_k[:, t] += einsum('nw,nwk->nk', d_out_w_backward[:, t], wk)
d_k[:, t] -= einsum('nw,nwk->nk', d_out_u_backward[:, t], uk)
decay_w = einsum('nw,nsw->ns', d_out_w_backward[:, t], w[:, :t])
decay_u = einsum('nw,nsw->ns', d_out_u_backward[:, t], u[:, :t])
d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_w)
d_k[:, :t] -= einsum('nk,ns->nsk', bk[:, t], decay_u)
# backpropagate through time
d_out_w_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_w_backward[:, t])
d_out_u_backward[:, :t] += einsum('nj,nk->njk', bKl[:, t, :t], d_out_u_backward[:, t])
# d_beta
d_beta += einsum('ntk,ntk->nt', w_bases, d_out_w_backward)
d_beta += einsum('ntk,ntk->nt', u_bases, d_out_u_backward)
# d_v
d_v = einsum('nt,ntv->ntv', beta, d_out_u_backward)
return d_k, d_v, d_beta
class DecayValues(torch.autograd.Function):
@staticmethod
def forward(ctx, k, v, beta):
w, u = decay_values(k, v, beta)
ctx.save_for_backward(k, v, beta)
return w, u
@staticmethod
def backward(ctx, d_out_w, d_out_u):
k, v, beta = ctx.saved_tensors
return decay_values_backward(d_out_w, d_out_u, k, v, beta)
def test_equal_decay_values_backward():
NH, T, D = 1, 16, 3
q, k, v, beta = make_example(NH, T, D)
w, u = decay_values(k, v, beta)
(w + u - torch.ones_like(w)).pow(2).mean().backward()
#(w - torch.ones_like(w)).pow(2).mean().backward()
q1, k1, v1, beta1 = make_example(NH, T, D)
w1, u1 = DecayValues.apply(k1, v1, beta1)
(w1 + u1 - torch.ones_like(w1)).pow(2).mean().backward()
#(w1 - torch.ones_like(w1)).pow(2).mean().backward()
# print(v.grad, 'v.grad', v.grad.shape)
# print(v1.grad, 'v1.grad')
# print(v.grad - v1.grad, 'v diff')
assert allclose(v.grad, v1.grad, atol=1e-5), 'v1_grad is wrong'
# print(beta.grad, 'beta.grad du')
# print(beta1.grad, 'beta1.grad du')
assert allclose(beta.grad, beta1.grad, atol=1e-5), 'beta1.grad is wrong'
# print(k.grad, 'k.grad du')
# print(k1.grad, 'k1.grad du')
# print(k.grad - k1.grad, 'diff du')
assert allclose(k.grad, k1.grad, atol=1e-5), 'k1_grad is wrong'
test_equal_decay_values_backward()
# %%
def stitch1_backward(d_y, d_state, state, q, k, w, u):
state_decays = einsum('nvk,ntk->ntv', state, w)
d_q1 = einsum('nsv,nvk->nsk', d_y, state) # prev_output
d_state1 = einsum('nsv,nsk->nvk', d_y, q) # prev_output
d_q2, d_k1, d_state_decays1 = causal_attend_backward(-d_y, q, k, state_decays) # delta
d_k2 = einsum('nvk,ntv->ntk', d_state, u - state_decays) # state_add
d_u = einsum('nvk,ntk->ntv', d_state, k) # state_add
d_state_decays2 = einsum('nvk,ntk->ntv', -d_state, k) # state_add
d_state_decays = d_state_decays1 + d_state_decays2
d_w = einsum('ntv,nvk->ntk', d_state_decays, state) # state_decays
d_state2 = einsum('ntv,ntk->nvk', d_state_decays, w) # state_decays
return d_state + d_state1 + d_state2, d_q1 + d_q2, d_k1 + d_k2, d_w, d_u
class Stitch1(torch.autograd.Function):
@staticmethod
def forward(ctx, state, q, k, w, u):
y, new_state = stitch1_forward(state, q, k, w, u)
ctx.save_for_backward(state, q, k, w, u)
return y, new_state
@staticmethod
def backward(ctx, d_y, d_state):
state, q, k, w, u = ctx.saved_tensors
return stitch1_backward(d_y, d_state, state, q, k, w, u)
def stitch_backward(d_y_delta, q, k, w, u, C, chunk_size):
NH, T, D = shape(q, k, None, None)
d_y_delta = d_y_delta.view(NH, C, chunk_size, D)
q_ = q.view(NH, C, chunk_size, D)
k_ = k.view(NH, C, chunk_size, D)
u = u.view(NH, C, chunk_size, D)
w = w.view(NH, C, chunk_size, D)
d_q_ = q.new_zeros(NH, C, chunk_size, D)
d_k_ = k.new_zeros(NH, C, chunk_size, D)
d_w = w.new_zeros(NH, C, chunk_size, D)
d_u = u.new_zeros(NH, C, chunk_size, D)
d_state = w.new_zeros(NH, D, D) # NHVK
y_delta = u.new_zeros(NH, C, chunk_size, D) # leading chunk has zero delta
# storing all states for BPTT
states = k.new_zeros(NH, C, D, D) # NHCVK
# materialize the state for the leading chunk
states[:, 0] = einsum('ntv,ntk->nvk', u[:, 0], k_[:, 0])
for c in range(1, C):
y_delta[:, c], states[:, c] = stitch1_forward(states[:, c-1], q_[:, c], k_[:, c], w[:, c], u[:, c])
for c in range(C-1, 0, -1):
(
d_state, d_q_[:, c], d_k_[:, c], d_w[:, c], d_u[:, c]
) = stitch1_backward(d_y_delta[:, c], d_state, states[:, c-1], q_[:, c], k_[:, c], w[:, c], u[:, c])
(
d_state, d_q_[:, 0], d_k_[:, 0], d_w[:, 0], d_u[:, 0]
) = stitch1_backward(d_y_delta[:, 0], d_state, torch.zeros_like(d_state), q_[:, 0], k_[:, 0], w[:, 0], u[:, 0])
return d_q_.view(NH, T, D), d_k_.view(NH, T, D), d_w.view(NH, T, D), d_u.view(NH, T, D)
class Stitch(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, w, u, C, chunk_size):
y_delta, state = stitch_forward(q, k, w, u, C, chunk_size)
ctx.save_for_backward(q, k, w, u)
ctx.C = C
ctx.chunk_size = chunk_size
return y_delta
@staticmethod
def backward(ctx, d_y_delta):
q, k, w, u = ctx.saved_tensors
return *stitch_backward(d_y_delta, q, k, w, u, ctx.C, ctx.chunk_size), None, None
def test_stitch_all(atol=1e-5):
NH, T, D = 1, 4, 2
C, chunk_size = 2, 2
q, k, v, beta = make_example(NH, T, D)
w, u = decay_values(k, v, beta)
q.requires_grad_()
k.requires_grad_()
v.requires_grad_()
beta.requires_grad_()
w.retain_grad()
u.retain_grad()
y, new_state = stitch_forward(q, k, w, u, C=C, chunk_size=chunk_size)
loss = (y - torch.ones_like(y)).pow(2).mean()
loss.backward()
# print(q.grad, 'q.grad')
# print(k.grad, 'k.grad')
# print(v.grad, 'v.grad')
# print(beta.grad, 'beta.grad')
# print(w.grad, 'w.grad')
# print(u.grad, 'u.grad')
q1, k1, v1, beta1 = make_example(NH, T, D)
w1, u1 = decay_values(k1, v1, beta1)
q1.requires_grad_()
k1.requires_grad_()
v1.requires_grad_()
beta1.requires_grad_()
w1.retain_grad()
u1.retain_grad()
y1 = Stitch.apply(q1, k1, w1, u1, C, chunk_size)
loss = (y1 - torch.ones_like(y1)).pow(2).mean()
loss.backward()
assert allclose(y, y1, atol=atol), 'y is wrong'
assert allclose(u.grad, u1.grad, atol=atol), 'u.grad is wrong'
assert allclose(v.grad, v1.grad, atol=atol), 'v.grad is wrong'
# print(k.grad, 'k.grad')
# print(k1.grad, 'k1.grad')
assert allclose(k.grad, k1.grad, atol=atol), 'k.grad is wrong'
assert allclose(q.grad, q1.grad, atol=atol), 'q.grad is wrong'
assert allclose(beta.grad, beta1.grad, atol=atol), 'beta.grad is wrong'
assert allclose(w.grad, w1.grad, atol=atol), 'w.grad is wrong'
test_stitch_all()
#%%
class DeltaChunkwise(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, beta, chunk_size):
y = forward_chunkwise(q, k, v, beta, chunk_size)
ctx.save_for_backward(q, k, v, beta)
ctx.chunk_size = chunk_size
return y
@staticmethod
def backward(ctx, d_y):
q, k, v, beta = ctx.saved_tensors
NH, T, D = shape(q, k, v, beta)
chunk_size = ctx.chunk_size
C = T // chunk_size
q_, k_, v_, beta_ = (
q.view(NH*C, chunk_size, D), k.view(NH*C, chunk_size, D),
v.view(NH*C, chunk_size, D), beta.view(NH*C, chunk_size)
)
w, u = decay_values(k_, v_, beta_)
d_q_1, d_k_1, d_w, d_u1 = stitch_backward(d_y, q, k, w, u, C=C, chunk_size=chunk_size)
d_w = d_w.view(NH*C, chunk_size, D)
d_u1 = d_u1.view(NH*C, chunk_size, D)
d_y = d_y.view(NH*C, chunk_size, D)
u = u.view(NH*C, chunk_size, D)
d_q_2, d_k_2, d_u2 = causal_attend_backward(d_y, q_, k_, u)
d_u = d_u1 + d_u2
d_k_3, d_v_, d_beta_ = decay_values_backward(d_w, d_u, k_, v_, beta_)
d_q_1 = d_q_1.view(NH, T, D)
d_q_2 = d_q_2.view(NH, C, chunk_size, D)
d_q_2 = d_q_2.reshape(NH, T, D)
d_q = d_q_1 + d_q_2
d_k_2 = d_k_2.reshape(NH, T, D)
d_k_3 = d_k_3.reshape(NH, T, D)
d_k = d_k_1 + d_k_2 + d_k_3
d_v = d_v_.reshape(NH, T, D)
d_beta = d_beta_.view(NH, T)
return d_q, d_k, d_v, d_beta, None
def test_delta_chunkwise_backward():
NH, T, D = 2, 16, 2
q1, k1, v1, beta1 = make_example(NH, T, D)
y1 = forward_chunkwise(q1, k1, v1, beta1, chunk_size=2)
(y1 - torch.ones_like(y1).detach()).pow(2).mean().backward()
q, k, v, beta = make_example(NH, T, D)
y = DeltaChunkwise.apply(q, k, v, beta, 2)
(y - torch.ones_like(y).detach()).pow(2).mean().backward()
assert allclose(y1, y, atol=1e-5), 'y is wrong'
# print(beta1.grad - beta.grad, 'beta.grad diff')
# print(q1.grad - q.grad, 'q.grad diff')
# print(k1.grad - k.grad, 'k.grad diff')
# print(v1.grad - v.grad, 'v.grad diff')
assert allclose(q1.grad, q.grad, atol=1e-5), 'q.grad is wrong'
assert allclose(beta1.grad, beta.grad, atol=1e-5), 'beta.grad is wrong'
assert allclose(k1.grad, k.grad, atol=1e-5), 'k.grad is wrong'
assert allclose(v1.grad, v.grad, atol=1e-5), 'v.grad is wrong'
test_delta_chunkwise_backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment