Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active February 22, 2025 17:13
Show Gist options
  • Save Chillee/afc7eda51be08c2d40f44f15e4df1161 to your computer and use it in GitHub Desktop.
Save Chillee/afc7eda51be08c2d40f44f15e4df1161 to your computer and use it in GitHub Desktop.
Merge Attention
import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
torch.set_default_device('cuda')
q, k, v = [torch.randn(8, 8, 1024, 64, requires_grad=True) for _ in range(3)]
causal_mask = create_block_mask(lambda b, h, q_idx, kv_idx: q_idx >= kv_idx, None, None, 1024, 1024)
uncausal_mask = create_block_mask(lambda b, h, q_idx, kv_idx: q_idx < kv_idx, None, None, 1024, 1024)
ref_out = flex_attention(q, k, v)
causal_out, causal_lse = flex_attention(q, k, v, block_mask=causal_mask, return_lse=True)
uncausal_out, uncausal_lse = flex_attention(q, k, v, block_mask=uncausal_mask, return_lse=True)
# merge_attention(*attention(q, k1, v1), *attention(q, k2, v2)) == attention(q, cat(k1, k2), cat(v1, v2))
def merge_attention(a, lse_a, b, lse_b):
max_lse = torch.maximum(lse_a, lse_b)
lse_a = torch.exp(lse_a - max_lse)
lse_b = torch.exp(lse_b - max_lse)
out = ((a * lse_a[..., None] + b * lse_b[..., None]) / (lse_a + lse_b)[..., None])
return out
merge_out = merge_attention(causal_out, causal_lse, uncausal_out, uncausal_lse)
assert (ref_out - merge_out).abs().max() < 1e-5
ref_out.sum().backward()
ref_q_grad = q.grad
q.grad = None
merge_out.sum().backward()
assert (q.grad - ref_q_grad).abs().max() < 1e-5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment