Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active June 14, 2025 01:07
Show Gist options
  • Save Birch-san/9ae247d7ebffe36a06ae54a75f614804 to your computer and use it in GitHub Desktop.
Save Birch-san/9ae247d7ebffe36a06ae54a75f614804 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Optional
import torch
from torch import Tensor, no_grad, enable_grad
import torch.autograd.forward_ad as fwAD
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from torch.utils.flop_counter import FlopCounterMode
import triton
import triton.language as tl
from triton.testing import do_bench
NiladicFn = Callable[[], None]
def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float:
iters_per_second = 1e3 / ms_per_iter
return iters_per_second * flop_count
def fmt_flops(flops: int) -> str:
return f"{flops / 1e12:5.1f} TFLOP/s"
# Python *please* bring back support for generic NamedTuples
def get_flop_count(f: Callable[[], Any], display_ops=True) -> int:
flop_counter = FlopCounterMode(display=display_ops)
with flop_counter:
f()
return flop_counter.get_total_flops()
# --- Triton Multi-Head JVP Kernel ---
@triton.autotune(
configs=[
# Ultra-conservative configs for maximum compatibility
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16}, num_warps=2, num_stages=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16}, num_warps=2, num_stages=1),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32}, num_warps=2, num_stages=1),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32}, num_warps=4, num_stages=1),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16}, num_warps=4, num_stages=1),
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64}, num_warps=4, num_stages=1),
],
key=['B', 'H', 'L', 'D_head'],
)
@triton.jit
def _flash_attention_jvp_multihead_kernel(
# Input tensors
Q, K, V, T_Q, T_K, T_V,
# Output tensors
Y, T_Y,
# Tensor strides
stride_qb, stride_qh, stride_ql, stride_qd,
stride_kb, stride_kh, stride_kl, stride_kd,
stride_vb, stride_vh, stride_vl, stride_vd,
stride_tqb, stride_tqh, stride_tql, stride_tqd,
stride_tkb, stride_tkh, stride_tkl, stride_tkd,
stride_tvb, stride_tvh, stride_tvl, stride_tvd,
stride_yb, stride_yh, stride_yl, stride_yd,
stride_tyb, stride_tyh, stride_tyl, stride_tyd,
# Problem dimensions
B: tl.constexpr, H: tl.constexpr, L: tl.constexpr, D_head: tl.constexpr,
# Scale factor
scale: tl.constexpr,
# Block sizes
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
"""
Flash Attention JVP kernel following the reference implementation pattern.
Grid: (B*H, triton.cdiv(L, BLOCK_M))
"""
# Get program IDs
pid_bh = tl.program_id(0) # Combined batch and head index
pid_m = tl.program_id(1) # Query block index
# Decompose batch and head indices
pid_b = pid_bh // H
pid_h = pid_bh % H
# Compute offsets
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, D_head)
# Base pointers for this (batch, head)
q_base = Q + pid_b * stride_qb + pid_h * stride_qh
k_base = K + pid_b * stride_kb + pid_h * stride_kh
v_base = V + pid_b * stride_vb + pid_h * stride_vh
tq_base = T_Q + pid_b * stride_tqb + pid_h * stride_tqh
tk_base = T_K + pid_b * stride_tkb + pid_h * stride_tkh
tv_base = T_V + pid_b * stride_tvb + pid_h * stride_tvh
y_base = Y + pid_b * stride_yb + pid_h * stride_yh
ty_base = T_Y + pid_b * stride_tyb + pid_h * stride_tyh
# Load query block
q_ptrs = q_base + offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd
tq_ptrs = tq_base + offs_m[:, None] * stride_tql + offs_d[None, :] * stride_tqd
mask_m = offs_m < L
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0)
tq = tl.load(tq_ptrs, mask=mask_m[:, None], other=0.0)
# Initialize accumulators following Flash Attention pattern
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32)
g_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32)
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32)
p_tv_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32)
# Scale factor for exp2 optimization (like reference)
qk_scale = scale * 1.44269504 # 1/log(2)
# Loop over key/value blocks
for start_n in range(0, L, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
offs_n_curr = start_n + offs_n
mask_n = offs_n_curr < L
# Load key and value blocks
k_ptrs = k_base + offs_n_curr[:, None] * stride_kl + offs_d[None, :] * stride_kd
v_ptrs = v_base + offs_n_curr[:, None] * stride_vl + offs_d[None, :] * stride_vd
tk_ptrs = tk_base + offs_n_curr[:, None] * stride_tkl + offs_d[None, :] * stride_tkd
tv_ptrs = tv_base + offs_n_curr[:, None] * stride_tvl + offs_d[None, :] * stride_tvd
k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0)
v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0)
tk = tl.load(tk_ptrs, mask=mask_n[:, None], other=0.0)
tv = tl.load(tv_ptrs, mask=mask_n[:, None], other=0.0)
# Compute attention scores
qk = tl.dot(q, tl.trans(k))
tqk = tl.dot(tq, tl.trans(k)) + tl.dot(q, tl.trans(tk))
# Mask invalid positions first
qk = tl.where(mask_n[None, :], qk, float('-inf'))
tqk = tl.where(mask_n[None, :], tqk, 0.0)
# Online softmax computation following Flash Attention
m_ij = tl.maximum(m_i, tl.max(qk * scale, 1))
qk = qk * qk_scale - m_ij[:, None] # Scale and subtract max
p = tl.math.exp2(qk) # Use exp2 like reference
# Correction factor
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)
# Update normalization
l_i = l_i * alpha + l_ij
# NOTE: this downcast of p is a new change compared to the reference implementation
# Cast p back to input dtype for matmul
p_typed = p.to(q.dtype)
# Update output accumulator
acc = acc * alpha[:, None] + tl.dot(p_typed, v)
# JVP accumulator: (p * tqk) @ v
p_tqk = p * (tqk * scale) # Apply scale to tangent scores
# NOTE: this downcast of p_tqk is a new change compared to the reference implementation
p_tqk_typed = p_tqk.to(q.dtype) # Cast tangent weights too
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk_typed, v)
# Update mu: sum(p * tqk)
mu_ij = tl.sum(p_tqk, 1)
mu_i = mu_i * alpha + mu_ij
# Update p @ tv accumulator
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p_typed, tv)
# Update max
m_i = m_ij
# Final computation - add log normalization and divide
m_i += tl.math.log2(l_i)
y_out = acc / l_i[:, None]
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * y_out
t_y_out = t_p_v + p_tv_acc / l_i[:, None]
# Store outputs
y_ptrs = y_base + offs_m[:, None] * stride_yl + offs_d[None, :] * stride_yd
ty_ptrs = ty_base + offs_m[:, None] * stride_tyl + offs_d[None, :] * stride_tyd
tl.store(y_ptrs, y_out, mask=mask_m[:, None])
tl.store(ty_ptrs, t_y_out, mask=mask_m[:, None])
def flash_attention_jvp_multihead_triton_kernel_wrapper(
Q: Tensor,
K: Tensor,
V: Tensor,
t_Q: Tensor,
t_K: Tensor,
t_V: Tensor,
scale: float = None
) -> tuple[Tensor, Tensor]:
"""
Python wrapper for the Multi-head Flash Attention JVP Triton kernel.
"""
device = Q.device
dtype = Q.dtype
B, H, L, D_head = Q.shape
# Check minimum dimension requirements for Triton
if D_head < 16:
raise ValueError(f"D_head must be >= 16 for efficient Triton kernel, got {D_head}")
if scale is None:
scale = 1.0 / (D_head ** 0.5)
# Ensure input shapes are correct
assert Q.shape == (B, H, L, D_head), f"Q shape mismatch: {Q.shape}"
assert K.shape == (B, H, L, D_head), f"K shape mismatch: {K.shape}"
assert V.shape == (B, H, L, D_head), f"V shape mismatch: {V.shape}"
assert t_Q.shape == (B, H, L, D_head), f"t_Q shape mismatch: {t_Q.shape}"
assert t_K.shape == (B, H, L, D_head), f"t_K shape mismatch: {t_K.shape}"
assert t_V.shape == (B, H, L, D_head), f"t_V shape mismatch: {t_V.shape}"
# Create output tensors
y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device)
t_y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device)
# Make tensors contiguous
Qc = Q.contiguous()
Kc = K.contiguous()
Vc = V.contiguous()
t_Qc = t_Q.contiguous()
t_Kc = t_K.contiguous()
t_Vc = t_V.contiguous()
# Compute strides
stride_qb, stride_qh, stride_ql, stride_qd = Qc.stride()
stride_kb, stride_kh, stride_kl, stride_kd = Kc.stride()
stride_vb, stride_vh, stride_vl, stride_vd = Vc.stride()
stride_tqb, stride_tqh, stride_tql, stride_tqd = t_Qc.stride()
stride_tkb, stride_tkh, stride_tkl, stride_tkd = t_Kc.stride()
stride_tvb, stride_tvh, stride_tvl, stride_tvd = t_Vc.stride()
stride_yb, stride_yh, stride_yl, stride_yd = y.stride()
stride_tyb, stride_tyh, stride_tyl, stride_tyd = t_y.stride()
# Use block-based grid like Flash Attention
# Choose BLOCK_M based on autotuning, but ensure we cover all queries
BLOCK_M = 64 # Will be determined by autotuning
grid = (B * H, triton.cdiv(L, BLOCK_M))
_flash_attention_jvp_multihead_kernel[grid](
Qc, Kc, Vc, t_Qc, t_Kc, t_Vc,
y, t_y,
stride_qb, stride_qh, stride_ql, stride_qd,
stride_kb, stride_kh, stride_kl, stride_kd,
stride_vb, stride_vh, stride_vl, stride_vd,
stride_tqb, stride_tqh, stride_tql, stride_tqd,
stride_tkb, stride_tkh, stride_tkl, stride_tkd,
stride_tvb, stride_tvh, stride_tvl, stride_tvd,
stride_yb, stride_yh, stride_yl, stride_yd,
stride_tyb, stride_tyh, stride_tyl, stride_tyd,
B, H, L, D_head,
scale,
)
return y, t_y
def dual_jvp_mha(
Q: Tensor,
K: Tensor,
V: Tensor,
scale: Optional[float] = None,
):
q_p, q_t = fwAD.unpack_dual(Q)
k_p, k_t = fwAD.unpack_dual(K)
v_p, v_t = fwAD.unpack_dual(V)
with no_grad():
a_p, a_t = flash_attention_jvp_multihead_triton_kernel_wrapper(q_p, k_p, v_p, q_t, k_t, v_t, scale)
return fwAD.make_dual(a_p, a_t)
@dataclass
class Args:
bsz: int
model_dim: int
head_dim: int
seq_len: int
@staticmethod
def get_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("--bsz", default=1, type=int)
parser.add_argument("--model-dim", default=320, type=int)
parser.add_argument("--head-dim", default=64, type=int)
parser.add_argument("--seq-len", default=128, type=int)
return parser
@staticmethod
def from_namespace(namespace: Namespace) -> Args:
args = Args(**vars(namespace))
return args
def main(args: Args) -> None:
device = torch.device('cuda')
dtype = torch.float16
seed = 42
gen = torch.Generator(device=device)
heads = args.model_dim // args.head_dim
q_p, q_t, k_p, k_t, v_p, v_t = (torch.randn(args.bsz, heads, args.seq_len, args.head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(6))
with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad():
print("Math, fwd+jvp")
q, k, v = (fwAD.make_dual(p, t) for p, t in zip((q_p, k_p, v_p), (q_t, k_t, v_t)))
flop_count_math_jvp = get_flop_count(partial(scaled_dot_product_attention, q, k, v), display_ops=True)
q.grad = None
k.grad = None
v.grad = None
@torch.compile()
def sdpa_jvp() -> Tensor:
return scaled_dot_product_attention(q, k, v)
# torch.compile warned that f32 was being used in this benchmark… somehow. let's enable tf32 to get the best-case speed.
torch.set_float32_matmul_precision('high')
ms_per_iter_sdpa_jvp: float = do_bench(sdpa_jvp, grad_to_none=(q, k, v))
print(f"SDPA JVP: {fmt_flops(mpi_to_flops(ms_per_iter_sdpa_jvp, flop_count_math_jvp))}")
# guess there's no point in torch.compile() because triton will do its own compilation
def triton_jvp() -> Tensor:
return dual_jvp_mha(q, k, v)
ms_per_iter_triton_jvp: float = do_bench(triton_jvp, grad_to_none=(q, k, v))
print(f"Triton JVP: {fmt_flops(mpi_to_flops(ms_per_iter_triton_jvp, flop_count_math_jvp))}")
print("sanity-checking whether our latencies make sense by testing a known-fast operation")
with sdpa_kernel(SDPBackend.FLASH_ATTENTION), no_grad():
print("Flash, fwd only")
flop_count_flash_fwd = get_flop_count(partial(scaled_dot_product_attention, q_p, k_p, v_p), display_ops=True)
@torch.compile()
def sdpa_fwd() -> Tensor:
return scaled_dot_product_attention(q, k, v)
ms_per_iter_sdpa_fwd: float = do_bench(sdpa_fwd, grad_to_none=(q, k, v))
print(f"Flash fwd: {fmt_flops(mpi_to_flops(ms_per_iter_sdpa_fwd, flop_count_flash_fwd))}")
pass
if __name__ == "__main__":
parser = Args.get_parser()
args_untyped: Namespace = parser.parse_args()
args: Args = Args.from_namespace(args_untyped)
main(args)
@Birch-san
Copy link
Author

Birch-san commented Jun 14, 2025

python -m jvp_bench --bsz 1 --model-dim 2048 --head-dim 64 --seq-len 8192
Math, fwd+jvp
Module            FLOP    % Total
-----------  ---------  ---------
Global       1649.267B    100.00%
 - aten.bmm  1649.267B    100.00%

SDPA JVP: 138.9 TFLOP/s
Triton JVP: 497.6 TFLOP/s

sanity-checking whether our latencies make sense by testing a known-fast operation

Flash, fwd only
Module                                           FLOP    % Total
-------------------------------------------  --------  ---------
Global                                       549.756B    100.00%
 - aten._scaled_dot_product_flash_attention  549.756B    100.00%

Flash fwd: 314.1 TFLOP/s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment