Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created May 27, 2025 01:34
Show Gist options
  • Save Birch-san/5a97046abd7200a90a2f37ef62e67741 to your computer and use it in GitHub Desktop.
Save Birch-san/5a97046abd7200a90a2f37ef62e67741 to your computer and use it in GitHub Desktop.
Test stub for comparing jvp of memory-efficient attention against reference implementation
from abc import ABC, abstractmethod
from typing import NamedTuple, Optional
from typing_extensions import override
import torch
from torch import Tensor, no_grad, enable_grad
import torch.autograd.forward_ad as fwAD
from torch.autograd.function import FunctionCtx
from torch.nn import Linear, Module
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.init import normal_
from einops import rearrange
class AbstractAttn(Module, ABC):
def __init__(
self,
q_in_dim: int,
kv_in_dim: int,
out_dim: int,
q_heads = 8,
kv_heads = 8,
head_dim = 64,
device: Optional[torch.device | str | int] = None,
dtype: Optional[torch.dtype] = None,
):
self.call_super_init = True
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.q_heads = q_heads
self.kv_heads = kv_heads
self.head_dim = head_dim
self.scale: float = head_dim ** -0.5
self.q_proj = Linear(q_in_dim, self.q_heads * head_dim, bias=False, **factory_kwargs)
self.kv_proj = Linear(kv_in_dim, self.kv_heads * head_dim * 2, bias=False, **factory_kwargs)
self.o_proj = Linear(self.kv_heads * head_dim, out_dim, bias=False, **factory_kwargs)
def init(self, generator: Optional[torch.Generator] = None) -> None:
normal_(self.q_proj.weight, std=self.q_proj.in_features**-.5, generator=generator)
normal_(self.kv_proj.weight, std=self.kv_proj.in_features**-.5, generator=generator)
normal_(self.o_proj.weight, std=self.o_proj.in_features**-.5, generator=generator)
def forward(self, x: Tensor, cross_x: Optional[Tensor] = None) -> Tensor:
q: Tensor = self.q_proj(x)
q = rearrange(q, "... seq (heads chan) -> ... heads seq chan ", heads=self.q_heads)
kv: Tensor = self.kv_proj(x if cross_x is None else cross_x)
k, v = rearrange(kv, "... seq (proj heads chan) -> proj ... heads seq chan ", proj=2, heads=self.kv_heads).unbind()
a = self.attend(q, k, v)
a = rearrange(a, "... heads seq chan -> ... seq (heads chan)")
o = self.o_proj(a)
return o
@abstractmethod
def attend(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
...
class SDPAAttn(AbstractAttn):
@override
def attend(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
return scaled_dot_product_attention(q, k, v)
class ChunkAttnFn(torch.autograd.Function):
class AttnChunk(NamedTuple):
exp_v: Tensor
sum_exp: Tensor
max_sim: Tensor
class FnCtx(FunctionCtx):
scale: float
q_chunk_size: int
kv_chunk_size: int
generate_vmap_rule = True
@staticmethod
def attend_kv_chunk(ctx: 'ChunkAttnFn.FnCtx', q_chunk: Tensor, k_mT_chunk: Tensor, v_chunk: Tensor) -> Tensor:
sim = q_chunk @ k_mT_chunk
sim = sim * ctx.scale
max_sim = sim.detach().amax(dim=-1, keepdim=True)
exp_sim = torch.exp(sim - max_sim)
exp_v = exp_sim @ v_chunk
max_sim = max_sim.squeeze(dim=-1)
return ChunkAttnFn.AttnChunk(exp_v, exp_sim.sum(dim=-1), max_sim)
@staticmethod
def attend_q_chunk(ctx: 'ChunkAttnFn.FnCtx', q_chunk: Tensor, k: Tensor, v: Tensor) -> Tensor:
chunks: list[ChunkAttnFn.AttnChunk] = [ChunkAttnFn.attend_kv_chunk(ctx, q_chunk, k_chunk, v_chunk) for k_chunk, v_chunk in zip(
k.mT.split(ctx.kv_chunk_size, dim=-1),
v.split(ctx.kv_chunk_size, dim=-2),
)]
acc_chunk = ChunkAttnFn.AttnChunk(*map(torch.stack, zip(*chunks)))
exp_v, sum_exp, max_sim = acc_chunk
global_max = max_sim.amax(dim=0, keepdim=True)
max_diffs = torch.exp(max_sim - global_max)
exp_v *= max_diffs.unsqueeze(-1)
sum_exp *= max_diffs
all_exp_v = exp_v.sum(dim=0)
all_sum_exp = sum_exp.sum(dim=0).unsqueeze(-1)
return all_exp_v / all_sum_exp
@staticmethod
def attend(ctx: 'ChunkAttnFn.FnCtx', q: Tensor, k: Tensor, v: Tensor) -> Tensor:
q_chunks: tuple[Tensor, ...] = q.split(ctx.q_chunk_size, dim=-2)
return torch.cat([ChunkAttnFn.attend_q_chunk(ctx, q_chunk, k, v) for q_chunk in q_chunks], dim=-2)
@staticmethod
def forward(ctx: 'ChunkAttnFn.FnCtx', q: Tensor, k: Tensor, v: Tensor, scale: float, q_chunk_size: int, kv_chunk_size: int) -> Tensor:
ctx.scale = scale
ctx.q_chunk_size = q_chunk_size
ctx.kv_chunk_size = kv_chunk_size
ctx.save_for_forward(v)
return ChunkAttnFn.attend(ctx, q, k, v)
@staticmethod
def jvp(ctx: 'ChunkAttnFn.FnCtx', gq: Tensor, gk: Tensor, gv: Tensor, *_) -> Tensor:
scale: float = ctx.scale
q_chunk_size: int = ctx.q_chunk_size
kv_chunk_size: int = ctx.kv_chunk_size
# raise NotImplementedError("JVP not implemented for ChunkAttnFn")
print("JVP not implemented for ChunkAttnFn; returning gradient w.r.t query without modification")
return gq
class ChunkAttn(AbstractAttn):
def __init__(
self,
q_in_dim: int,
kv_in_dim: int,
out_dim: int,
q_heads = 8,
kv_heads = 8,
head_dim = 64,
device: Optional[torch.device | str | int] = None,
dtype: Optional[torch.dtype] = None,
q_chunk_size: int = 64,
kv_chunk_size: int = 64,
):
super().__init__(
q_in_dim=q_in_dim,
kv_in_dim=kv_in_dim,
out_dim=out_dim,
q_heads=q_heads,
kv_heads=kv_heads,
head_dim=head_dim,
device=device,
dtype=dtype,
)
self.q_chunk_size = q_chunk_size
self.kv_chunk_size = kv_chunk_size
@override
def attend(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
return ChunkAttnFn.apply(q, k, v, self.scale, self.q_chunk_size, self.kv_chunk_size)
dtype = torch.float16
model_dim = 320
init_kwargs_common = {
"q_in_dim": model_dim,
"kv_in_dim": model_dim,
"out_dim": model_dim,
"q_heads": 8,
"kv_heads": 8,
"head_dim": 64,
"dtype": dtype,
}
with torch.device('meta'):
sdpa_attn = SDPAAttn(**init_kwargs_common).eval()
chunk_attn = ChunkAttn(**init_kwargs_common, q_chunk_size=128, kv_chunk_size=128).eval()
device = torch.device('cuda')
seed = 42
gen = torch.Generator(device=device)
for attn in (chunk_attn, sdpa_attn):
attn.to_empty(device=device)
# attn.requires_grad_(False)
attn.init(generator=gen.manual_seed(seed))
bsz = 1
seq = 512
cross_seq = 256
x = torch.randn(bsz, seq, model_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed))
cross_x = torch.randn(bsz, cross_seq, model_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed))
# commented-out because the dual-tensor test will run forward and jvp both, making this forward-only test redundant
# with no_grad():
# sdpa_out = sdpa_attn(x, cross_x)
# chunk_out = chunk_attn(x, cross_x)
# assert torch.allclose(chunk_out, sdpa_out, atol=5e-3), "chunked attn and SDPA don't match closely"
x_tan = torch.randn(bsz, seq, model_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed+1))
cross_x_tan = torch.randn(bsz, cross_seq, model_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed+1))
with fwAD.dual_level(), enable_grad(), sdpa_kernel(SDPBackend.MATH):
dual_x = fwAD.make_dual(x, x_tan)
dual_cross_x = fwAD.make_dual(cross_x, cross_x_tan)
sdpa_out_dual = sdpa_attn(dual_x, dual_cross_x)
sdpa_out_prime, sdpa_out_tangent = fwAD.unpack_dual(sdpa_out_dual)
chunk_out_dual = chunk_attn(dual_x, dual_cross_x)
chunk_out_prime, chunk_out_tangent = fwAD.unpack_dual(chunk_out_dual)
assert torch.allclose(chunk_out_prime, sdpa_out_prime, atol=5e-3), "chunked attn and SDPA primals don't match"
print("fwd was fine, now let's compare jvp outputs")
assert torch.allclose(chunk_out_tangent, sdpa_out_tangent, atol=5e-3), "chunked attn and SDPA tangents don't match"
@Birch-san
Copy link
Author

