Last active
August 11, 2025 08:35
-
-
Save Birch-san/c51234fe006cf1ffc680063abb4f572f to your computer and use it in GitHub Desktop.
Triton fused attention tutorial, updated with JVP support. Albeit with atol=1e-3 accuracy on JVP.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import annotations | |
""" | |
Fused Attention | |
=============== | |
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) | |
Credits: OpenAI kernel team | |
Extra Credits: | |
* Original flash attention paper (https://arxiv.org/abs/2205.14135) | |
* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) | |
Plus modifications to support JVP: | |
- formulation of flash JVP by Cheng Lu, Yang Song in https://arxiv.org/abs/2410.11081 | |
- reference triton implementation by Sofian Mejjoute | |
- reimplementing reference implementation as an autograd function with latest triton tutorial optimizations, by Alex Birch | |
- support for forward to receive tangents, so as to compute fwd and jvp together, autograd workaround by Emily (nshepperd) | |
- support for function transforms (e.g. torch.func.jvp) via the use of setup_context, by Shih-Ying Yeh | |
""" | |
from typing import Any, Literal, NamedTuple, Optional | |
import pytest | |
import torch | |
import os | |
import triton | |
import triton.language as tl | |
from torch import Tensor | |
from torch.autograd import Function | |
from torch.autograd.function import FunctionCtx | |
import torch.autograd.forward_ad as fwAD | |
try: | |
from triton.tools.tensor_descriptor import TensorDescriptor | |
HAS_TENSOR_DESC = True | |
except ModuleNotFoundError: | |
HAS_TENSOR_DESC = False | |
DEVICE = getattr(triton.runtime.driver.active, "get_active_torch_device", lambda: torch.device('cuda'))() | |
def is_hip(): | |
return triton.runtime.driver.active.get_current_target().backend == "hip" | |
def is_cuda(): | |
return triton.runtime.driver.active.get_current_target().backend == "cuda" | |
def supports_host_descriptor(): | |
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 | |
def supports_tma(): | |
return HAS_TENSOR_DESC and is_cuda() and torch.cuda.get_device_capability()[0] >= 9 | |
def is_blackwell(): | |
return is_cuda() and torch.cuda.get_device_capability()[0] == 10 | |
@triton.jit | |
def _attn_fwd_inner(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
K_block_ptr, V_block_ptr, # | |
T_K_block_ptr, T_V_block_ptr, # | |
dtype: tl.constexpr, start_m, qk_scale, sm_scale, # | |
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # | |
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # | |
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, # | |
ENABLE_JVP: tl.constexpr): | |
# range of values handled by this stage | |
if STAGE == 1: | |
lo, hi = 0, start_m * BLOCK_M | |
elif STAGE == 2: | |
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M | |
lo = tl.multiple_of(lo, BLOCK_M) | |
# causal = False | |
else: | |
lo, hi = 0, N_CTX | |
K_block_ptr = tl.advance(K_block_ptr, (0, lo)) | |
# NOTE: in fp8 mode, we may want to advance the V_block_ptr differently. | |
# I did try advancing by (0, lo) instead for fp8, but I got an illegal memory access. | |
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 | |
V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) | |
if ENABLE_JVP: | |
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, lo)) | |
T_V_block_ptr = tl.advance(T_V_block_ptr, (lo, 0)) | |
# loop over k, v and update accumulator | |
for start_n in range(lo, hi, BLOCK_N): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
# -- compute qk ---- | |
k = tl.load(K_block_ptr) | |
qk = tl.dot(q, k) | |
if ENABLE_JVP: | |
t_k = tl.load(T_K_block_ptr) | |
t_qk = tl.dot(t_q, k) + tl.dot(q, t_k) | |
if STAGE == 2: | |
mask = offs_m[:, None] >= (start_n + offs_n[None, :]) | |
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) | |
if ENABLE_JVP: | |
# Claude says "tangents should be masked with 0.0 since they represent derivatives". | |
t_qk = tl.where(mask, t_qk, 0.0) | |
# TODO: do we need a separate row maximum for qk_t? | |
m_ij = tl.maximum(m_i, tl.max(qk, 1)) | |
qk -= m_ij[:, None] | |
else: | |
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) | |
qk = qk * qk_scale - m_ij[:, None] | |
p = tl.math.exp2(qk) | |
l_ij = tl.sum(p, 1) | |
# -- update m_i and l_i | |
alpha = tl.math.exp2(m_i - m_ij) | |
l_i = l_i * alpha + l_ij | |
# -- update output accumulator -- | |
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): | |
BM: tl.constexpr = acc.shape[0] | |
BN: tl.constexpr = acc.shape[1] | |
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
acc0 = acc0 * alpha[:, None] | |
acc1 = acc1 * alpha[:, None] | |
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
acc = acc * alpha[:, None] | |
# update acc | |
v = tl.load(V_block_ptr) | |
# NOTE: we may need to transpose v if dtype == tl.float8e5 | |
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 | |
p = p.to(dtype) | |
if ENABLE_JVP: | |
p_tqk = p * (t_qk * sm_scale) | |
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): | |
BM: tl.constexpr = g_acc.shape[0] | |
BN: tl.constexpr = g_acc.shape[1] | |
g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
g_acc0 = g_acc0 * alpha[:, None] | |
g_acc1 = g_acc1 * alpha[:, None] | |
g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
g_acc = g_acc * alpha[:, None] | |
g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc) | |
mu_ij = tl.sum(p_tqk, 1) | |
mu_i = mu_i * alpha + mu_ij | |
t_v = tl.load(T_V_block_ptr) | |
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v) | |
T_V_block_ptr = tl.advance(T_V_block_ptr, (BLOCK_N, 0)) | |
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, BLOCK_N)) | |
acc = tl.dot(p, v, acc) | |
# update m_i and l_i | |
m_i = m_ij | |
# the fp8 PR made a change to how K and V are advanced here but I believe we already have that. | |
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 | |
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) | |
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) | |
return acc, g_acc, l_i, m_i, mu_i, p_tv_acc | |
@triton.jit | |
def _attn_fwd_inner_tma(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
desc_k, desc_v, # | |
desc_k_t, desc_v_t, # | |
offset_y, dtype: tl.constexpr, start_m, qk_scale, sm_scale, # | |
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # | |
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # | |
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, | |
ENABLE_JVP: tl.constexpr): | |
# range of values handled by this stage | |
if STAGE == 1: | |
lo, hi = 0, start_m * BLOCK_M | |
elif STAGE == 2: | |
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M | |
lo = tl.multiple_of(lo, BLOCK_M) | |
# causal = False | |
else: | |
lo, hi = 0, N_CTX | |
offsetk_y = offset_y + lo | |
if dtype == tl.float8e5: | |
offsetv_y = offset_y * HEAD_DIM + lo | |
else: | |
offsetv_y = offset_y + lo | |
# loop over k, v and update accumulator | |
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
# -- compute qk ---- | |
k = desc_k.load([offsetk_y, 0]).T | |
qk = tl.dot(q, k) | |
if ENABLE_JVP: | |
t_k = desc_k_t.load([offsetk_y, 0]).T | |
t_qk = tl.dot(t_q, k) + tl.dot(q, t_k) | |
if STAGE == 2: | |
mask = offs_m[:, None] >= (start_n + offs_n[None, :]) | |
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) | |
if ENABLE_JVP: | |
# Claude says "tangents should be masked with 0.0 since they represent derivatives". | |
t_qk = tl.where(mask, t_qk, 0.0) | |
m_ij = tl.maximum(m_i, tl.max(qk, 1)) | |
qk -= m_ij[:, None] | |
else: | |
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) | |
qk = qk * qk_scale - m_ij[:, None] | |
p = tl.math.exp2(qk) | |
# -- compute correction factor | |
alpha = tl.math.exp2(m_i - m_ij) | |
l_ij = tl.sum(p, 1) | |
# -- update output accumulator -- | |
if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128: | |
BM: tl.constexpr = acc.shape[0] | |
BN: tl.constexpr = acc.shape[1] | |
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
acc0 = acc0 * alpha[:, None] | |
acc1 = acc1 * alpha[:, None] | |
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
acc = acc * alpha[:, None] | |
# prepare p and v for the dot | |
if dtype == tl.float8e5: | |
v = desc_v.load([0, offsetv_y]).T | |
else: | |
v = desc_v.load([offsetv_y, 0]) | |
p = p.to(dtype) | |
if ENABLE_JVP: | |
p_tqk = p * (t_qk * sm_scale) | |
# this non-transposed v for FP8 is presumably only supported on Blackwell | |
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): | |
BM: tl.constexpr = g_acc.shape[0] | |
BN: tl.constexpr = g_acc.shape[1] | |
g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
g_acc0 = g_acc0 * alpha[:, None] | |
g_acc1 = g_acc1 * alpha[:, None] | |
g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
g_acc = g_acc * alpha[:, None] | |
g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc) | |
mu_ij = tl.sum(p_tqk, 1) | |
mu_i = mu_i * alpha + mu_ij | |
t_v = desc_v_t.load([offsetv_y, 0]) | |
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v) | |
# note that this non transposed v for FP8 is only supported on Blackwell | |
acc = tl.dot(p, v, acc) | |
# update m_i and l_i | |
# place this at the end of the loop to reduce register pressure | |
l_i = l_i * alpha + l_ij | |
m_i = m_ij | |
offsetk_y += BLOCK_N | |
offsetv_y += BLOCK_N | |
return acc, g_acc, l_i, m_i, mu_i, p_tv_acc | |
def _host_descriptor_pre_hook(nargs): | |
BLOCK_M = nargs["BLOCK_M"] | |
BLOCK_N = nargs["BLOCK_N"] | |
HEAD_DIM = nargs["HEAD_DIM"] | |
if not HAS_TENSOR_DESC or not isinstance(nargs["desc_q"], TensorDescriptor): | |
return | |
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] | |
if nargs["FP8_OUTPUT"]: | |
nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] | |
else: | |
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] | |
if is_hip(): | |
NUM_STAGES_OPTIONS = [1] | |
elif supports_host_descriptor(): | |
NUM_STAGES_OPTIONS = [2, 3, 4] | |
else: | |
NUM_STAGES_OPTIONS = [2, 3, 4] | |
configs = [ | |
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \ | |
for BM in [64, 128]\ | |
for BN in [32, 64, 128]\ | |
for s in NUM_STAGES_OPTIONS \ | |
for w in [4, 8]\ | |
] | |
if "PYTEST_VERSION" in os.environ: | |
# Use a single config in testing for reproducibility | |
configs = [ | |
triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook), | |
] | |
def keep(conf): | |
BLOCK_M = conf.kwargs["BLOCK_M"] | |
BLOCK_N = conf.kwargs["BLOCK_N"] | |
return not (BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8) | |
def prune_invalid_configs(configs, named_args, **kwargs): | |
N_CTX = kwargs["N_CTX"] | |
# Filter out configs where BLOCK_M > N_CTX | |
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX] | |
@triton.jit | |
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): | |
if isinstance(desc_or_ptr, tl.tensor_descriptor): | |
return desc_or_ptr | |
else: | |
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) | |
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], | |
prune_configs_by={'early_config_prune': prune_invalid_configs}) | |
@triton.jit | |
def _attn_fwd(Q, K, V, T_Q, T_K, T_V, # | |
sm_scale, M, Out, T_Out, # | |
stride_qz, stride_qh, stride_qm, stride_qk, # | |
stride_kz, stride_kh, stride_kn, stride_kk, # | |
stride_vz, stride_vh, stride_vk, stride_vn, # | |
stride_tqz, stride_tqh, stride_tqm, stride_tqk, # | |
stride_tkz, stride_tkh, stride_tkn, stride_tkk, # | |
stride_tvz, stride_tvh, stride_tvk, stride_tvn, # | |
stride_oz, stride_oh, stride_om, stride_on, # | |
stride_toz, stride_toh, stride_tom, stride_ton, # | |
Z, H, N_CTX, # | |
HEAD_DIM: tl.constexpr, # | |
BLOCK_M: tl.constexpr, # | |
BLOCK_N: tl.constexpr, # | |
FP8_OUTPUT: tl.constexpr, # | |
STAGE: tl.constexpr, # | |
warp_specialize: tl.constexpr, # | |
ENABLE_JVP: tl.constexpr, # | |
): | |
tl.static_assert(BLOCK_N <= HEAD_DIM) | |
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
off_z = off_hz // H | |
off_h = off_hz % H | |
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh | |
# block pointers | |
Q_block_ptr = tl.make_block_ptr( | |
base=Q + qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_qm, stride_qk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) | |
V_block_ptr = tl.make_block_ptr( | |
base=V + qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_vk, stride_vn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, HEAD_DIM), | |
order=v_order, | |
) | |
K_block_ptr = tl.make_block_ptr( | |
base=K + qvk_offset, | |
shape=(HEAD_DIM, N_CTX), | |
strides=(stride_kk, stride_kn), | |
offsets=(0, 0), | |
block_shape=(HEAD_DIM, BLOCK_N), | |
order=(0, 1), | |
) | |
O_block_ptr = tl.make_block_ptr( | |
base=Out + qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_om, stride_on), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
# initialize pointer to m and l | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 | |
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
if ENABLE_JVP: | |
# it's extremely likely we could just re-use qvk_offset, but this seems cheap so whatever | |
t_qvk_offset = off_z.to(tl.int64) * stride_tqz + off_h.to(tl.int64) * stride_tqh | |
T_Q_block_ptr = tl.make_block_ptr( | |
base=T_Q + t_qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_tqm, stride_tqk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
# could probably just re-use v_order here | |
t_v_order: tl.constexpr = (0, 1) if T_V.dtype.element_ty == tl.float8e5 else (1, 0) | |
T_V_block_ptr = tl.make_block_ptr( | |
base=T_V + t_qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_tvk, stride_tvn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, HEAD_DIM), | |
order=t_v_order, | |
) | |
T_K_block_ptr = tl.make_block_ptr( | |
base=T_K + t_qvk_offset, | |
shape=(HEAD_DIM, N_CTX), | |
strides=(stride_tkk, stride_tkn), | |
offsets=(0, 0), | |
block_shape=(HEAD_DIM, BLOCK_N), | |
order=(0, 1), | |
) | |
T_O_block_ptr = tl.make_block_ptr( | |
base=T_Out + t_qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_tom, stride_ton), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
# load q_t: it will stay in SRAM throughout | |
t_q = tl.load(T_Q_block_ptr) | |
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
else: | |
t_q = None | |
T_V_block_ptr = None | |
T_K_block_ptr = None | |
# Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner | |
g_acc = tl.zeros([1, 1], dtype=tl.float32) | |
mu_i = tl.zeros([1], dtype=tl.float32) | |
p_tv_acc = tl.zeros([1, 1], dtype=tl.float32) | |
# load scales | |
qk_scale = sm_scale | |
qk_scale = qk_scale * 1.44269504 # 1/log(2) | |
# load q: it will stay in SRAM throughout | |
q = tl.load(Q_block_ptr) | |
# stage 1: off-band | |
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE | |
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE | |
if STAGE & 1: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc, | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
K_block_ptr, V_block_ptr, # | |
T_K_block_ptr, T_V_block_ptr, # | |
dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
4 - STAGE, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# stage 2: on-band | |
if STAGE & 2: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
K_block_ptr, V_block_ptr, # | |
T_K_block_ptr, T_V_block_ptr, # | |
dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
2, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# epilogue | |
m_i += tl.math.log2(l_i) | |
acc = acc / l_i[:, None] | |
m_ptrs = M + off_hz * N_CTX + offs_m | |
tl.store(m_ptrs, m_i) | |
tl.store(O_block_ptr, acc.to(Out.type.element_ty)) | |
if ENABLE_JVP: | |
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc | |
t_y_out = t_p_v + p_tv_acc / l_i[:, None] | |
tl.store(T_O_block_ptr, t_y_out.to(T_Out.type.element_ty)) | |
def _tma_pre_hook(nargs): | |
BLOCK_M = nargs["BLOCK_M"] | |
BLOCK_N = nargs["BLOCK_N"] | |
HEAD_DIM = nargs["HEAD_DIM"] | |
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] | |
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] | |
# We don't run auto-tuning every time to keep the tutorial fast. Keeping | |
# the code below and commenting out the equivalent parameters is convenient for | |
# re-tuning. | |
configs_tma = [ | |
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_tma_pre_hook) \ | |
for BM in [64, 128, 256]\ | |
for BN in [64, 128]\ | |
for s in [3, 4, 5]\ | |
for w in [4, 8]\ | |
] | |
def keep_tma(conf): | |
BLOCK_M = conf.kwargs["BLOCK_M"] | |
BLOCK_N = conf.kwargs["BLOCK_N"] | |
return not (is_cuda() and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 | |
and conf.num_warps == 8) | |
@triton.autotune(configs=list(filter(keep_tma, configs_tma)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], | |
prune_configs_by={'early_config_prune': prune_invalid_configs}) | |
@triton.jit | |
def _attn_fwd_tma(sm_scale, M, # | |
Z, H, # | |
desc_q, desc_k, desc_v, # | |
desc_q_t, desc_k_t, desc_v_t, # | |
desc_o, desc_o_t, # | |
N_CTX, # | |
HEAD_DIM: tl.constexpr, # | |
BLOCK_M: tl.constexpr, # | |
BLOCK_N: tl.constexpr, # | |
FP8_OUTPUT: tl.constexpr, # | |
STAGE: tl.constexpr, # | |
warp_specialize: tl.constexpr, # | |
ENABLE_JVP: tl.constexpr, # | |
): | |
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 | |
tl.static_assert(BLOCK_N <= HEAD_DIM) | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
off_z = off_hz // H | |
off_h = off_hz % H | |
y_dim = Z * H * N_CTX | |
desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
if FP8_OUTPUT: | |
v_shape = [HEAD_DIM, y_dim] | |
v_strides = [N_CTX, 1] | |
v_block_shape = [HEAD_DIM, BLOCK_N] | |
else: | |
v_shape = [y_dim, HEAD_DIM] | |
v_strides = [HEAD_DIM, 1] | |
v_block_shape = [BLOCK_N, HEAD_DIM] | |
desc_v = _maybe_make_tensor_desc(desc_v, shape=v_shape, strides=v_strides, block_shape=v_block_shape) | |
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_N, HEAD_DIM]) | |
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
offset_y = off_z * (N_CTX * H) + off_h * N_CTX | |
qo_offset_y = offset_y + start_m * BLOCK_M | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
# initialize pointer to m and l | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 | |
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
if ENABLE_JVP: | |
desc_q_t = _maybe_make_tensor_desc(desc_q_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
if FP8_OUTPUT: | |
t_v_shape = [HEAD_DIM, y_dim] | |
t_v_strides = [N_CTX, 1] | |
t_v_block_shape = [HEAD_DIM, BLOCK_N] | |
else: | |
t_v_shape = [y_dim, HEAD_DIM] | |
t_v_strides = [HEAD_DIM, 1] | |
t_v_block_shape = [BLOCK_N, HEAD_DIM] | |
desc_v_t = _maybe_make_tensor_desc(desc_v_t, shape=t_v_shape, strides=t_v_strides, block_shape=t_v_block_shape) | |
desc_k_t = _maybe_make_tensor_desc(desc_k_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_N, HEAD_DIM]) | |
desc_o_t = _maybe_make_tensor_desc(desc_o_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
# load t_q: it will stay in SRAM throughout | |
t_q = desc_q_t.load([qo_offset_y, 0]) | |
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
else: | |
t_q = None | |
desc_k_t = None | |
desc_v_t = None | |
# Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner_tma | |
g_acc = tl.zeros([1, 1], dtype=tl.float32) | |
mu_i = tl.zeros([1], dtype=tl.float32) | |
p_tv_acc = tl.zeros([1, 1], dtype=tl.float32) | |
# load scales | |
qk_scale = sm_scale | |
qk_scale *= 1.44269504 # 1/log(2) | |
# load q: it will stay in SRAM throughout | |
q = desc_q.load([qo_offset_y, 0]) | |
# stage 1: off-band | |
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE | |
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE | |
if STAGE & 1: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
desc_k, desc_v, # | |
desc_k_t, desc_v_t, # | |
offset_y, dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
4 - STAGE, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# stage 2: on-band | |
if STAGE & 2: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
desc_k, desc_v, # | |
desc_k_t, desc_v_t, # | |
offset_y, dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
2, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# epilogue | |
m_i += tl.math.log2(l_i) | |
acc = acc / l_i[:, None] | |
m_ptrs = M + off_hz * N_CTX + offs_m | |
tl.store(m_ptrs, m_i) | |
desc_o.store([qo_offset_y, 0], acc.to(dtype)) | |
if ENABLE_JVP: | |
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc | |
t_y_out = t_p_v + p_tv_acc / l_i[:, None] | |
desc_o_t.store([qo_offset_y, 0], t_y_out.to(dtype)) | |
@triton.jit | |
def _attn_bwd_preprocess(O, DO, # | |
Delta, # | |
Z, H, N_CTX, # | |
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # | |
): | |
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) | |
off_hz = tl.program_id(1) | |
off_n = tl.arange(0, HEAD_DIM) | |
# load | |
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) | |
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) | |
delta = tl.sum(o * do, axis=1) | |
# write-back | |
tl.store(Delta + off_hz * N_CTX + off_m, delta) | |
# The main inner-loop logic for computing dK and dV. | |
@triton.jit | |
def _attn_bwd_dkdv(dk, dv, # | |
Q, k, v, sm_scale, # | |
DO, # | |
M, D, # | |
# shared by Q/K/V/DO. | |
stride_tok, stride_d, # | |
H, N_CTX, BLOCK_M1: tl.constexpr, # | |
BLOCK_N1: tl.constexpr, # | |
HEAD_DIM: tl.constexpr, # | |
# Filled in by the wrapper. | |
start_n, start_m, num_steps, # | |
MASK: tl.constexpr): | |
offs_m = start_m + tl.arange(0, BLOCK_M1) | |
offs_n = start_n + tl.arange(0, BLOCK_N1) | |
offs_k = tl.arange(0, HEAD_DIM) | |
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d | |
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d | |
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) | |
curr_m = start_m | |
step_m = BLOCK_M1 | |
for blk_idx in range(num_steps): | |
qT = tl.load(qT_ptrs) | |
# Load m before computing qk to reduce pipeline stall. | |
offs_m = curr_m + tl.arange(0, BLOCK_M1) | |
m = tl.load(M + offs_m) | |
qkT = tl.dot(k, qT) | |
pT = tl.math.exp2(qkT - m[None, :]) | |
# Autoregressive masking. | |
if MASK: | |
mask = (offs_m[None, :] >= offs_n[:, None]) | |
pT = tl.where(mask, pT, 0.0) | |
do = tl.load(do_ptrs) | |
# Compute dV. | |
ppT = pT | |
ppT = ppT.to(tl.float16) | |
dv += tl.dot(ppT, do) | |
# D (= delta) is pre-divided by ds_scale. | |
Di = tl.load(D + offs_m) | |
# Compute dP and dS. | |
dpT = tl.dot(v, tl.trans(do)).to(tl.float32) | |
dsT = pT * (dpT - Di[None, :]) | |
dsT = dsT.to(tl.float16) | |
dk += tl.dot(dsT, tl.trans(qT)) | |
# Increment pointers. | |
curr_m += step_m | |
qT_ptrs += step_m * stride_tok | |
do_ptrs += step_m * stride_tok | |
return dk, dv | |
# the main inner-loop logic for computing dQ | |
@triton.jit | |
def _attn_bwd_dq(dq, q, K, V, # | |
do, m, D, | |
# shared by Q/K/V/DO. | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M2: tl.constexpr, # | |
BLOCK_N2: tl.constexpr, # | |
HEAD_DIM: tl.constexpr, | |
# Filled in by the wrapper. | |
start_m, start_n, num_steps, # | |
MASK: tl.constexpr): | |
offs_m = start_m + tl.arange(0, BLOCK_M2) | |
offs_n = start_n + tl.arange(0, BLOCK_N2) | |
offs_k = tl.arange(0, HEAD_DIM) | |
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d | |
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d | |
# D (= delta) is pre-divided by ds_scale. | |
Di = tl.load(D + offs_m) | |
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) | |
curr_n = start_n | |
step_n = BLOCK_N2 | |
for blk_idx in range(num_steps): | |
kT = tl.load(kT_ptrs) | |
vT = tl.load(vT_ptrs) | |
qk = tl.dot(q, kT) | |
p = tl.math.exp2(qk - m) | |
# Autoregressive masking. | |
if MASK: | |
offs_n = curr_n + tl.arange(0, BLOCK_N2) | |
mask = (offs_m[:, None] >= offs_n[None, :]) | |
p = tl.where(mask, p, 0.0) | |
# Compute dP and dS. | |
dp = tl.dot(do, vT).to(tl.float32) | |
ds = p * (dp - Di[:, None]) | |
ds = ds.to(tl.float16) | |
# Compute dQ. | |
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled. | |
dq += tl.dot(ds, tl.trans(kT)) | |
# Increment pointers. | |
curr_n += step_n | |
kT_ptrs += step_n * stride_tok | |
vT_ptrs += step_n * stride_tok | |
return dq | |
@triton.jit | |
def _attn_bwd(Q, K, V, sm_scale, # | |
DO, # | |
DQ, DK, DV, # | |
M, D, | |
# shared by Q/K/V/DO. | |
stride_z, stride_h, stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M1: tl.constexpr, # | |
BLOCK_N1: tl.constexpr, # | |
BLOCK_M2: tl.constexpr, # | |
BLOCK_N2: tl.constexpr, # | |
BLK_SLICE_FACTOR: tl.constexpr, # | |
HEAD_DIM: tl.constexpr): | |
LN2: tl.constexpr = 0.6931471824645996 # = ln(2) | |
bhid = tl.program_id(2) | |
off_chz = (bhid * N_CTX).to(tl.int64) | |
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) | |
pid = tl.program_id(0) | |
# offset pointers for batch/head | |
Q += adj | |
K += adj | |
V += adj | |
DO += adj | |
DQ += adj | |
DK += adj | |
DV += adj | |
M += off_chz | |
D += off_chz | |
# load scales | |
offs_k = tl.arange(0, HEAD_DIM) | |
start_n = pid * BLOCK_N1 | |
start_m = start_n | |
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR | |
offs_n = start_n + tl.arange(0, BLOCK_N1) | |
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) | |
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) | |
# load K and V: they stay in SRAM throughout the inner loop. | |
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
num_steps = BLOCK_N1 // MASK_BLOCK_M1 | |
dk, dv = _attn_bwd_dkdv(dk, dv, # | |
Q, k, v, sm_scale, # | |
DO, # | |
M, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # | |
start_n, start_m, num_steps, # | |
MASK=True # | |
) | |
start_m += num_steps * MASK_BLOCK_M1 | |
num_steps = (N_CTX - start_m) // BLOCK_M1 | |
# Compute dK and dV for non-masked blocks. | |
dk, dv = _attn_bwd_dkdv( # | |
dk, dv, # | |
Q, k, v, sm_scale, # | |
DO, # | |
M, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M1, BLOCK_N1, HEAD_DIM, # | |
start_n, start_m, num_steps, # | |
MASK=False # | |
) | |
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d | |
tl.store(dv_ptrs, dv) | |
# Write back dK. | |
dk *= sm_scale | |
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d | |
tl.store(dk_ptrs, dk) | |
# THIS BLOCK DOES DQ: | |
start_m = pid * BLOCK_M2 | |
end_n = start_m + BLOCK_M2 | |
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR | |
offs_m = start_m + tl.arange(0, BLOCK_M2) | |
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) | |
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
m = tl.load(M + offs_m) | |
m = m[:, None] | |
# Compute dQ for masked (diagonal) blocks. | |
# NOTE: This code scans each row of QK^T backward (from right to left, | |
# but inside each call to _attn_bwd_dq, from left to right), but that's | |
# not due to anything important. I just wanted to reuse the loop | |
# structure for dK & dV above as much as possible. | |
num_steps = BLOCK_M2 // MASK_BLOCK_N2 | |
dq = _attn_bwd_dq(dq, q, K, V, # | |
do, m, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # | |
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # | |
MASK=True # | |
) | |
end_n -= num_steps * MASK_BLOCK_N2 | |
# stage 2 | |
num_steps = end_n // BLOCK_N2 | |
dq = _attn_bwd_dq(dq, q, K, V, # | |
do, m, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M2, BLOCK_N2, HEAD_DIM, # | |
start_m, end_n - num_steps * BLOCK_N2, num_steps, # | |
MASK=False # | |
) | |
# Write back dQ. | |
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d | |
dq *= LN2 | |
tl.store(dq_ptrs, dq) | |
class JVPAttn(Function): | |
class Grid(NamedTuple): | |
M_BLOCKS: int | |
Z_H: int | |
ONE: Literal[1] | |
class FnCtx(FunctionCtx): | |
sm_scale: float | |
HEAD_DIM_K: int | |
causal: bool | |
grid: JVPAttn.Grid | |
class FwdOutCtxContrib(NamedTuple): | |
o_t: Optional[Tensor] | |
M: Tensor | |
grid: JVPAttn.Grid | |
HEAD_DIM_K: int | |
sm_scale: float | |
class FwdOut(NamedTuple): | |
o: Tensor | |
ctx: JVPAttn.FwdOutCtxContrib | |
class JVPOut(NamedTuple): | |
o: Tensor | |
ctx: None | |
class BwdOut(NamedTuple): | |
q: Tensor | |
k: Tensor | |
v: Tensor | |
q_t: None | |
k_t: None | |
v_t: None | |
causal: None | |
sm_scale: None | |
warp_specialize: None | |
USE_TMA: None | |
class Strides(NamedTuple): | |
z: int | |
h: int | |
n_ctx: int | |
head_dim: int | |
@staticmethod | |
def forward( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
q_t: Optional[Tensor], | |
k_t: Optional[Tensor], | |
v_t: Optional[Tensor], | |
causal: bool, | |
sm_scale: Optional[float], | |
warp_specialize=True, | |
USE_TMA=True, | |
) -> JVPAttn.FwdOut: | |
# shape constraints | |
Z, H, N_CTX, HEAD_DIM_Q = q.shape | |
HEAD_DIM_K = k.shape[-1] | |
# when v is in float8_e5m2 it is transposed. | |
HEAD_DIM_V = v.shape[-1] | |
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V | |
assert HEAD_DIM_K in {16, 32, 64, 128, 256} | |
if sm_scale is None: | |
sm_scale = HEAD_DIM_K**-.5 | |
o = torch.empty_like(q) | |
ENABLE_JVP = q_t is not None | |
o_t: Optional[Tensor] = torch.empty_like(q_t) if ENABLE_JVP else None | |
stage = 3 if causal else 1 | |
extra_kern_args = {} | |
# Tuning for AMD target | |
if is_hip(): | |
waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 | |
extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} | |
if is_cuda() and warp_specialize: | |
# we need more registers if we're doing JVP | |
if (HEAD_DIM_K == 128 and q.dtype == torch.float16) or ENABLE_JVP: | |
extra_kern_args["maxnreg"] = 168 | |
else: | |
# TODO: I think for backwards pass of dim=128 this is too low for H100; register allocation fails | |
extra_kern_args["maxnreg"] = 80 | |
M = torch.empty((Z, H, N_CTX), device=q.device, dtype=torch.float32) | |
if hasattr(triton, 'set_allocator') and is_cuda(): | |
def alloc_fn(size: int, align: int, _): | |
return torch.empty(size, dtype=torch.int8, device="cuda") | |
triton.set_allocator(alloc_fn) | |
Z_H = Z * H | |
def grid(META: dict[str, Any]) -> JVPAttn.Grid: | |
return JVPAttn.Grid(triton.cdiv(N_CTX, META["BLOCK_M"]), Z_H, 1) | |
# if USE_TMA and supports_tma() and not (torch.cuda.get_device_capability()[0] == 9 | |
# and q.dtype == torch.float8_e5m2): | |
if USE_TMA and supports_tma(): | |
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor | |
y_dim = Z_H * N_CTX | |
dummy_block = [1, 1] | |
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_q_t = desc_q if q_t is None else TensorDescriptor(q_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
if q.dtype == torch.float8_e5m2: | |
v_shape = [HEAD_DIM_K, y_dim] | |
v_strides = [N_CTX, 1] | |
else: | |
v_shape = [y_dim, HEAD_DIM_K] | |
v_strides = [HEAD_DIM_K, 1] | |
desc_v = TensorDescriptor(v, shape=v_shape, strides=v_strides, block_shape=dummy_block) | |
# probably we could share the shape and strides from above, but whatever | |
if q_t.dtype == torch.float8_e5m2: | |
t_v_shape = [HEAD_DIM_K, y_dim] | |
t_v_strides = [q_t.shape[2], 1] | |
else: | |
t_v_shape = [y_dim, HEAD_DIM_K] | |
t_v_strides = [HEAD_DIM_K, 1] | |
desc_v_t = desc_v if v_t is None else TensorDescriptor(v_t, shape=t_v_shape, strides=t_v_strides, block_shape=dummy_block) | |
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_k_t = desc_k if k_t is None else TensorDescriptor(k_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_o_t = desc_o if o_t is None else TensorDescriptor(o_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
_attn_fwd_tma[grid]( | |
sm_scale, M, # | |
Z, H, # | |
desc_q, desc_k, desc_v, # | |
desc_q_t, desc_k_t, desc_v_t, # | |
desc_o, desc_o_t, # | |
N_CTX=N_CTX, # | |
HEAD_DIM=HEAD_DIM_K, # | |
FP8_OUTPUT=q.dtype == torch.float8_e5m2, # | |
STAGE=stage, # | |
warp_specialize=warp_specialize, # | |
ENABLE_JVP=ENABLE_JVP, # | |
**extra_kern_args) | |
else: | |
def strides_zhnd(t: Tensor) -> JVPAttn.Strides: | |
return JVPAttn.Strides(t.stride(0), t.stride(1), t.stride(2), t.stride(3)) | |
_attn_fwd[grid]( | |
q, k, v, q_t, k_t, v_t, # | |
sm_scale, M, o, o_t, # | |
*strides_zhnd(q), # | |
*strides_zhnd(k), # | |
*strides_zhnd(v), # | |
*strides_zhnd(q if q_t is None else q_t), # | |
*strides_zhnd(k if k_t is None else k_t), # | |
*strides_zhnd(v if v_t is None else v_t), # | |
*strides_zhnd(o), # | |
*strides_zhnd(o if o_t is None else o_t), # | |
Z, H, # | |
N_CTX=N_CTX, # | |
HEAD_DIM=HEAD_DIM_K, # | |
FP8_OUTPUT=q.dtype == torch.float8_e5m2, # | |
STAGE=stage, # | |
warp_specialize=warp_specialize, # | |
ENABLE_JVP=ENABLE_JVP, # | |
**extra_kern_args) | |
return JVPAttn.FwdOut(o, JVPAttn.FwdOutCtxContrib(o_t, M, grid, HEAD_DIM_K, sm_scale)) | |
@staticmethod | |
def setup_context(ctx: JVPAttn.FnCtx, inputs, outputs: JVPAttn.FwdOut) -> Tensor: | |
( | |
q, | |
k, | |
v, | |
q_t, | |
k_t, | |
v_t, | |
causal, | |
sm_scale, | |
warp_specialize, | |
USE_TMA | |
) = inputs | |
o, (o_t, M, grid, HEAD_DIM_K, sm_scale) = outputs | |
ctx.grid = grid | |
ctx.save_for_forward(o_t) | |
ctx.save_for_backward(q, k, v, o, M) | |
ctx.sm_scale = sm_scale | |
ctx.HEAD_DIM_K = HEAD_DIM_K | |
ctx.causal = causal | |
@staticmethod | |
def fwd( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
causal = False, | |
sm_scale: Optional[float] = None, | |
warp_specialize=True, | |
USE_TMA=True, | |
) -> Tensor: | |
""" | |
This is not an autograd convention, it's a workaround to get type-hinting and kwarg support | |
""" | |
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, None, None, None, causal, sm_scale, warp_specialize, USE_TMA) | |
a, _ = out | |
return a | |
@staticmethod | |
def fwd_dual( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
causal = False, | |
sm_scale: Optional[float] = None, | |
warp_specialize=True, | |
USE_TMA=True, | |
) -> Tensor: | |
""" | |
This is not an autograd convention, it's a workaround for invoking | |
JVPAttn::forward with the right arguments when you have a dual tensor input. | |
""" | |
q_p, q_t = fwAD.unpack_dual(q) | |
k_p, k_t = fwAD.unpack_dual(k) | |
v_p, v_t = fwAD.unpack_dual(v) | |
# we pass some dualtensor args to ensure jvp() will be called | |
# but we also pass tangents separately, as forward() demotes dual tensor args to primals for some reason | |
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, q_t, k_t, v_t, causal, sm_scale, warp_specialize, USE_TMA) | |
a, _ = out | |
return a | |
@staticmethod | |
def jvp(ctx: JVPAttn.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> JVPAttn.JVPOut: | |
return JVPAttn.JVPOut(ctx.saved_for_forward[0], None) | |
@staticmethod | |
def backward(ctx, do, _) -> JVPAttn.BwdOut: | |
q, k, v, o, M = ctx.saved_tensors | |
assert do.is_contiguous() | |
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() | |
dq = torch.empty_like(q) | |
dk = torch.empty_like(k) | |
dv = torch.empty_like(v) | |
BATCH, N_HEAD, N_CTX = q.shape[:3] | |
PRE_BLOCK = 128 | |
NUM_WARPS, NUM_STAGES = 4, 5 | |
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 | |
BLK_SLICE_FACTOR = 2 | |
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) | |
arg_k = k | |
arg_k = arg_k * (ctx.sm_scale * RCP_LN2) | |
PRE_BLOCK = 128 | |
assert N_CTX % PRE_BLOCK == 0 | |
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) | |
delta = torch.empty_like(M) | |
_attn_bwd_preprocess[pre_grid]( | |
o, do, # | |
delta, # | |
BATCH, N_HEAD, N_CTX, # | |
BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM_K # | |
) | |
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) | |
_attn_bwd[grid]( | |
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # | |
M, delta, # | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), # | |
N_HEAD, N_CTX, # | |
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # | |
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # | |
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # | |
HEAD_DIM=ctx.HEAD_DIM_K, # | |
num_warps=NUM_WARPS, # | |
num_stages=NUM_STAGES # | |
) | |
return JVPAttn.BwdOut(dq, dk, dv, None, None, None, None, None, None, None) | |
attention = JVPAttn.fwd | |
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') | |
@pytest.mark.parametrize("Z", [1, 4]) | |
@pytest.mark.parametrize("H", [2, 48]) | |
@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024]) | |
@pytest.mark.parametrize("HEAD_DIM", [64, 128]) | |
@pytest.mark.parametrize("causal", [True]) # FIXME: Non-causal tests do not pass at the moment. | |
@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False]) | |
@pytest.mark.parametrize("mode", ["fwd", "bwd"]) | |
@pytest.mark.parametrize("provider", ["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else [])) | |
def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtype=torch.float16): | |
if mode == "fwd" and "fp16" in provider: | |
pytest.skip("Avoid running the forward computation twice.") | |
if mode == "bwd" and "fp8" in provider: | |
pytest.skip("Backward pass with FP8 is not supported.") | |
torch.manual_seed(20) | |
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) | |
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) | |
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) | |
sm_scale = 0.5 | |
# reference implementation | |
ref_dtype = dtype | |
if mode == "fwd" and "fp8" in provider: | |
ref_dtype = torch.float32 | |
q = q.to(ref_dtype) | |
k = k.to(ref_dtype) | |
v = v.to(ref_dtype) | |
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE)) | |
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale | |
if causal: | |
p[:, :, M == 0] = float("-inf") | |
p = torch.softmax(p.float(), dim=-1) | |
p = p.to(ref_dtype) | |
# p = torch.exp(p) | |
ref_out = torch.matmul(p, v).half() | |
if mode == "bwd": | |
dout = torch.randn_like(q) | |
ref_out.backward(dout) | |
ref_dv, v.grad = v.grad.clone(), None | |
ref_dk, k.grad = k.grad.clone(), None | |
ref_dq, q.grad = q.grad.clone(), None | |
# triton implementation | |
if mode == "fwd" and "fp8" in provider: | |
q = q.to(torch.float8_e5m2) | |
k = k.to(torch.float8_e5m2) | |
v = v.permute(0, 1, 3, 2).contiguous() | |
v = v.permute(0, 1, 3, 2) | |
v = v.to(torch.float8_e5m2) | |
tri_out = attention(q, k, v, causal, sm_scale, warp_specialize).half() | |
if mode == "fwd": | |
atol = 3 if "fp8" in provider else 1e-2 | |
torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0) | |
return | |
tri_out.backward(dout) | |
tri_dv, v.grad = v.grad.clone(), None | |
tri_dk, k.grad = k.grad.clone(), None | |
tri_dq, q.grad = q.grad.clone(), None | |
# compare | |
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=0) | |
rtol = 0.0 | |
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU. | |
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices | |
if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": | |
rtol = 1e-2 | |
torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol) | |
torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol) | |
torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol) | |
try: | |
from flash_attn.flash_attn_interface import \ | |
flash_attn_qkvpacked_func as flash_attn_func | |
HAS_FLASH = True | |
except BaseException: | |
HAS_FLASH = False | |
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') | |
BATCH, N_HEADS = 4, 32 | |
# vary seq length for fixed head and batch=4 | |
configs = [] | |
for HEAD_DIM in [64, 128]: | |
for mode in ["fwd", "bwd"]: | |
for causal in [True, False]: | |
for warp_specialize in [True, False]: | |
configs.append( | |
triton.testing.Benchmark( | |
x_names=["N_CTX"], | |
x_vals=[2**i for i in range(10, 15)], | |
line_arg="provider", | |
line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + | |
(["flash"] if HAS_FLASH else []), | |
line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + | |
(["Flash-2"] if HAS_FLASH else []), | |
styles=[("red", "-"), ("blue", "-"), ("green", "-")], | |
ylabel="TFLOPS", | |
plot_name= | |
f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}-warp_specialize={warp_specialize}", | |
args={ | |
"H": N_HEADS, | |
"BATCH": BATCH, | |
"HEAD_DIM": HEAD_DIM, | |
"mode": mode, | |
"causal": causal, | |
"warp_specialize": warp_specialize, | |
}, | |
)) | |
@triton.testing.perf_report(configs) | |
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, device=DEVICE): | |
assert mode in ["fwd", "bwd"] | |
dtype = torch.float16 | |
if "triton" in provider: | |
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) | |
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) | |
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) | |
if mode == "fwd" and "fp8" in provider: | |
q = q.to(torch.float8_e5m2) | |
k = k.to(torch.float8_e5m2) | |
v = v.permute(0, 1, 3, 2).contiguous() | |
v = v.permute(0, 1, 3, 2) | |
v = v.to(torch.float8_e5m2) | |
sm_scale = 1.3 | |
fn = lambda: attention(q, k, v, causal, sm_scale, warp_specialize) | |
if mode == "bwd": | |
o = fn() | |
do = torch.randn_like(o) | |
fn = lambda: o.backward(do, retain_graph=True) | |
ms = triton.testing.do_bench(fn) | |
if provider == "flash": | |
qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) | |
fn = lambda: flash_attn_func(qkv, causal=causal) | |
if mode == "bwd": | |
o = fn() | |
do = torch.randn_like(o) | |
fn = lambda: o.backward(do, retain_graph=True) | |
ms = triton.testing.do_bench(fn) | |
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM | |
total_flops = 2 * flops_per_matmul | |
if causal: | |
total_flops *= 0.5 | |
if mode == "bwd": | |
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) | |
return total_flops * 1e-12 / (ms * 1e-3) | |
if __name__ == "__main__": | |
# only works on post-Ampere GPUs right now | |
bench_flash_attention.run(save_path=".", print_data=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@Birch-san
The memory leak seems fixed. Thank you
I still think my adjustments to the backward are necessary.
However, I am not familiar with the official flash attention example, can you provide a link?
I found this implementation which is pretty close to the one used here. The major difference is that your implementation has a number of steps starting on the diagonal (matching the stage handling in the forward), whereas the implementation I found, always goes through all blocks and masks as necessary (see for instance line 461ff).
The method used here is more efficient for Causal attention as it skips unnecessary blocks directly and only applies masking on diagonal blocks.
Compare to 835-836 selecting start and end of a diagonal block, and then moving through exactly enough steps to cover that block on lines 853-862. Thereafter only the previous blocks in that line are processed.
So I disagree that it is currently working correctly from a theoretical viewpoint.
As for the testing, since we are debating whether the backward is correct, the gradients are the important values.
However, in your example, maximal gradients were around 2e-4 for me, making it practically impossible to violate the tolerances.
One can scale up the loss by about 1e3 to get maximal gradients around 1e-1. (That seemed a good number to me)
Then for the current gist I get (for the path "is_causal=False"),
Whereas the implementation with my proposed changes runs through without issue.
Can you confirm this on your machine maybe?