-
-
Save Birch-san/c51234fe006cf1ffc680063abb4f572f to your computer and use it in GitHub Desktop.
from __future__ import annotations | |
""" | |
Fused Attention | |
=============== | |
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) | |
Credits: OpenAI kernel team | |
Extra Credits: | |
* Original flash attention paper (https://arxiv.org/abs/2205.14135) | |
* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) | |
Plus modifications to support JVP: | |
- formulation of flash JVP by Cheng Lu, Yang Song in https://arxiv.org/abs/2410.11081 | |
- reference triton implementation by Sofian Mejjoute | |
- reimplementing reference implementation as an autograd function with latest triton tutorial optimizations, by Alex Birch | |
- support for forward to receive tangents, so as to compute fwd and jvp together, autograd workaround by Emily (nshepperd) | |
- support for function transforms (e.g. torch.func.jvp) via the use of setup_context, by Shih-Ying Yeh | |
""" | |
from typing import Any, Literal, NamedTuple, Optional | |
import pytest | |
import torch | |
import os | |
import triton | |
import triton.language as tl | |
from torch import Tensor | |
from torch.autograd import Function | |
from torch.autograd.function import FunctionCtx | |
import torch.autograd.forward_ad as fwAD | |
try: | |
from triton.tools.tensor_descriptor import TensorDescriptor | |
HAS_TENSOR_DESC = True | |
except ModuleNotFoundError: | |
HAS_TENSOR_DESC = False | |
DEVICE = getattr(triton.runtime.driver.active, "get_active_torch_device", lambda: torch.device('cuda'))() | |
def is_hip(): | |
return triton.runtime.driver.active.get_current_target().backend == "hip" | |
def is_cuda(): | |
return triton.runtime.driver.active.get_current_target().backend == "cuda" | |
def supports_host_descriptor(): | |
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 | |
def supports_tma(): | |
return HAS_TENSOR_DESC and is_cuda() and torch.cuda.get_device_capability()[0] >= 9 | |
def is_blackwell(): | |
return is_cuda() and torch.cuda.get_device_capability()[0] == 10 | |
@triton.jit | |
def _attn_fwd_inner(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
K_block_ptr, V_block_ptr, # | |
T_K_block_ptr, T_V_block_ptr, # | |
dtype: tl.constexpr, start_m, qk_scale, sm_scale, # | |
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # | |
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # | |
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, # | |
ENABLE_JVP: tl.constexpr): | |
# range of values handled by this stage | |
if STAGE == 1: | |
lo, hi = 0, start_m * BLOCK_M | |
elif STAGE == 2: | |
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M | |
lo = tl.multiple_of(lo, BLOCK_M) | |
# causal = False | |
else: | |
lo, hi = 0, N_CTX | |
K_block_ptr = tl.advance(K_block_ptr, (0, lo)) | |
# NOTE: in fp8 mode, we may want to advance the V_block_ptr differently. | |
# I did try advancing by (0, lo) instead for fp8, but I got an illegal memory access. | |
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 | |
V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) | |
if ENABLE_JVP: | |
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, lo)) | |
T_V_block_ptr = tl.advance(T_V_block_ptr, (lo, 0)) | |
# loop over k, v and update accumulator | |
for start_n in range(lo, hi, BLOCK_N): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
# -- compute qk ---- | |
k = tl.load(K_block_ptr) | |
qk = tl.dot(q, k) | |
if ENABLE_JVP: | |
t_k = tl.load(T_K_block_ptr) | |
t_qk = tl.dot(t_q, k) + tl.dot(q, t_k) | |
if STAGE == 2: | |
mask = offs_m[:, None] >= (start_n + offs_n[None, :]) | |
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) | |
if ENABLE_JVP: | |
# Claude says "tangents should be masked with 0.0 since they represent derivatives". | |
t_qk = tl.where(mask, t_qk, 0.0) | |
# TODO: do we need a separate row maximum for qk_t? | |
m_ij = tl.maximum(m_i, tl.max(qk, 1)) | |
qk -= m_ij[:, None] | |
else: | |
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) | |
qk = qk * qk_scale - m_ij[:, None] | |
p = tl.math.exp2(qk) | |
l_ij = tl.sum(p, 1) | |
# -- update m_i and l_i | |
alpha = tl.math.exp2(m_i - m_ij) | |
l_i = l_i * alpha + l_ij | |
# -- update output accumulator -- | |
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): | |
BM: tl.constexpr = acc.shape[0] | |
BN: tl.constexpr = acc.shape[1] | |
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
acc0 = acc0 * alpha[:, None] | |
acc1 = acc1 * alpha[:, None] | |
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
acc = acc * alpha[:, None] | |
# update acc | |
v = tl.load(V_block_ptr) | |
# NOTE: we may need to transpose v if dtype == tl.float8e5 | |
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 | |
p = p.to(dtype) | |
if ENABLE_JVP: | |
p_tqk = p * (t_qk * sm_scale) | |
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): | |
BM: tl.constexpr = g_acc.shape[0] | |
BN: tl.constexpr = g_acc.shape[1] | |
g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
g_acc0 = g_acc0 * alpha[:, None] | |
g_acc1 = g_acc1 * alpha[:, None] | |
g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
g_acc = g_acc * alpha[:, None] | |
g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc) | |
mu_ij = tl.sum(p_tqk, 1) | |
mu_i = mu_i * alpha + mu_ij | |
t_v = tl.load(T_V_block_ptr) | |
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v) | |
T_V_block_ptr = tl.advance(T_V_block_ptr, (BLOCK_N, 0)) | |
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, BLOCK_N)) | |
acc = tl.dot(p, v, acc) | |
# update m_i and l_i | |
m_i = m_ij | |
# the fp8 PR made a change to how K and V are advanced here but I believe we already have that. | |
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31 | |
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) | |
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) | |
return acc, g_acc, l_i, m_i, mu_i, p_tv_acc | |
@triton.jit | |
def _attn_fwd_inner_tma(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
desc_k, desc_v, # | |
desc_k_t, desc_v_t, # | |
offset_y, dtype: tl.constexpr, start_m, qk_scale, sm_scale, # | |
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, # | |
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # | |
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, | |
ENABLE_JVP: tl.constexpr): | |
# range of values handled by this stage | |
if STAGE == 1: | |
lo, hi = 0, start_m * BLOCK_M | |
elif STAGE == 2: | |
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M | |
lo = tl.multiple_of(lo, BLOCK_M) | |
# causal = False | |
else: | |
lo, hi = 0, N_CTX | |
offsetk_y = offset_y + lo | |
if dtype == tl.float8e5: | |
offsetv_y = offset_y * HEAD_DIM + lo | |
else: | |
offsetv_y = offset_y + lo | |
# loop over k, v and update accumulator | |
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize): | |
start_n = tl.multiple_of(start_n, BLOCK_N) | |
# -- compute qk ---- | |
k = desc_k.load([offsetk_y, 0]).T | |
qk = tl.dot(q, k) | |
if ENABLE_JVP: | |
t_k = desc_k_t.load([offsetk_y, 0]).T | |
t_qk = tl.dot(t_q, k) + tl.dot(q, t_k) | |
if STAGE == 2: | |
mask = offs_m[:, None] >= (start_n + offs_n[None, :]) | |
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) | |
if ENABLE_JVP: | |
# Claude says "tangents should be masked with 0.0 since they represent derivatives". | |
t_qk = tl.where(mask, t_qk, 0.0) | |
m_ij = tl.maximum(m_i, tl.max(qk, 1)) | |
qk -= m_ij[:, None] | |
else: | |
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) | |
qk = qk * qk_scale - m_ij[:, None] | |
p = tl.math.exp2(qk) | |
# -- compute correction factor | |
alpha = tl.math.exp2(m_i - m_ij) | |
l_ij = tl.sum(p, 1) | |
# -- update output accumulator -- | |
if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128: | |
BM: tl.constexpr = acc.shape[0] | |
BN: tl.constexpr = acc.shape[1] | |
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
acc0 = acc0 * alpha[:, None] | |
acc1 = acc1 * alpha[:, None] | |
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
acc = acc * alpha[:, None] | |
# prepare p and v for the dot | |
if dtype == tl.float8e5: | |
v = desc_v.load([0, offsetv_y]).T | |
else: | |
v = desc_v.load([offsetv_y, 0]) | |
p = p.to(dtype) | |
if ENABLE_JVP: | |
p_tqk = p * (t_qk * sm_scale) | |
# this non-transposed v for FP8 is presumably only supported on Blackwell | |
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128): | |
BM: tl.constexpr = g_acc.shape[0] | |
BN: tl.constexpr = g_acc.shape[1] | |
g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split() | |
g_acc0 = g_acc0 * alpha[:, None] | |
g_acc1 = g_acc1 * alpha[:, None] | |
g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN]) | |
else: | |
g_acc = g_acc * alpha[:, None] | |
g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc) | |
mu_ij = tl.sum(p_tqk, 1) | |
mu_i = mu_i * alpha + mu_ij | |
t_v = desc_v_t.load([offsetv_y, 0]) | |
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v) | |
# note that this non transposed v for FP8 is only supported on Blackwell | |
acc = tl.dot(p, v, acc) | |
# update m_i and l_i | |
# place this at the end of the loop to reduce register pressure | |
l_i = l_i * alpha + l_ij | |
m_i = m_ij | |
offsetk_y += BLOCK_N | |
offsetv_y += BLOCK_N | |
return acc, g_acc, l_i, m_i, mu_i, p_tv_acc | |
def _host_descriptor_pre_hook(nargs): | |
BLOCK_M = nargs["BLOCK_M"] | |
BLOCK_N = nargs["BLOCK_N"] | |
HEAD_DIM = nargs["HEAD_DIM"] | |
if not HAS_TENSOR_DESC or not isinstance(nargs["desc_q"], TensorDescriptor): | |
return | |
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] | |
if nargs["FP8_OUTPUT"]: | |
nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N] | |
else: | |
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] | |
if is_hip(): | |
NUM_STAGES_OPTIONS = [1] | |
elif supports_host_descriptor(): | |
NUM_STAGES_OPTIONS = [2, 3, 4] | |
else: | |
NUM_STAGES_OPTIONS = [2, 3, 4] | |
configs = [ | |
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \ | |
for BM in [64, 128]\ | |
for BN in [32, 64, 128]\ | |
for s in NUM_STAGES_OPTIONS \ | |
for w in [4, 8]\ | |
] | |
if "PYTEST_VERSION" in os.environ: | |
# Use a single config in testing for reproducibility | |
configs = [ | |
triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook), | |
] | |
def keep(conf): | |
BLOCK_M = conf.kwargs["BLOCK_M"] | |
BLOCK_N = conf.kwargs["BLOCK_N"] | |
return not (BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8) | |
def prune_invalid_configs(configs, named_args, **kwargs): | |
N_CTX = kwargs["N_CTX"] | |
# Filter out configs where BLOCK_M > N_CTX | |
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX] | |
@triton.jit | |
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape): | |
if isinstance(desc_or_ptr, tl.tensor_descriptor): | |
return desc_or_ptr | |
else: | |
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape) | |
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], | |
prune_configs_by={'early_config_prune': prune_invalid_configs}) | |
@triton.jit | |
def _attn_fwd(Q, K, V, T_Q, T_K, T_V, # | |
sm_scale, M, Out, T_Out, # | |
stride_qz, stride_qh, stride_qm, stride_qk, # | |
stride_kz, stride_kh, stride_kn, stride_kk, # | |
stride_vz, stride_vh, stride_vk, stride_vn, # | |
stride_tqz, stride_tqh, stride_tqm, stride_tqk, # | |
stride_tkz, stride_tkh, stride_tkn, stride_tkk, # | |
stride_tvz, stride_tvh, stride_tvk, stride_tvn, # | |
stride_oz, stride_oh, stride_om, stride_on, # | |
stride_toz, stride_toh, stride_tom, stride_ton, # | |
Z, H, N_CTX, # | |
HEAD_DIM: tl.constexpr, # | |
BLOCK_M: tl.constexpr, # | |
BLOCK_N: tl.constexpr, # | |
FP8_OUTPUT: tl.constexpr, # | |
STAGE: tl.constexpr, # | |
warp_specialize: tl.constexpr, # | |
ENABLE_JVP: tl.constexpr, # | |
): | |
tl.static_assert(BLOCK_N <= HEAD_DIM) | |
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
off_z = off_hz // H | |
off_h = off_hz % H | |
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh | |
# block pointers | |
Q_block_ptr = tl.make_block_ptr( | |
base=Q + qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_qm, stride_qk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) | |
V_block_ptr = tl.make_block_ptr( | |
base=V + qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_vk, stride_vn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, HEAD_DIM), | |
order=v_order, | |
) | |
K_block_ptr = tl.make_block_ptr( | |
base=K + qvk_offset, | |
shape=(HEAD_DIM, N_CTX), | |
strides=(stride_kk, stride_kn), | |
offsets=(0, 0), | |
block_shape=(HEAD_DIM, BLOCK_N), | |
order=(0, 1), | |
) | |
O_block_ptr = tl.make_block_ptr( | |
base=Out + qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_om, stride_on), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
# initialize pointer to m and l | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 | |
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
if ENABLE_JVP: | |
# it's extremely likely we could just re-use qvk_offset, but this seems cheap so whatever | |
t_qvk_offset = off_z.to(tl.int64) * stride_tqz + off_h.to(tl.int64) * stride_tqh | |
T_Q_block_ptr = tl.make_block_ptr( | |
base=T_Q + t_qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_tqm, stride_tqk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
# could probably just re-use v_order here | |
t_v_order: tl.constexpr = (0, 1) if T_V.dtype.element_ty == tl.float8e5 else (1, 0) | |
T_V_block_ptr = tl.make_block_ptr( | |
base=T_V + t_qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_tvk, stride_tvn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, HEAD_DIM), | |
order=t_v_order, | |
) | |
T_K_block_ptr = tl.make_block_ptr( | |
base=T_K + t_qvk_offset, | |
shape=(HEAD_DIM, N_CTX), | |
strides=(stride_tkk, stride_tkn), | |
offsets=(0, 0), | |
block_shape=(HEAD_DIM, BLOCK_N), | |
order=(0, 1), | |
) | |
T_O_block_ptr = tl.make_block_ptr( | |
base=T_Out + t_qvk_offset, | |
shape=(N_CTX, HEAD_DIM), | |
strides=(stride_tom, stride_ton), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, HEAD_DIM), | |
order=(1, 0), | |
) | |
# load q_t: it will stay in SRAM throughout | |
t_q = tl.load(T_Q_block_ptr) | |
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
else: | |
t_q = None | |
T_V_block_ptr = None | |
T_K_block_ptr = None | |
# Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner | |
g_acc = tl.zeros([1, 1], dtype=tl.float32) | |
mu_i = tl.zeros([1], dtype=tl.float32) | |
p_tv_acc = tl.zeros([1, 1], dtype=tl.float32) | |
# load scales | |
qk_scale = sm_scale | |
qk_scale = qk_scale * 1.44269504 # 1/log(2) | |
# load q: it will stay in SRAM throughout | |
q = tl.load(Q_block_ptr) | |
# stage 1: off-band | |
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE | |
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE | |
if STAGE & 1: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc, | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
K_block_ptr, V_block_ptr, # | |
T_K_block_ptr, T_V_block_ptr, # | |
dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
4 - STAGE, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# stage 2: on-band | |
if STAGE & 2: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
K_block_ptr, V_block_ptr, # | |
T_K_block_ptr, T_V_block_ptr, # | |
dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
2, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# epilogue | |
m_i += tl.math.log2(l_i) | |
acc = acc / l_i[:, None] | |
m_ptrs = M + off_hz * N_CTX + offs_m | |
tl.store(m_ptrs, m_i) | |
tl.store(O_block_ptr, acc.to(Out.type.element_ty)) | |
if ENABLE_JVP: | |
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc | |
t_y_out = t_p_v + p_tv_acc / l_i[:, None] | |
tl.store(T_O_block_ptr, t_y_out.to(T_Out.type.element_ty)) | |
def _tma_pre_hook(nargs): | |
BLOCK_M = nargs["BLOCK_M"] | |
BLOCK_N = nargs["BLOCK_N"] | |
HEAD_DIM = nargs["HEAD_DIM"] | |
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM] | |
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM] | |
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM] | |
# We don't run auto-tuning every time to keep the tutorial fast. Keeping | |
# the code below and commenting out the equivalent parameters is convenient for | |
# re-tuning. | |
configs_tma = [ | |
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_tma_pre_hook) \ | |
for BM in [64, 128, 256]\ | |
for BN in [64, 128]\ | |
for s in [3, 4, 5]\ | |
for w in [4, 8]\ | |
] | |
def keep_tma(conf): | |
BLOCK_M = conf.kwargs["BLOCK_M"] | |
BLOCK_N = conf.kwargs["BLOCK_N"] | |
return not (is_cuda() and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128 | |
and conf.num_warps == 8) | |
@triton.autotune(configs=list(filter(keep_tma, configs_tma)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"], | |
prune_configs_by={'early_config_prune': prune_invalid_configs}) | |
@triton.jit | |
def _attn_fwd_tma(sm_scale, M, # | |
Z, H, # | |
desc_q, desc_k, desc_v, # | |
desc_q_t, desc_k_t, desc_v_t, # | |
desc_o, desc_o_t, # | |
N_CTX, # | |
HEAD_DIM: tl.constexpr, # | |
BLOCK_M: tl.constexpr, # | |
BLOCK_N: tl.constexpr, # | |
FP8_OUTPUT: tl.constexpr, # | |
STAGE: tl.constexpr, # | |
warp_specialize: tl.constexpr, # | |
ENABLE_JVP: tl.constexpr, # | |
): | |
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16 | |
tl.static_assert(BLOCK_N <= HEAD_DIM) | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
off_z = off_hz // H | |
off_h = off_hz % H | |
y_dim = Z * H * N_CTX | |
desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
if FP8_OUTPUT: | |
v_shape = [HEAD_DIM, y_dim] | |
v_strides = [N_CTX, 1] | |
v_block_shape = [HEAD_DIM, BLOCK_N] | |
else: | |
v_shape = [y_dim, HEAD_DIM] | |
v_strides = [HEAD_DIM, 1] | |
v_block_shape = [BLOCK_N, HEAD_DIM] | |
desc_v = _maybe_make_tensor_desc(desc_v, shape=v_shape, strides=v_strides, block_shape=v_block_shape) | |
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_N, HEAD_DIM]) | |
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
offset_y = off_z * (N_CTX * H) + off_h * N_CTX | |
qo_offset_y = offset_y + start_m * BLOCK_M | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
# initialize pointer to m and l | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 | |
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
if ENABLE_JVP: | |
desc_q_t = _maybe_make_tensor_desc(desc_q_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
if FP8_OUTPUT: | |
t_v_shape = [HEAD_DIM, y_dim] | |
t_v_strides = [N_CTX, 1] | |
t_v_block_shape = [HEAD_DIM, BLOCK_N] | |
else: | |
t_v_shape = [y_dim, HEAD_DIM] | |
t_v_strides = [HEAD_DIM, 1] | |
t_v_block_shape = [BLOCK_N, HEAD_DIM] | |
desc_v_t = _maybe_make_tensor_desc(desc_v_t, shape=t_v_shape, strides=t_v_strides, block_shape=t_v_block_shape) | |
desc_k_t = _maybe_make_tensor_desc(desc_k_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_N, HEAD_DIM]) | |
desc_o_t = _maybe_make_tensor_desc(desc_o_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1], | |
block_shape=[BLOCK_M, HEAD_DIM]) | |
# load t_q: it will stay in SRAM throughout | |
t_q = desc_q_t.load([qo_offset_y, 0]) | |
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) | |
else: | |
t_q = None | |
desc_k_t = None | |
desc_v_t = None | |
# Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner_tma | |
g_acc = tl.zeros([1, 1], dtype=tl.float32) | |
mu_i = tl.zeros([1], dtype=tl.float32) | |
p_tv_acc = tl.zeros([1, 1], dtype=tl.float32) | |
# load scales | |
qk_scale = sm_scale | |
qk_scale *= 1.44269504 # 1/log(2) | |
# load q: it will stay in SRAM throughout | |
q = desc_q.load([qo_offset_y, 0]) | |
# stage 1: off-band | |
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE | |
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE | |
if STAGE & 1: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
desc_k, desc_v, # | |
desc_k_t, desc_v_t, # | |
offset_y, dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
4 - STAGE, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# stage 2: on-band | |
if STAGE & 2: | |
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma(acc, g_acc, # | |
l_i, m_i, # | |
mu_i, p_tv_acc, # | |
q, t_q, # | |
desc_k, desc_v, # | |
desc_k_t, desc_v_t, # | |
offset_y, dtype, start_m, qk_scale, sm_scale, # | |
BLOCK_M, HEAD_DIM, BLOCK_N, # | |
2, offs_m, offs_n, N_CTX, # | |
warp_specialize, | |
ENABLE_JVP) | |
# epilogue | |
m_i += tl.math.log2(l_i) | |
acc = acc / l_i[:, None] | |
m_ptrs = M + off_hz * N_CTX + offs_m | |
tl.store(m_ptrs, m_i) | |
desc_o.store([qo_offset_y, 0], acc.to(dtype)) | |
if ENABLE_JVP: | |
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc | |
t_y_out = t_p_v + p_tv_acc / l_i[:, None] | |
desc_o_t.store([qo_offset_y, 0], t_y_out.to(dtype)) | |
@triton.jit | |
def _attn_bwd_preprocess(O, DO, # | |
Delta, # | |
Z, H, N_CTX, # | |
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # | |
): | |
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) | |
off_hz = tl.program_id(1) | |
off_n = tl.arange(0, HEAD_DIM) | |
# load | |
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]) | |
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32) | |
delta = tl.sum(o * do, axis=1) | |
# write-back | |
tl.store(Delta + off_hz * N_CTX + off_m, delta) | |
# The main inner-loop logic for computing dK and dV. | |
@triton.jit | |
def _attn_bwd_dkdv(dk, dv, # | |
Q, k, v, sm_scale, # | |
DO, # | |
M, D, # | |
# shared by Q/K/V/DO. | |
stride_tok, stride_d, # | |
H, N_CTX, BLOCK_M1: tl.constexpr, # | |
BLOCK_N1: tl.constexpr, # | |
HEAD_DIM: tl.constexpr, # | |
# Filled in by the wrapper. | |
start_n, start_m, num_steps, # | |
MASK: tl.constexpr): | |
offs_m = start_m + tl.arange(0, BLOCK_M1) | |
offs_n = start_n + tl.arange(0, BLOCK_N1) | |
offs_k = tl.arange(0, HEAD_DIM) | |
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d | |
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d | |
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) | |
curr_m = start_m | |
step_m = BLOCK_M1 | |
for blk_idx in range(num_steps): | |
qT = tl.load(qT_ptrs) | |
# Load m before computing qk to reduce pipeline stall. | |
offs_m = curr_m + tl.arange(0, BLOCK_M1) | |
m = tl.load(M + offs_m) | |
qkT = tl.dot(k, qT) | |
pT = tl.math.exp2(qkT - m[None, :]) | |
# Autoregressive masking. | |
if MASK: | |
mask = (offs_m[None, :] >= offs_n[:, None]) | |
pT = tl.where(mask, pT, 0.0) | |
do = tl.load(do_ptrs) | |
# Compute dV. | |
ppT = pT | |
ppT = ppT.to(tl.float16) | |
dv += tl.dot(ppT, do) | |
# D (= delta) is pre-divided by ds_scale. | |
Di = tl.load(D + offs_m) | |
# Compute dP and dS. | |
dpT = tl.dot(v, tl.trans(do)).to(tl.float32) | |
dsT = pT * (dpT - Di[None, :]) | |
dsT = dsT.to(tl.float16) | |
dk += tl.dot(dsT, tl.trans(qT)) | |
# Increment pointers. | |
curr_m += step_m | |
qT_ptrs += step_m * stride_tok | |
do_ptrs += step_m * stride_tok | |
return dk, dv | |
# the main inner-loop logic for computing dQ | |
@triton.jit | |
def _attn_bwd_dq(dq, q, K, V, # | |
do, m, D, | |
# shared by Q/K/V/DO. | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M2: tl.constexpr, # | |
BLOCK_N2: tl.constexpr, # | |
HEAD_DIM: tl.constexpr, | |
# Filled in by the wrapper. | |
start_m, start_n, num_steps, # | |
MASK: tl.constexpr): | |
offs_m = start_m + tl.arange(0, BLOCK_M2) | |
offs_n = start_n + tl.arange(0, BLOCK_N2) | |
offs_k = tl.arange(0, HEAD_DIM) | |
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d | |
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d | |
# D (= delta) is pre-divided by ds_scale. | |
Di = tl.load(D + offs_m) | |
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. | |
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) | |
curr_n = start_n | |
step_n = BLOCK_N2 | |
for blk_idx in range(num_steps): | |
kT = tl.load(kT_ptrs) | |
vT = tl.load(vT_ptrs) | |
qk = tl.dot(q, kT) | |
p = tl.math.exp2(qk - m) | |
# Autoregressive masking. | |
if MASK: | |
offs_n = curr_n + tl.arange(0, BLOCK_N2) | |
mask = (offs_m[:, None] >= offs_n[None, :]) | |
p = tl.where(mask, p, 0.0) | |
# Compute dP and dS. | |
dp = tl.dot(do, vT).to(tl.float32) | |
ds = p * (dp - Di[:, None]) | |
ds = ds.to(tl.float16) | |
# Compute dQ. | |
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled. | |
dq += tl.dot(ds, tl.trans(kT)) | |
# Increment pointers. | |
curr_n += step_n | |
kT_ptrs += step_n * stride_tok | |
vT_ptrs += step_n * stride_tok | |
return dq | |
@triton.jit | |
def _attn_bwd(Q, K, V, sm_scale, # | |
DO, # | |
DQ, DK, DV, # | |
M, D, | |
# shared by Q/K/V/DO. | |
stride_z, stride_h, stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M1: tl.constexpr, # | |
BLOCK_N1: tl.constexpr, # | |
BLOCK_M2: tl.constexpr, # | |
BLOCK_N2: tl.constexpr, # | |
BLK_SLICE_FACTOR: tl.constexpr, # | |
HEAD_DIM: tl.constexpr): | |
LN2: tl.constexpr = 0.6931471824645996 # = ln(2) | |
bhid = tl.program_id(2) | |
off_chz = (bhid * N_CTX).to(tl.int64) | |
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) | |
pid = tl.program_id(0) | |
# offset pointers for batch/head | |
Q += adj | |
K += adj | |
V += adj | |
DO += adj | |
DQ += adj | |
DK += adj | |
DV += adj | |
M += off_chz | |
D += off_chz | |
# load scales | |
offs_k = tl.arange(0, HEAD_DIM) | |
start_n = pid * BLOCK_N1 | |
start_m = start_n | |
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR | |
offs_n = start_n + tl.arange(0, BLOCK_N1) | |
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) | |
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32) | |
# load K and V: they stay in SRAM throughout the inner loop. | |
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
num_steps = BLOCK_N1 // MASK_BLOCK_M1 | |
dk, dv = _attn_bwd_dkdv(dk, dv, # | |
Q, k, v, sm_scale, # | |
DO, # | |
M, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, # | |
start_n, start_m, num_steps, # | |
MASK=True # | |
) | |
start_m += num_steps * MASK_BLOCK_M1 | |
num_steps = (N_CTX - start_m) // BLOCK_M1 | |
# Compute dK and dV for non-masked blocks. | |
dk, dv = _attn_bwd_dkdv( # | |
dk, dv, # | |
Q, k, v, sm_scale, # | |
DO, # | |
M, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M1, BLOCK_N1, HEAD_DIM, # | |
start_n, start_m, num_steps, # | |
MASK=False # | |
) | |
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d | |
tl.store(dv_ptrs, dv) | |
# Write back dK. | |
dk *= sm_scale | |
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d | |
tl.store(dk_ptrs, dk) | |
# THIS BLOCK DOES DQ: | |
start_m = pid * BLOCK_M2 | |
end_n = start_m + BLOCK_M2 | |
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR | |
offs_m = start_m + tl.arange(0, BLOCK_M2) | |
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32) | |
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) | |
m = tl.load(M + offs_m) | |
m = m[:, None] | |
# Compute dQ for masked (diagonal) blocks. | |
# NOTE: This code scans each row of QK^T backward (from right to left, | |
# but inside each call to _attn_bwd_dq, from left to right), but that's | |
# not due to anything important. I just wanted to reuse the loop | |
# structure for dK & dV above as much as possible. | |
num_steps = BLOCK_M2 // MASK_BLOCK_N2 | |
dq = _attn_bwd_dq(dq, q, K, V, # | |
do, m, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, # | |
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # | |
MASK=True # | |
) | |
end_n -= num_steps * MASK_BLOCK_N2 | |
# stage 2 | |
num_steps = end_n // BLOCK_N2 | |
dq = _attn_bwd_dq(dq, q, K, V, # | |
do, m, D, # | |
stride_tok, stride_d, # | |
H, N_CTX, # | |
BLOCK_M2, BLOCK_N2, HEAD_DIM, # | |
start_m, end_n - num_steps * BLOCK_N2, num_steps, # | |
MASK=False # | |
) | |
# Write back dQ. | |
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d | |
dq *= LN2 | |
tl.store(dq_ptrs, dq) | |
class JVPAttn(Function): | |
class Grid(NamedTuple): | |
M_BLOCKS: int | |
Z_H: int | |
ONE: Literal[1] | |
class FnCtx(FunctionCtx): | |
sm_scale: float | |
HEAD_DIM_K: int | |
causal: bool | |
grid: JVPAttn.Grid | |
class FwdOutCtxContrib(NamedTuple): | |
o_t: Optional[Tensor] | |
M: Tensor | |
grid: JVPAttn.Grid | |
HEAD_DIM_K: int | |
sm_scale: float | |
class FwdOut(NamedTuple): | |
o: Tensor | |
ctx: JVPAttn.FwdOutCtxContrib | |
class JVPOut(NamedTuple): | |
o: Tensor | |
ctx: None | |
class BwdOut(NamedTuple): | |
q: Tensor | |
k: Tensor | |
v: Tensor | |
q_t: None | |
k_t: None | |
v_t: None | |
causal: None | |
sm_scale: None | |
warp_specialize: None | |
USE_TMA: None | |
class Strides(NamedTuple): | |
z: int | |
h: int | |
n_ctx: int | |
head_dim: int | |
@staticmethod | |
def forward( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
q_t: Optional[Tensor], | |
k_t: Optional[Tensor], | |
v_t: Optional[Tensor], | |
causal: bool, | |
sm_scale: Optional[float], | |
warp_specialize=True, | |
USE_TMA=True, | |
) -> JVPAttn.FwdOut: | |
# shape constraints | |
Z, H, N_CTX, HEAD_DIM_Q = q.shape | |
HEAD_DIM_K = k.shape[-1] | |
# when v is in float8_e5m2 it is transposed. | |
HEAD_DIM_V = v.shape[-1] | |
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V | |
assert HEAD_DIM_K in {16, 32, 64, 128, 256} | |
if sm_scale is None: | |
sm_scale = HEAD_DIM_K**-.5 | |
o = torch.empty_like(q) | |
ENABLE_JVP = q_t is not None | |
o_t: Optional[Tensor] = torch.empty_like(q_t) if ENABLE_JVP else None | |
stage = 3 if causal else 1 | |
extra_kern_args = {} | |
# Tuning for AMD target | |
if is_hip(): | |
waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2 | |
extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True} | |
if is_cuda() and warp_specialize: | |
# we need more registers if we're doing JVP | |
if (HEAD_DIM_K == 128 and q.dtype == torch.float16) or ENABLE_JVP: | |
extra_kern_args["maxnreg"] = 168 | |
else: | |
# TODO: I think for backwards pass of dim=128 this is too low for H100; register allocation fails | |
extra_kern_args["maxnreg"] = 80 | |
M = torch.empty((Z, H, N_CTX), device=q.device, dtype=torch.float32) | |
if hasattr(triton, 'set_allocator') and is_cuda(): | |
def alloc_fn(size: int, align: int, _): | |
return torch.empty(size, dtype=torch.int8, device="cuda") | |
triton.set_allocator(alloc_fn) | |
Z_H = Z * H | |
def grid(META: dict[str, Any]) -> JVPAttn.Grid: | |
return JVPAttn.Grid(triton.cdiv(N_CTX, META["BLOCK_M"]), Z_H, 1) | |
# if USE_TMA and supports_tma() and not (torch.cuda.get_device_capability()[0] == 9 | |
# and q.dtype == torch.float8_e5m2): | |
if USE_TMA and supports_tma(): | |
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor | |
y_dim = Z_H * N_CTX | |
dummy_block = [1, 1] | |
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_q_t = desc_q if q_t is None else TensorDescriptor(q_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
if q.dtype == torch.float8_e5m2: | |
v_shape = [HEAD_DIM_K, y_dim] | |
v_strides = [N_CTX, 1] | |
else: | |
v_shape = [y_dim, HEAD_DIM_K] | |
v_strides = [HEAD_DIM_K, 1] | |
desc_v = TensorDescriptor(v, shape=v_shape, strides=v_strides, block_shape=dummy_block) | |
# probably we could share the shape and strides from above, but whatever | |
if q_t.dtype == torch.float8_e5m2: | |
t_v_shape = [HEAD_DIM_K, y_dim] | |
t_v_strides = [q_t.shape[2], 1] | |
else: | |
t_v_shape = [y_dim, HEAD_DIM_K] | |
t_v_strides = [HEAD_DIM_K, 1] | |
desc_v_t = desc_v if v_t is None else TensorDescriptor(v_t, shape=t_v_shape, strides=t_v_strides, block_shape=dummy_block) | |
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_k_t = desc_k if k_t is None else TensorDescriptor(k_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
desc_o_t = desc_o if o_t is None else TensorDescriptor(o_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block) | |
_attn_fwd_tma[grid]( | |
sm_scale, M, # | |
Z, H, # | |
desc_q, desc_k, desc_v, # | |
desc_q_t, desc_k_t, desc_v_t, # | |
desc_o, desc_o_t, # | |
N_CTX=N_CTX, # | |
HEAD_DIM=HEAD_DIM_K, # | |
FP8_OUTPUT=q.dtype == torch.float8_e5m2, # | |
STAGE=stage, # | |
warp_specialize=warp_specialize, # | |
ENABLE_JVP=ENABLE_JVP, # | |
**extra_kern_args) | |
else: | |
def strides_zhnd(t: Tensor) -> JVPAttn.Strides: | |
return JVPAttn.Strides(t.stride(0), t.stride(1), t.stride(2), t.stride(3)) | |
_attn_fwd[grid]( | |
q, k, v, q_t, k_t, v_t, # | |
sm_scale, M, o, o_t, # | |
*strides_zhnd(q), # | |
*strides_zhnd(k), # | |
*strides_zhnd(v), # | |
*strides_zhnd(q if q_t is None else q_t), # | |
*strides_zhnd(k if k_t is None else k_t), # | |
*strides_zhnd(v if v_t is None else v_t), # | |
*strides_zhnd(o), # | |
*strides_zhnd(o if o_t is None else o_t), # | |
Z, H, # | |
N_CTX=N_CTX, # | |
HEAD_DIM=HEAD_DIM_K, # | |
FP8_OUTPUT=q.dtype == torch.float8_e5m2, # | |
STAGE=stage, # | |
warp_specialize=warp_specialize, # | |
ENABLE_JVP=ENABLE_JVP, # | |
**extra_kern_args) | |
return JVPAttn.FwdOut(o, JVPAttn.FwdOutCtxContrib(o_t, M, grid, HEAD_DIM_K, sm_scale)) | |
@staticmethod | |
def setup_context(ctx: JVPAttn.FnCtx, inputs, outputs: JVPAttn.FwdOut) -> Tensor: | |
( | |
q, | |
k, | |
v, | |
q_t, | |
k_t, | |
v_t, | |
causal, | |
sm_scale, | |
warp_specialize, | |
USE_TMA | |
) = inputs | |
o, (o_t, M, grid, HEAD_DIM_K, sm_scale) = outputs | |
ctx.grid = grid | |
ctx.save_for_forward(o_t) | |
ctx.save_for_backward(q, k, v, o, M) | |
ctx.sm_scale = sm_scale | |
ctx.HEAD_DIM_K = HEAD_DIM_K | |
ctx.causal = causal | |
@staticmethod | |
def fwd( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
causal = False, | |
sm_scale: Optional[float] = None, | |
warp_specialize=True, | |
USE_TMA=True, | |
) -> Tensor: | |
""" | |
This is not an autograd convention, it's a workaround to get type-hinting and kwarg support | |
""" | |
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, None, None, None, causal, sm_scale, warp_specialize, USE_TMA) | |
a, _ = out | |
return a | |
@staticmethod | |
def fwd_dual( | |
q: Tensor, | |
k: Tensor, | |
v: Tensor, | |
causal = False, | |
sm_scale: Optional[float] = None, | |
warp_specialize=True, | |
USE_TMA=True, | |
) -> Tensor: | |
""" | |
This is not an autograd convention, it's a workaround for invoking | |
JVPAttn::forward with the right arguments when you have a dual tensor input. | |
""" | |
q_p, q_t = fwAD.unpack_dual(q) | |
k_p, k_t = fwAD.unpack_dual(k) | |
v_p, v_t = fwAD.unpack_dual(v) | |
# we pass some dualtensor args to ensure jvp() will be called | |
# but we also pass tangents separately, as forward() demotes dual tensor args to primals for some reason | |
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, q_t, k_t, v_t, causal, sm_scale, warp_specialize, USE_TMA) | |
a, _ = out | |
return a | |
@staticmethod | |
def jvp(ctx: JVPAttn.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> JVPAttn.JVPOut: | |
return JVPAttn.JVPOut(ctx.saved_for_forward[0], None) | |
@staticmethod | |
def backward(ctx, do, _) -> JVPAttn.BwdOut: | |
q, k, v, o, M = ctx.saved_tensors | |
assert do.is_contiguous() | |
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() | |
dq = torch.empty_like(q) | |
dk = torch.empty_like(k) | |
dv = torch.empty_like(v) | |
BATCH, N_HEAD, N_CTX = q.shape[:3] | |
PRE_BLOCK = 128 | |
NUM_WARPS, NUM_STAGES = 4, 5 | |
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 | |
BLK_SLICE_FACTOR = 2 | |
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) | |
arg_k = k | |
arg_k = arg_k * (ctx.sm_scale * RCP_LN2) | |
PRE_BLOCK = 128 | |
assert N_CTX % PRE_BLOCK == 0 | |
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) | |
delta = torch.empty_like(M) | |
_attn_bwd_preprocess[pre_grid]( | |
o, do, # | |
delta, # | |
BATCH, N_HEAD, N_CTX, # | |
BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM_K # | |
) | |
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) | |
_attn_bwd[grid]( | |
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # | |
M, delta, # | |
q.stride(0), q.stride(1), q.stride(2), q.stride(3), # | |
N_HEAD, N_CTX, # | |
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # | |
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # | |
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # | |
HEAD_DIM=ctx.HEAD_DIM_K, # | |
num_warps=NUM_WARPS, # | |
num_stages=NUM_STAGES # | |
) | |
return JVPAttn.BwdOut(dq, dk, dv, None, None, None, None, None, None, None) | |
attention = JVPAttn.fwd | |
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') | |
@pytest.mark.parametrize("Z", [1, 4]) | |
@pytest.mark.parametrize("H", [2, 48]) | |
@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024]) | |
@pytest.mark.parametrize("HEAD_DIM", [64, 128]) | |
@pytest.mark.parametrize("causal", [True]) # FIXME: Non-causal tests do not pass at the moment. | |
@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False]) | |
@pytest.mark.parametrize("mode", ["fwd", "bwd"]) | |
@pytest.mark.parametrize("provider", ["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else [])) | |
def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtype=torch.float16): | |
if mode == "fwd" and "fp16" in provider: | |
pytest.skip("Avoid running the forward computation twice.") | |
if mode == "bwd" and "fp8" in provider: | |
pytest.skip("Backward pass with FP8 is not supported.") | |
torch.manual_seed(20) | |
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) | |
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) | |
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_()) | |
sm_scale = 0.5 | |
# reference implementation | |
ref_dtype = dtype | |
if mode == "fwd" and "fp8" in provider: | |
ref_dtype = torch.float32 | |
q = q.to(ref_dtype) | |
k = k.to(ref_dtype) | |
v = v.to(ref_dtype) | |
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE)) | |
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale | |
if causal: | |
p[:, :, M == 0] = float("-inf") | |
p = torch.softmax(p.float(), dim=-1) | |
p = p.to(ref_dtype) | |
# p = torch.exp(p) | |
ref_out = torch.matmul(p, v).half() | |
if mode == "bwd": | |
dout = torch.randn_like(q) | |
ref_out.backward(dout) | |
ref_dv, v.grad = v.grad.clone(), None | |
ref_dk, k.grad = k.grad.clone(), None | |
ref_dq, q.grad = q.grad.clone(), None | |
# triton implementation | |
if mode == "fwd" and "fp8" in provider: | |
q = q.to(torch.float8_e5m2) | |
k = k.to(torch.float8_e5m2) | |
v = v.permute(0, 1, 3, 2).contiguous() | |
v = v.permute(0, 1, 3, 2) | |
v = v.to(torch.float8_e5m2) | |
tri_out = attention(q, k, v, causal, sm_scale, warp_specialize).half() | |
if mode == "fwd": | |
atol = 3 if "fp8" in provider else 1e-2 | |
torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0) | |
return | |
tri_out.backward(dout) | |
tri_dv, v.grad = v.grad.clone(), None | |
tri_dk, k.grad = k.grad.clone(), None | |
tri_dq, q.grad = q.grad.clone(), None | |
# compare | |
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=0) | |
rtol = 0.0 | |
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU. | |
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices | |
if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a": | |
rtol = 1e-2 | |
torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol) | |
torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol) | |
torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol) | |
try: | |
from flash_attn.flash_attn_interface import \ | |
flash_attn_qkvpacked_func as flash_attn_func | |
HAS_FLASH = True | |
except BaseException: | |
HAS_FLASH = False | |
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2') | |
BATCH, N_HEADS = 4, 32 | |
# vary seq length for fixed head and batch=4 | |
configs = [] | |
for HEAD_DIM in [64, 128]: | |
for mode in ["fwd", "bwd"]: | |
for causal in [True, False]: | |
for warp_specialize in [True, False]: | |
configs.append( | |
triton.testing.Benchmark( | |
x_names=["N_CTX"], | |
x_vals=[2**i for i in range(10, 15)], | |
line_arg="provider", | |
line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) + | |
(["flash"] if HAS_FLASH else []), | |
line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) + | |
(["Flash-2"] if HAS_FLASH else []), | |
styles=[("red", "-"), ("blue", "-"), ("green", "-")], | |
ylabel="TFLOPS", | |
plot_name= | |
f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}-warp_specialize={warp_specialize}", | |
args={ | |
"H": N_HEADS, | |
"BATCH": BATCH, | |
"HEAD_DIM": HEAD_DIM, | |
"mode": mode, | |
"causal": causal, | |
"warp_specialize": warp_specialize, | |
}, | |
)) | |
@triton.testing.perf_report(configs) | |
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, device=DEVICE): | |
assert mode in ["fwd", "bwd"] | |
dtype = torch.float16 | |
if "triton" in provider: | |
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) | |
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) | |
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) | |
if mode == "fwd" and "fp8" in provider: | |
q = q.to(torch.float8_e5m2) | |
k = k.to(torch.float8_e5m2) | |
v = v.permute(0, 1, 3, 2).contiguous() | |
v = v.permute(0, 1, 3, 2) | |
v = v.to(torch.float8_e5m2) | |
sm_scale = 1.3 | |
fn = lambda: attention(q, k, v, causal, sm_scale, warp_specialize) | |
if mode == "bwd": | |
o = fn() | |
do = torch.randn_like(o) | |
fn = lambda: o.backward(do, retain_graph=True) | |
ms = triton.testing.do_bench(fn) | |
if provider == "flash": | |
qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True) | |
fn = lambda: flash_attn_func(qkv, causal=causal) | |
if mode == "bwd": | |
o = fn() | |
do = torch.randn_like(o) | |
fn = lambda: o.backward(do, retain_graph=True) | |
ms = triton.testing.do_bench(fn) | |
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM | |
total_flops = 2 * flops_per_matmul | |
if causal: | |
total_flops *= 0.5 | |
if mode == "bwd": | |
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) | |
return total_flops * 1e-12 / (ms * 1e-3) | |
if __name__ == "__main__": | |
# only works on post-Ampere GPUs right now | |
bench_flash_attention.run(save_path=".", print_data=True) |
@Peterande As far as I could tell, setup_context should not return tensors/values, it should save them via the ctx object.
What happend was not technically a memory leak, i.e. unreferenced allocated memory as can be detected in memcheck, but returned tensor "o" was saved somewhere and the memory was never freed. Which leads to the program exploding in memory over time.
I'm afraid I can't tell you exactly what happens with the returned tensors however.
@benjamin-dinkelmann Got it, thanks! That helps a lot.
thanks @benjamin-dinkelmann for catching the setup_context mistake. I hadn't intended to return o
. I've updated the gist without that return statement now. does this fix the memory leak?
I'm not sure about that change to the backwards pass. we already test allclose() vs reference implementation, and for is_causal=False and True, these assertions are already satisfied at tighter tolerances than flashattn's rtol=1e-3, atol=1e-3
.
I've updated the test script to test causal attention too. can you tell me which comparison becomes more accurate (i.e. which allclose can we tighten up the tolerance on) if your change is applied to the backwards pass?
I'm a bit suspicious of the change, because if this change is required for accuracy, then why does the official triton example for fused attention not include such handling? and what about the multi-stage mask=True, mask=False handling, does that not exist to support causal attention?
from __future__ import annotations
from _06_fused_attention_blockptr_jvp import JVPAttn
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, NamedTuple
import torch
from torch import Tensor, enable_grad
import torch.autograd.forward_ad as fwAD
from torch.nn import MSELoss
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from torch.utils.flop_counter import FlopCounterMode
NiladicFn = Callable[[], None]
def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float:
iters_per_second = 1e3 / ms_per_iter
return iters_per_second * flop_count
def fmt_flops(flops: int) -> str:
return f"{flops / 1e12:5.1f} TFLOP/s"
# Python *please* bring back support for generic NamedTuples
def get_flop_count(f: Callable[[], Any], display_ops=True) -> int:
flop_counter = FlopCounterMode(display=display_ops)
with flop_counter:
f()
return flop_counter.get_total_flops()
class QKV(NamedTuple):
q: Tensor
k: Tensor
v: Tensor
class UnpackedDualQKV(NamedTuple):
primal: QKV
tangent: QKV
@dataclass
class Args:
bsz: int
model_dim: int
head_dim: int
seq_len: int
@staticmethod
def get_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("--bsz", default=1, type=int)
parser.add_argument("--model-dim", default=320, type=int)
parser.add_argument("--head-dim", default=64, type=int)
parser.add_argument("--seq-len", default=128, type=int)
return parser
@staticmethod
def from_namespace(namespace: Namespace) -> Args:
args = Args(**vars(namespace))
return args
def main(args: Args) -> None:
device = torch.device('cuda')
dtype = torch.float16
seed = 42
gen = torch.Generator(device=device)
heads = args.model_dim // args.head_dim
q_p, q_t, k_p, k_t, v_p, v_t, target = (torch.randn(args.bsz, heads, args.seq_len, args.head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(7))
# for t in (q_p, k_p, v_p):
# t.requires_grad = True
# t.retain_grad()
# MSELoss only works for torch.func.jvp(), if we use MSELoss with fwAD invocation, we get this error:
# ZeroTensors are immutable. Please use the materialized zero tensor obtained using .clone() if you want a mutable tensor.
def loss_fn(out: Tensor, target: Tensor) -> Tensor:
return (out - target).square().mean()
def gimme_grads(t: Tensor) -> Tensor:
t.requires_grad = True
t.retain_grad()
return t
def make_qkv(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor) -> QKV:
return QKV(
q=gimme_grads(fwAD.make_dual(q_p, q_t)),
k=gimme_grads(fwAD.make_dual(k_p, k_t)),
v=gimme_grads(fwAD.make_dual(v_p, v_t)),
)
def make_qkv_unpacked(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor) -> UnpackedDualQKV:
return UnpackedDualQKV(
primal=QKV(
q=gimme_grads(q_p),
k=gimme_grads(k_p),
v=gimme_grads(v_p),
),
tangent=QKV(
q=q_t,
k=k_t,
v=v_t,
)
)
for is_causal in (False, True):
print("is_causal:", is_causal)
with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad():
q0, k0, v0 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())
sdpa_out = scaled_dot_product_attention(q0, k0, v0, is_causal=is_causal)
sdpa_out.retain_grad()
sdpa_op, sdpa_ot = fwAD.unpack_dual(sdpa_out)
loss0: Tensor = loss_fn(sdpa_out, target)
loss0.backward()
q1, k1, v1 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())
dual_out = JVPAttn.fwd_dual(q1, k1, v1, causal=is_causal)
dual_out.retain_grad()
dual_op, dual_ot = fwAD.unpack_dual(dual_out)
torch.testing.assert_close(dual_op, sdpa_op, atol=5e-3 if is_causal else 5e-4, rtol=1e-5)
# TODO: improve this accuracy
torch.testing.assert_close(dual_ot, sdpa_ot, atol=5e-3 if is_causal else 1e-3, rtol=1e-5)
loss1: Tensor = loss_fn(dual_out, target)
torch.testing.assert_close(loss1, loss0, atol=5e-4, rtol=1e-5)
loss1.backward()
torch.testing.assert_close(q1.grad, q0.grad, atol=5e-4, rtol=1e-5)
torch.testing.assert_close(k1.grad, k0.grad, atol=5e-4, rtol=1e-5)
torch.testing.assert_close(v1.grad, v0.grad, atol=5e-4, rtol=1e-5)
mse_fn = MSELoss()
with enable_grad():
qkv_p, qkv_t = make_qkv_unpacked(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())
j_p: Tensor
j_t: Tensor
j_p, j_t = torch.func.jvp(partial(JVPAttn.fwd_dual, causal=is_causal), qkv_p, qkv_t)
j_p.retain_grad()
loss2: Tensor = mse_fn(j_p, target)
torch.testing.assert_close(loss2, loss0, atol=5e-4, rtol=1e-5)
loss2.backward()
torch.testing.assert_close(j_p, sdpa_op, atol=1e-3 if is_causal else 5e-4, rtol=1e-5)
# TODO: improve this accuracy
torch.testing.assert_close(j_t, sdpa_ot, atol=5e-3 if is_causal else 1e-3, rtol=1e-5)
q2, k2, v2 = qkv_p
torch.testing.assert_close(q2.grad, q0.grad, atol=5e-4, rtol=1e-5)
torch.testing.assert_close(k2.grad, k0.grad, atol=5e-4, rtol=1e-5)
torch.testing.assert_close(v2.grad, v0.grad, atol=5e-4, rtol=1e-5)
pass
pass
print("is_causal:", is_causal, "passed all assertions.")
if __name__ == "__main__":
parser = Args.get_parser()
args_untyped: Namespace = parser.parse_args()
args: Args = Args.from_namespace(args_untyped)
main(args)
@Birch-san
The memory leak seems fixed. Thank you
I still think my adjustments to the backward are necessary.
However, I am not familiar with the official flash attention example, can you provide a link?
I found this implementation which is pretty close to the one used here. The major difference is that your implementation has a number of steps starting on the diagonal (matching the stage handling in the forward), whereas the implementation I found, always goes through all blocks and masks as necessary (see for instance line 461ff).
The method used here is more efficient for Causal attention as it skips unnecessary blocks directly and only applies masking on diagonal blocks.
Compare to 835-836 selecting start and end of a diagonal block, and then moving through exactly enough steps to cover that block on lines 853-862. Thereafter only the previous blocks in that line are processed.
So I disagree that it is currently working correctly from a theoretical viewpoint.
As for the testing, since we are debating whether the backward is correct, the gradients are the important values.
However, in your example, maximal gradients were around 2e-4 for me, making it practically impossible to violate the tolerances.
One can scale up the loss by about 1e3 to get maximal gradients around 1e-1. (That seemed a good number to me)
Then for the current gist I get (for the path "is_causal=False"),
Mismatched elements: 34845 / 40960 (85.1%)
Greatest absolute difference: 0.057159423828125 at index (0, 4, 63, 45) (up to 0.0005 allowed)
Greatest relative difference: 4208.0 at index (0, 1, 32, 49) (up to 1e-05 allowed)
Whereas the implementation with my proposed changes runs through without issue.
Can you confirm this on your machine maybe?
@benjamin-dinkelmann Hi~Do you mean that the updated version of the code will cause a memory leak? I also encountered a memory leak when trying to write fa2 + jvp myself. Do you know the reason for this? Is removing the return of o in setup_context a direct solution to the problem?