the other thing I tried, was to get rid of the custom torch autograd function and do it as a regular pytorch function that happens to output a dual tensor when it receives a dual tensor as input.

this is a bit of a mess though, as autograd will unnecessarily run on every operation here, not just the ones where I explicitly compute tangents.

also, my shapes don't match because I'm not sure what the maths in the paper wants.
and I'm not sure any of the variables are computed correctly.

but yeah, this is what I tried, if you want to approach this without a custom autograd function.

class ChunkAttn(AbstractAttn):
    class AttnChunk(NamedTuple):
        exp_v: Tensor
        sum_exp: Tensor
        max_sim: Tensor
        g: Optional[Tensor]
        mu: Optional[Tensor]

    def __init__(
        self,
        q_in_dim: int,
        kv_in_dim: int,
        out_dim: int,
        q_heads = 8,
        kv_heads = 8,
        head_dim = 64,
        device: Optional[torch.device | str | int] = None,
        dtype: Optional[torch.dtype] = None,
        q_chunk_size: int = 64,
        kv_chunk_size: int = 64,
    ):
        super().__init__(
            q_in_dim=q_in_dim,
            kv_in_dim=kv_in_dim,
            out_dim=out_dim,
            q_heads=q_heads,
            kv_heads=kv_heads,
            head_dim=head_dim,
            device=device,
            dtype=dtype,
        )
        self.q_chunk_size = q_chunk_size
        self.kv_chunk_size = kv_chunk_size

    def attend_kv_chunk(self, q_chunk: Tensor, k_mT_chunk: Tensor, v_chunk: Tensor) -> Tensor:
        sim = q_chunk @ k_mT_chunk
        sim = sim * self.scale
        max_sim = sim.detach().amax(dim=-1, keepdim=True)
        exp_sim = torch.exp(sim - max_sim)
        sim_tan: Optional[Tensor] = fwAD.unpack_dual(sim).tangent
        exp_v = exp_sim @ v_chunk
        max_sim = max_sim.squeeze(dim=-1)
        if sim_tan is None:
            g: Optional[Tensor] = None
            mu: Optional[Tensor] = None
        else:
            g_precursor = exp_sim * sim_tan
            g = g_precursor @ v_chunk
            mu = g_precursor.sum(dim=-1)

            # exp_v_tan = exp_sim @ fwAD.unpack_dual(v_chunk).tangent
            # exp_v = fwAD.make_dual(exp_v, exp_v_tan)

        return self.AttnChunk(exp_v, exp_sim.sum(dim=-1), max_sim, g, mu)

    def attend_q_chunk(self, q_chunk: Tensor, k: Tensor, v: Tensor) -> Tensor:
        chunks: list[ChunkAttn.AttnChunk] = [self.attend_kv_chunk(q_chunk, k_chunk, v_chunk) for k_chunk, v_chunk in zip(
            k.mT.split(self.kv_chunk_size, dim=-1),
            v.split(self.kv_chunk_size, dim=-2),
        )]
        # acc_chunk = self.AttnChunk(*map(torch.stack, zip(*chunks)))
        acc_chunk = self.AttnChunk(*[None if t[0] is None else torch.stack(t) for t in zip(*chunks)])
        exp_v, sum_exp, max_sim, g, mu = acc_chunk

        global_max = max_sim.amax(dim=0, keepdim=True)
        max_diffs = torch.exp(max_sim - global_max)
        exp_v *= max_diffs.unsqueeze(-1)
        sum_exp *= max_diffs

        all_exp_v = exp_v.sum(dim=0)
        all_sum_exp = sum_exp.sum(dim=0).unsqueeze(-1)
        a = all_exp_v / all_sum_exp

        if mu is None:
            return a

        all_mu = mu.sum(dim=0)
        # TODO: is this mean to be all_sum_exp or just sum_exp?
        # TODO: should we do this before multiplying by max_diffs?
        prob_tan_mm_v = g / all_sum_exp - (all_mu.unsqueeze(-1) / all_sum_exp) * a
        tan_a = prob_tan_mm_v + fwAD.unpack_dual(exp_v).tangent

        return fwAD.make_dual(a, tan_a)
        

    @override
    def attend(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        q_chunks: tuple[Tensor, ...] = q.split(self.q_chunk_size, dim=-2)
        return torch.cat([self.attend_q_chunk(q_chunk, k, v) for q_chunk in q_chunks], dim=-2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment