Last active
June 14, 2025 01:07
-
-
Save Birch-san/9ae247d7ebffe36a06ae54a75f614804 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
from argparse import ArgumentParser, Namespace | |
from dataclasses import dataclass | |
from functools import partial | |
from typing import Any, Callable, Optional | |
import torch | |
from torch import Tensor, no_grad, enable_grad | |
import torch.autograd.forward_ad as fwAD | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
from torch.nn.functional import scaled_dot_product_attention | |
from torch.utils.flop_counter import FlopCounterMode | |
import triton | |
import triton.language as tl | |
from triton.testing import do_bench | |
NiladicFn = Callable[[], None] | |
def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float: | |
iters_per_second = 1e3 / ms_per_iter | |
return iters_per_second * flop_count | |
def fmt_flops(flops: int) -> str: | |
return f"{flops / 1e12:5.1f} TFLOP/s" | |
# Python *please* bring back support for generic NamedTuples | |
def get_flop_count(f: Callable[[], Any], display_ops=True) -> int: | |
flop_counter = FlopCounterMode(display=display_ops) | |
with flop_counter: | |
f() | |
return flop_counter.get_total_flops() | |
# --- Triton Multi-Head JVP Kernel --- | |
@triton.autotune( | |
configs=[ | |
# Ultra-conservative configs for maximum compatibility | |
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16}, num_warps=2, num_stages=1), | |
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16}, num_warps=2, num_stages=1), | |
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 32}, num_warps=2, num_stages=1), | |
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32}, num_warps=4, num_stages=1), | |
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 16}, num_warps=4, num_stages=1), | |
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 64}, num_warps=4, num_stages=1), | |
], | |
key=['B', 'H', 'L', 'D_head'], | |
) | |
@triton.jit | |
def _flash_attention_jvp_multihead_kernel( | |
# Input tensors | |
Q, K, V, T_Q, T_K, T_V, | |
# Output tensors | |
Y, T_Y, | |
# Tensor strides | |
stride_qb, stride_qh, stride_ql, stride_qd, | |
stride_kb, stride_kh, stride_kl, stride_kd, | |
stride_vb, stride_vh, stride_vl, stride_vd, | |
stride_tqb, stride_tqh, stride_tql, stride_tqd, | |
stride_tkb, stride_tkh, stride_tkl, stride_tkd, | |
stride_tvb, stride_tvh, stride_tvl, stride_tvd, | |
stride_yb, stride_yh, stride_yl, stride_yd, | |
stride_tyb, stride_tyh, stride_tyl, stride_tyd, | |
# Problem dimensions | |
B: tl.constexpr, H: tl.constexpr, L: tl.constexpr, D_head: tl.constexpr, | |
# Scale factor | |
scale: tl.constexpr, | |
# Block sizes | |
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, | |
): | |
""" | |
Flash Attention JVP kernel following the reference implementation pattern. | |
Grid: (B*H, triton.cdiv(L, BLOCK_M)) | |
""" | |
# Get program IDs | |
pid_bh = tl.program_id(0) # Combined batch and head index | |
pid_m = tl.program_id(1) # Query block index | |
# Decompose batch and head indices | |
pid_b = pid_bh // H | |
pid_h = pid_bh % H | |
# Compute offsets | |
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
offs_d = tl.arange(0, D_head) | |
# Base pointers for this (batch, head) | |
q_base = Q + pid_b * stride_qb + pid_h * stride_qh | |
k_base = K + pid_b * stride_kb + pid_h * stride_kh | |
v_base = V + pid_b * stride_vb + pid_h * stride_vh | |
tq_base = T_Q + pid_b * stride_tqb + pid_h * stride_tqh | |
tk_base = T_K + pid_b * stride_tkb + pid_h * stride_tkh | |
tv_base = T_V + pid_b * stride_tvb + pid_h * stride_tvh | |
y_base = Y + pid_b * stride_yb + pid_h * stride_yh | |
ty_base = T_Y + pid_b * stride_tyb + pid_h * stride_tyh | |
# Load query block | |
q_ptrs = q_base + offs_m[:, None] * stride_ql + offs_d[None, :] * stride_qd | |
tq_ptrs = tq_base + offs_m[:, None] * stride_tql + offs_d[None, :] * stride_tqd | |
mask_m = offs_m < L | |
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0) | |
tq = tl.load(tq_ptrs, mask=mask_m[:, None], other=0.0) | |
# Initialize accumulators following Flash Attention pattern | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32) | |
g_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32) | |
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
p_tv_acc = tl.zeros([BLOCK_M, D_head], dtype=tl.float32) | |
# Scale factor for exp2 optimization (like reference) | |
qk_scale = scale * 1.44269504 # 1/log(2) | |
# Loop over key/value blocks | |
for start_n in range(0, L, BLOCK_N): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
offs_n_curr = start_n + offs_n | |
mask_n = offs_n_curr < L | |
# Load key and value blocks | |
k_ptrs = k_base + offs_n_curr[:, None] * stride_kl + offs_d[None, :] * stride_kd | |
v_ptrs = v_base + offs_n_curr[:, None] * stride_vl + offs_d[None, :] * stride_vd | |
tk_ptrs = tk_base + offs_n_curr[:, None] * stride_tkl + offs_d[None, :] * stride_tkd | |
tv_ptrs = tv_base + offs_n_curr[:, None] * stride_tvl + offs_d[None, :] * stride_tvd | |
k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) | |
v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) | |
tk = tl.load(tk_ptrs, mask=mask_n[:, None], other=0.0) | |
tv = tl.load(tv_ptrs, mask=mask_n[:, None], other=0.0) | |
# Compute attention scores | |
qk = tl.dot(q, tl.trans(k)) | |
tqk = tl.dot(tq, tl.trans(k)) + tl.dot(q, tl.trans(tk)) | |
# Mask invalid positions first | |
qk = tl.where(mask_n[None, :], qk, float('-inf')) | |
tqk = tl.where(mask_n[None, :], tqk, 0.0) | |
# Online softmax computation following Flash Attention | |
m_ij = tl.maximum(m_i, tl.max(qk * scale, 1)) | |
qk = qk * qk_scale - m_ij[:, None] # Scale and subtract max | |
p = tl.math.exp2(qk) # Use exp2 like reference | |
# Correction factor | |
alpha = tl.math.exp2(m_i - m_ij) | |
l_ij = tl.sum(p, 1) | |
# Update normalization | |
l_i = l_i * alpha + l_ij | |
# NOTE: this downcast of p is a new change compared to the reference implementation | |
# Cast p back to input dtype for matmul | |
p_typed = p.to(q.dtype) | |
# Update output accumulator | |
acc = acc * alpha[:, None] + tl.dot(p_typed, v) | |
# JVP accumulator: (p * tqk) @ v | |
p_tqk = p * (tqk * scale) # Apply scale to tangent scores | |
# NOTE: this downcast of p_tqk is a new change compared to the reference implementation | |
p_tqk_typed = p_tqk.to(q.dtype) # Cast tangent weights too | |
g_acc = g_acc * alpha[:, None] + tl.dot(p_tqk_typed, v) | |
# Update mu: sum(p * tqk) | |
mu_ij = tl.sum(p_tqk, 1) | |
mu_i = mu_i * alpha + mu_ij | |
# Update p @ tv accumulator | |
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p_typed, tv) | |
# Update max | |
m_i = m_ij | |
# Final computation - add log normalization and divide | |
m_i += tl.math.log2(l_i) | |
y_out = acc / l_i[:, None] | |
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * y_out | |
t_y_out = t_p_v + p_tv_acc / l_i[:, None] | |
# Store outputs | |
y_ptrs = y_base + offs_m[:, None] * stride_yl + offs_d[None, :] * stride_yd | |
ty_ptrs = ty_base + offs_m[:, None] * stride_tyl + offs_d[None, :] * stride_tyd | |
tl.store(y_ptrs, y_out, mask=mask_m[:, None]) | |
tl.store(ty_ptrs, t_y_out, mask=mask_m[:, None]) | |
def flash_attention_jvp_multihead_triton_kernel_wrapper( | |
Q: Tensor, | |
K: Tensor, | |
V: Tensor, | |
t_Q: Tensor, | |
t_K: Tensor, | |
t_V: Tensor, | |
scale: float = None | |
) -> tuple[Tensor, Tensor]: | |
""" | |
Python wrapper for the Multi-head Flash Attention JVP Triton kernel. | |
""" | |
device = Q.device | |
dtype = Q.dtype | |
B, H, L, D_head = Q.shape | |
# Check minimum dimension requirements for Triton | |
if D_head < 16: | |
raise ValueError(f"D_head must be >= 16 for efficient Triton kernel, got {D_head}") | |
if scale is None: | |
scale = 1.0 / (D_head ** 0.5) | |
# Ensure input shapes are correct | |
assert Q.shape == (B, H, L, D_head), f"Q shape mismatch: {Q.shape}" | |
assert K.shape == (B, H, L, D_head), f"K shape mismatch: {K.shape}" | |
assert V.shape == (B, H, L, D_head), f"V shape mismatch: {V.shape}" | |
assert t_Q.shape == (B, H, L, D_head), f"t_Q shape mismatch: {t_Q.shape}" | |
assert t_K.shape == (B, H, L, D_head), f"t_K shape mismatch: {t_K.shape}" | |
assert t_V.shape == (B, H, L, D_head), f"t_V shape mismatch: {t_V.shape}" | |
# Create output tensors | |
y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device) | |
t_y = torch.zeros((B, H, L, D_head), dtype=dtype, device=device) | |
# Make tensors contiguous | |
Qc = Q.contiguous() | |
Kc = K.contiguous() | |
Vc = V.contiguous() | |
t_Qc = t_Q.contiguous() | |
t_Kc = t_K.contiguous() | |
t_Vc = t_V.contiguous() | |
# Compute strides | |
stride_qb, stride_qh, stride_ql, stride_qd = Qc.stride() | |
stride_kb, stride_kh, stride_kl, stride_kd = Kc.stride() | |
stride_vb, stride_vh, stride_vl, stride_vd = Vc.stride() | |
stride_tqb, stride_tqh, stride_tql, stride_tqd = t_Qc.stride() | |
stride_tkb, stride_tkh, stride_tkl, stride_tkd = t_Kc.stride() | |
stride_tvb, stride_tvh, stride_tvl, stride_tvd = t_Vc.stride() | |
stride_yb, stride_yh, stride_yl, stride_yd = y.stride() | |
stride_tyb, stride_tyh, stride_tyl, stride_tyd = t_y.stride() | |
# Use block-based grid like Flash Attention | |
# Choose BLOCK_M based on autotuning, but ensure we cover all queries | |
BLOCK_M = 64 # Will be determined by autotuning | |
grid = (B * H, triton.cdiv(L, BLOCK_M)) | |
_flash_attention_jvp_multihead_kernel[grid]( | |
Qc, Kc, Vc, t_Qc, t_Kc, t_Vc, | |
y, t_y, | |
stride_qb, stride_qh, stride_ql, stride_qd, | |
stride_kb, stride_kh, stride_kl, stride_kd, | |
stride_vb, stride_vh, stride_vl, stride_vd, | |
stride_tqb, stride_tqh, stride_tql, stride_tqd, | |
stride_tkb, stride_tkh, stride_tkl, stride_tkd, | |
stride_tvb, stride_tvh, stride_tvl, stride_tvd, | |
stride_yb, stride_yh, stride_yl, stride_yd, | |
stride_tyb, stride_tyh, stride_tyl, stride_tyd, | |
B, H, L, D_head, | |
scale, | |
) | |
return y, t_y | |
def dual_jvp_mha( | |
Q: Tensor, | |
K: Tensor, | |
V: Tensor, | |
scale: Optional[float] = None, | |
): | |
q_p, q_t = fwAD.unpack_dual(Q) | |
k_p, k_t = fwAD.unpack_dual(K) | |
v_p, v_t = fwAD.unpack_dual(V) | |
with no_grad(): | |
a_p, a_t = flash_attention_jvp_multihead_triton_kernel_wrapper(q_p, k_p, v_p, q_t, k_t, v_t, scale) | |
return fwAD.make_dual(a_p, a_t) | |
@dataclass | |
class Args: | |
bsz: int | |
model_dim: int | |
head_dim: int | |
seq_len: int | |
@staticmethod | |
def get_parser() -> ArgumentParser: | |
parser = ArgumentParser() | |
parser.add_argument("--bsz", default=1, type=int) | |
parser.add_argument("--model-dim", default=320, type=int) | |
parser.add_argument("--head-dim", default=64, type=int) | |
parser.add_argument("--seq-len", default=128, type=int) | |
return parser | |
@staticmethod | |
def from_namespace(namespace: Namespace) -> Args: | |
args = Args(**vars(namespace)) | |
return args | |
def main(args: Args) -> None: | |
device = torch.device('cuda') | |
dtype = torch.float16 | |
seed = 42 | |
gen = torch.Generator(device=device) | |
heads = args.model_dim // args.head_dim | |
q_p, q_t, k_p, k_t, v_p, v_t = (torch.randn(args.bsz, heads, args.seq_len, args.head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(6)) | |
with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad(): | |
print("Math, fwd+jvp") | |
q, k, v = (fwAD.make_dual(p, t) for p, t in zip((q_p, k_p, v_p), (q_t, k_t, v_t))) | |
flop_count_math_jvp = get_flop_count(partial(scaled_dot_product_attention, q, k, v), display_ops=True) | |
q.grad = None | |
k.grad = None | |
v.grad = None | |
@torch.compile() | |
def sdpa_jvp() -> Tensor: | |
return scaled_dot_product_attention(q, k, v) | |
# torch.compile warned that f32 was being used in this benchmark… somehow. let's enable tf32 to get the best-case speed. | |
torch.set_float32_matmul_precision('high') | |
ms_per_iter_sdpa_jvp: float = do_bench(sdpa_jvp, grad_to_none=(q, k, v)) | |
print(f"SDPA JVP: {fmt_flops(mpi_to_flops(ms_per_iter_sdpa_jvp, flop_count_math_jvp))}") | |
# guess there's no point in torch.compile() because triton will do its own compilation | |
def triton_jvp() -> Tensor: | |
return dual_jvp_mha(q, k, v) | |
ms_per_iter_triton_jvp: float = do_bench(triton_jvp, grad_to_none=(q, k, v)) | |
print(f"Triton JVP: {fmt_flops(mpi_to_flops(ms_per_iter_triton_jvp, flop_count_math_jvp))}") | |
print("sanity-checking whether our latencies make sense by testing a known-fast operation") | |
with sdpa_kernel(SDPBackend.FLASH_ATTENTION), no_grad(): | |
print("Flash, fwd only") | |
flop_count_flash_fwd = get_flop_count(partial(scaled_dot_product_attention, q_p, k_p, v_p), display_ops=True) | |
@torch.compile() | |
def sdpa_fwd() -> Tensor: | |
return scaled_dot_product_attention(q, k, v) | |
ms_per_iter_sdpa_fwd: float = do_bench(sdpa_fwd, grad_to_none=(q, k, v)) | |
print(f"Flash fwd: {fmt_flops(mpi_to_flops(ms_per_iter_sdpa_fwd, flop_count_flash_fwd))}") | |
pass | |
if __name__ == "__main__": | |
parser = Args.get_parser() | |
args_untyped: Namespace = parser.parse_args() | |
args: Args = Args.from_namespace(args_untyped) | |
main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.