Created
May 27, 2025 01:34
-
-
Save Birch-san/5a97046abd7200a90a2f37ef62e67741 to your computer and use it in GitHub Desktop.
Test stub for comparing jvp of memory-efficient attention against reference implementation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
from typing import NamedTuple, Optional | |
from typing_extensions import override | |
import torch | |
from torch import Tensor, no_grad, enable_grad | |
import torch.autograd.forward_ad as fwAD | |
from torch.autograd.function import FunctionCtx | |
from torch.nn import Linear, Module | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
from torch.nn.functional import scaled_dot_product_attention | |
from torch.nn.init import normal_ | |
from einops import rearrange | |
class AbstractAttn(Module, ABC): | |
def __init__( | |
self, | |
q_in_dim: int, | |
kv_in_dim: int, | |
out_dim: int, | |
q_heads = 8, | |
kv_heads = 8, | |
head_dim = 64, | |
device: Optional[torch.device | str | int] = None, | |
dtype: Optional[torch.dtype] = None, | |
): | |
self.call_super_init = True | |
factory_kwargs = {"device": device, "dtype": dtype} | |
super().__init__() | |
self.q_heads = q_heads | |
self.kv_heads = kv_heads | |
self.head_dim = head_dim | |
self.scale: float = head_dim ** -0.5 | |
self.q_proj = Linear(q_in_dim, self.q_heads * head_dim, bias=False, **factory_kwargs) | |
self.kv_proj = Linear(kv_in_dim, self.kv_heads * head_dim * 2, bias=False, **factory_kwargs) | |
self.o_proj = Linear(self.kv_heads * head_dim, out_dim, bias=False, **factory_kwargs) | |
def init(self, generator: Optional[torch.Generator] = None) -> None: | |
normal_(self.q_proj.weight, std=self.q_proj.in_features**-.5, generator=generator) | |
normal_(self.kv_proj.weight, std=self.kv_proj.in_features**-.5, generator=generator) | |
normal_(self.o_proj.weight, std=self.o_proj.in_features**-.5, generator=generator) | |
def forward(self, x: Tensor, cross_x: Optional[Tensor] = None) -> Tensor: | |
q: Tensor = self.q_proj(x) | |
q = rearrange(q, "... seq (heads chan) -> ... heads seq chan ", heads=self.q_heads) | |
kv: Tensor = self.kv_proj(x if cross_x is None else cross_x) | |
k, v = rearrange(kv, "... seq (proj heads chan) -> proj ... heads seq chan ", proj=2, heads=self.kv_heads).unbind() | |
a = self.attend(q, k, v) | |
a = rearrange(a, "... heads seq chan -> ... seq (heads chan)") | |
o = self.o_proj(a) | |
return o | |
@abstractmethod | |
def attend(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: | |
... | |
class SDPAAttn(AbstractAttn): | |
@override | |
def attend(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: | |
return scaled_dot_product_attention(q, k, v) | |
class ChunkAttnFn(torch.autograd.Function): | |
class AttnChunk(NamedTuple): | |
exp_v: Tensor | |
sum_exp: Tensor | |
max_sim: Tensor | |
class FnCtx(FunctionCtx): | |
scale: float | |
q_chunk_size: int | |
kv_chunk_size: int | |
generate_vmap_rule = True | |
@staticmethod | |
def attend_kv_chunk(ctx: 'ChunkAttnFn.FnCtx', q_chunk: Tensor, k_mT_chunk: Tensor, v_chunk: Tensor) -> Tensor: | |
sim = q_chunk @ k_mT_chunk | |
sim = sim * ctx.scale | |
max_sim = sim.detach().amax(dim=-1, keepdim=True) | |
exp_sim = torch.exp(sim - max_sim) | |
exp_v = exp_sim @ v_chunk | |
max_sim = max_sim.squeeze(dim=-1) | |
return ChunkAttnFn.AttnChunk(exp_v, exp_sim.sum(dim=-1), max_sim) | |
@staticmethod | |
def attend_q_chunk(ctx: 'ChunkAttnFn.FnCtx', q_chunk: Tensor, k: Tensor, v: Tensor) -> Tensor: | |
chunks: list[ChunkAttnFn.AttnChunk] = [ChunkAttnFn.attend_kv_chunk(ctx, q_chunk, k_chunk, v_chunk) for k_chunk, v_chunk in zip( | |
k.mT.split(ctx.kv_chunk_size, dim=-1), | |
v.split(ctx.kv_chunk_size, dim=-2), | |
)] | |
acc_chunk = ChunkAttnFn.AttnChunk(*map(torch.stack, zip(*chunks))) | |
exp_v, sum_exp, max_sim = acc_chunk | |
global_max = max_sim.amax(dim=0, keepdim=True) | |
max_diffs = torch.exp(max_sim - global_max) | |
exp_v *= max_diffs.unsqueeze(-1) | |
sum_exp *= max_diffs | |
all_exp_v = exp_v.sum(dim=0) | |
all_sum_exp = sum_exp.sum(dim=0).unsqueeze(-1) | |
return all_exp_v / all_sum_exp | |
@staticmethod | |
def attend(ctx: 'ChunkAttnFn.FnCtx', q: Tensor, k: Tensor, v: Tensor) -> Tensor: | |
q_chunks: tuple[Tensor, ...] = q.split(ctx.q_chunk_size, dim=-2) | |
return torch.cat([ChunkAttnFn.attend_q_chunk(ctx, q_chunk, k, v) for q_chunk in q_chunks], dim=-2) | |
@staticmethod | |
def forward(ctx: 'ChunkAttnFn.FnCtx', q: Tensor, k: Tensor, v: Tensor, scale: float, q_chunk_size: int, kv_chunk_size: int) -> Tensor: | |
ctx.scale = scale | |
ctx.q_chunk_size = q_chunk_size | |
ctx.kv_chunk_size = kv_chunk_size | |
ctx.save_for_forward(v) | |
return ChunkAttnFn.attend(ctx, q, k, v) | |
@staticmethod | |
def jvp(ctx: 'ChunkAttnFn.FnCtx', gq: Tensor, gk: Tensor, gv: Tensor, *_) -> Tensor: | |
scale: float = ctx.scale | |
q_chunk_size: int = ctx.q_chunk_size | |
kv_chunk_size: int = ctx.kv_chunk_size | |
# raise NotImplementedError("JVP not implemented for ChunkAttnFn") | |
print("JVP not implemented for ChunkAttnFn; returning gradient w.r.t query without modification") | |
return gq | |
class ChunkAttn(AbstractAttn): | |
def __init__( | |
self, | |
q_in_dim: int, | |
kv_in_dim: int, | |
out_dim: int, | |
q_heads = 8, | |
kv_heads = 8, | |
head_dim = 64, | |
device: Optional[torch.device | str | int] = None, | |
dtype: Optional[torch.dtype] = None, | |
q_chunk_size: int = 64, | |
kv_chunk_size: int = 64, | |
): | |
super().__init__( | |
q_in_dim=q_in_dim, | |
kv_in_dim=kv_in_dim, | |
out_dim=out_dim, | |
q_heads=q_heads, | |
kv_heads=kv_heads, | |
head_dim=head_dim, | |
device=device, | |
dtype=dtype, | |
) | |
self.q_chunk_size = q_chunk_size | |
self.kv_chunk_size = kv_chunk_size | |
@override | |
def attend(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: | |
return ChunkAttnFn.apply(q, k, v, self.scale, self.q_chunk_size, self.kv_chunk_size) | |
dtype = torch.float16 | |
model_dim = 320 | |
init_kwargs_common = { | |
"q_in_dim": model_dim, | |
"kv_in_dim": model_dim, | |
"out_dim": model_dim, | |
"q_heads": 8, | |
"kv_heads": 8, | |
"head_dim": 64, | |
"dtype": dtype, | |
} | |
with torch.device('meta'): | |
sdpa_attn = SDPAAttn(**init_kwargs_common).eval() | |
chunk_attn = ChunkAttn(**init_kwargs_common, q_chunk_size=128, kv_chunk_size=128).eval() | |
device = torch.device('cuda') | |
seed = 42 | |
gen = torch.Generator(device=device) | |
for attn in (chunk_attn, sdpa_attn): | |
attn.to_empty(device=device) | |
# attn.requires_grad_(False) | |
attn.init(generator=gen.manual_seed(seed)) | |
bsz = 1 | |
seq = 512 | |
cross_seq = 256 | |
x = torch.randn(bsz, seq, model_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed)) | |
cross_x = torch.randn(bsz, cross_seq, model_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed)) | |
# commented-out because the dual-tensor test will run forward and jvp both, making this forward-only test redundant | |
# with no_grad(): | |
# sdpa_out = sdpa_attn(x, cross_x) | |
# chunk_out = chunk_attn(x, cross_x) | |
# assert torch.allclose(chunk_out, sdpa_out, atol=5e-3), "chunked attn and SDPA don't match closely" | |
x_tan = torch.randn(bsz, seq, model_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed+1)) | |
cross_x_tan = torch.randn(bsz, cross_seq, model_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed+1)) | |
with fwAD.dual_level(), enable_grad(), sdpa_kernel(SDPBackend.MATH): | |
dual_x = fwAD.make_dual(x, x_tan) | |
dual_cross_x = fwAD.make_dual(cross_x, cross_x_tan) | |
sdpa_out_dual = sdpa_attn(dual_x, dual_cross_x) | |
sdpa_out_prime, sdpa_out_tangent = fwAD.unpack_dual(sdpa_out_dual) | |
chunk_out_dual = chunk_attn(dual_x, dual_cross_x) | |
chunk_out_prime, chunk_out_tangent = fwAD.unpack_dual(chunk_out_dual) | |
assert torch.allclose(chunk_out_prime, sdpa_out_prime, atol=5e-3), "chunked attn and SDPA primals don't match" | |
print("fwd was fine, now let's compare jvp outputs") | |
assert torch.allclose(chunk_out_tangent, sdpa_out_tangent, atol=5e-3), "chunked attn and SDPA tangents don't match" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
the other thing I tried, was to get rid of the custom torch autograd function and do it as a regular pytorch function that happens to output a dual tensor when it receives a dual tensor as input.
this is a bit of a mess though, as autograd will unnecessarily run on every operation here, not just the ones where I explicitly compute tangents.
also, my shapes don't match because I'm not sure what the maths in the paper wants.
and I'm not sure any of the variables are computed correctly.
but yeah, this is what I tried, if you want to approach this without a custom autograd function.