Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active August 11, 2025 08:35
Show Gist options
  • Save Birch-san/c51234fe006cf1ffc680063abb4f572f to your computer and use it in GitHub Desktop.
Save Birch-san/c51234fe006cf1ffc680063abb4f572f to your computer and use it in GitHub Desktop.
Triton fused attention tutorial, updated with JVP support. Albeit with atol=1e-3 accuracy on JVP.
from __future__ import annotations
"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team
Extra Credits:
* Original flash attention paper (https://arxiv.org/abs/2205.14135)
* Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
Plus modifications to support JVP:
- formulation of flash JVP by Cheng Lu, Yang Song in https://arxiv.org/abs/2410.11081
- reference triton implementation by Sofian Mejjoute
- reimplementing reference implementation as an autograd function with latest triton tutorial optimizations, by Alex Birch
- support for forward to receive tangents, so as to compute fwd and jvp together, autograd workaround by Emily (nshepperd)
- support for function transforms (e.g. torch.func.jvp) via the use of setup_context, by Shih-Ying Yeh
"""
from typing import Any, Literal, NamedTuple, Optional
import pytest
import torch
import os
import triton
import triton.language as tl
from torch import Tensor
from torch.autograd import Function
from torch.autograd.function import FunctionCtx
import torch.autograd.forward_ad as fwAD
try:
from triton.tools.tensor_descriptor import TensorDescriptor
HAS_TENSOR_DESC = True
except ModuleNotFoundError:
HAS_TENSOR_DESC = False
DEVICE = getattr(triton.runtime.driver.active, "get_active_torch_device", lambda: torch.device('cuda'))()
def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def supports_host_descriptor():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
def supports_tma():
return HAS_TENSOR_DESC and is_cuda() and torch.cuda.get_device_capability()[0] >= 9
def is_blackwell():
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
@triton.jit
def _attn_fwd_inner(acc, g_acc, #
l_i, m_i, #
mu_i, p_tv_acc, #
q, t_q, #
K_block_ptr, V_block_ptr, #
T_K_block_ptr, T_V_block_ptr, #
dtype: tl.constexpr, start_m, qk_scale, sm_scale, #
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
N_CTX: tl.constexpr, warp_specialize: tl.constexpr, #
ENABLE_JVP: tl.constexpr):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
# NOTE: in fp8 mode, we may want to advance the V_block_ptr differently.
# I did try advancing by (0, lo) instead for fp8, but I got an illegal memory access.
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
if ENABLE_JVP:
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, lo))
T_V_block_ptr = tl.advance(T_V_block_ptr, (lo, 0))
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(K_block_ptr)
qk = tl.dot(q, k)
if ENABLE_JVP:
t_k = tl.load(T_K_block_ptr)
t_qk = tl.dot(t_q, k) + tl.dot(q, t_k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
if ENABLE_JVP:
# Claude says "tangents should be masked with 0.0 since they represent derivatives".
t_qk = tl.where(mask, t_qk, 0.0)
# TODO: do we need a separate row maximum for qk_t?
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
# -- update output accumulator --
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128):
BM: tl.constexpr = acc.shape[0]
BN: tl.constexpr = acc.shape[1]
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
acc0 = acc0 * alpha[:, None]
acc1 = acc1 * alpha[:, None]
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
else:
acc = acc * alpha[:, None]
# update acc
v = tl.load(V_block_ptr)
# NOTE: we may need to transpose v if dtype == tl.float8e5
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31
p = p.to(dtype)
if ENABLE_JVP:
p_tqk = p * (t_qk * sm_scale)
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128):
BM: tl.constexpr = g_acc.shape[0]
BN: tl.constexpr = g_acc.shape[1]
g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
g_acc0 = g_acc0 * alpha[:, None]
g_acc1 = g_acc1 * alpha[:, None]
g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN])
else:
g_acc = g_acc * alpha[:, None]
g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc)
mu_ij = tl.sum(p_tqk, 1)
mu_i = mu_i * alpha + mu_ij
t_v = tl.load(T_V_block_ptr)
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v)
T_V_block_ptr = tl.advance(T_V_block_ptr, (BLOCK_N, 0))
T_K_block_ptr = tl.advance(T_K_block_ptr, (0, BLOCK_N))
acc = tl.dot(p, v, acc)
# update m_i and l_i
m_i = m_ij
# the fp8 PR made a change to how K and V are advanced here but I believe we already have that.
# https://github.com/triton-lang/triton/commit/75d27b0b425329bad8c13b9cd47177d93590ec31
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
return acc, g_acc, l_i, m_i, mu_i, p_tv_acc
@triton.jit
def _attn_fwd_inner_tma(acc, g_acc, #
l_i, m_i, #
mu_i, p_tv_acc, #
q, t_q, #
desc_k, desc_v, #
desc_k_t, desc_v_t, #
offset_y, dtype: tl.constexpr, start_m, qk_scale, sm_scale, #
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, #
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, #
N_CTX: tl.constexpr, warp_specialize: tl.constexpr,
ENABLE_JVP: tl.constexpr):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
# causal = False
else:
lo, hi = 0, N_CTX
offsetk_y = offset_y + lo
if dtype == tl.float8e5:
offsetv_y = offset_y * HEAD_DIM + lo
else:
offsetv_y = offset_y + lo
# loop over k, v and update accumulator
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = desc_k.load([offsetk_y, 0]).T
qk = tl.dot(q, k)
if ENABLE_JVP:
t_k = desc_k_t.load([offsetk_y, 0]).T
t_qk = tl.dot(t_q, k) + tl.dot(q, t_k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk * qk_scale + tl.where(mask, 0, -1.0e6)
if ENABLE_JVP:
# Claude says "tangents should be masked with 0.0 since they represent derivatives".
t_qk = tl.where(mask, t_qk, 0.0)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
qk = qk * qk_scale - m_ij[:, None]
p = tl.math.exp2(qk)
# -- compute correction factor
alpha = tl.math.exp2(m_i - m_ij)
l_ij = tl.sum(p, 1)
# -- update output accumulator --
if warp_specialize and BLOCK_M == 128 and HEAD_DIM == 128:
BM: tl.constexpr = acc.shape[0]
BN: tl.constexpr = acc.shape[1]
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
acc0 = acc0 * alpha[:, None]
acc1 = acc1 * alpha[:, None]
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
else:
acc = acc * alpha[:, None]
# prepare p and v for the dot
if dtype == tl.float8e5:
v = desc_v.load([0, offsetv_y]).T
else:
v = desc_v.load([offsetv_y, 0])
p = p.to(dtype)
if ENABLE_JVP:
p_tqk = p * (t_qk * sm_scale)
# this non-transposed v for FP8 is presumably only supported on Blackwell
if warp_specialize and (BLOCK_M == 128 and HEAD_DIM == 128):
BM: tl.constexpr = g_acc.shape[0]
BN: tl.constexpr = g_acc.shape[1]
g_acc0, g_acc1 = g_acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
g_acc0 = g_acc0 * alpha[:, None]
g_acc1 = g_acc1 * alpha[:, None]
g_acc = tl.join(g_acc0, g_acc1).permute(0, 2, 1).reshape([BM, BN])
else:
g_acc = g_acc * alpha[:, None]
g_acc = tl.dot(p_tqk.to(v.dtype), v, g_acc)
mu_ij = tl.sum(p_tqk, 1)
mu_i = mu_i * alpha + mu_ij
t_v = desc_v_t.load([offsetv_y, 0])
p_tv_acc = p_tv_acc * alpha[:, None] + tl.dot(p, t_v)
# note that this non transposed v for FP8 is only supported on Blackwell
acc = tl.dot(p, v, acc)
# update m_i and l_i
# place this at the end of the loop to reduce register pressure
l_i = l_i * alpha + l_ij
m_i = m_ij
offsetk_y += BLOCK_N
offsetv_y += BLOCK_N
return acc, g_acc, l_i, m_i, mu_i, p_tv_acc
def _host_descriptor_pre_hook(nargs):
BLOCK_M = nargs["BLOCK_M"]
BLOCK_N = nargs["BLOCK_N"]
HEAD_DIM = nargs["HEAD_DIM"]
if not HAS_TENSOR_DESC or not isinstance(nargs["desc_q"], TensorDescriptor):
return
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
if nargs["FP8_OUTPUT"]:
nargs["desc_v"].block_shape = [HEAD_DIM, BLOCK_N]
else:
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
if is_hip():
NUM_STAGES_OPTIONS = [1]
elif supports_host_descriptor():
NUM_STAGES_OPTIONS = [2, 3, 4]
else:
NUM_STAGES_OPTIONS = [2, 3, 4]
configs = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_host_descriptor_pre_hook) \
for BM in [64, 128]\
for BN in [32, 64, 128]\
for s in NUM_STAGES_OPTIONS \
for w in [4, 8]\
]
if "PYTEST_VERSION" in os.environ:
# Use a single config in testing for reproducibility
configs = [
triton.Config(dict(BLOCK_M=128, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook),
]
def keep(conf):
BLOCK_M = conf.kwargs["BLOCK_M"]
BLOCK_N = conf.kwargs["BLOCK_N"]
return not (BLOCK_M * BLOCK_N < 128 * 128 and conf.num_warps == 8)
def prune_invalid_configs(configs, named_args, **kwargs):
N_CTX = kwargs["N_CTX"]
# Filter out configs where BLOCK_M > N_CTX
return [conf for conf in configs if conf.kwargs.get("BLOCK_M", 0) <= N_CTX]
@triton.jit
def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
if isinstance(desc_or_ptr, tl.tensor_descriptor):
return desc_or_ptr
else:
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)
@triton.autotune(configs=list(filter(keep, configs)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
prune_configs_by={'early_config_prune': prune_invalid_configs})
@triton.jit
def _attn_fwd(Q, K, V, T_Q, T_K, T_V, #
sm_scale, M, Out, T_Out, #
stride_qz, stride_qh, stride_qm, stride_qk, #
stride_kz, stride_kh, stride_kn, stride_kk, #
stride_vz, stride_vh, stride_vk, stride_vn, #
stride_tqz, stride_tqh, stride_tqm, stride_tqk, #
stride_tkz, stride_tkh, stride_tkn, stride_tkk, #
stride_tvz, stride_tvh, stride_tvk, stride_tvn, #
stride_oz, stride_oh, stride_om, stride_on, #
stride_toz, stride_toh, stride_tom, stride_ton, #
Z, H, N_CTX, #
HEAD_DIM: tl.constexpr, #
BLOCK_M: tl.constexpr, #
BLOCK_N: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
STAGE: tl.constexpr, #
warp_specialize: tl.constexpr, #
ENABLE_JVP: tl.constexpr, #
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=v_order,
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
if ENABLE_JVP:
# it's extremely likely we could just re-use qvk_offset, but this seems cheap so whatever
t_qvk_offset = off_z.to(tl.int64) * stride_tqz + off_h.to(tl.int64) * stride_tqh
T_Q_block_ptr = tl.make_block_ptr(
base=T_Q + t_qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_tqm, stride_tqk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# could probably just re-use v_order here
t_v_order: tl.constexpr = (0, 1) if T_V.dtype.element_ty == tl.float8e5 else (1, 0)
T_V_block_ptr = tl.make_block_ptr(
base=T_V + t_qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_tvk, stride_tvn),
offsets=(0, 0),
block_shape=(BLOCK_N, HEAD_DIM),
order=t_v_order,
)
T_K_block_ptr = tl.make_block_ptr(
base=T_K + t_qvk_offset,
shape=(HEAD_DIM, N_CTX),
strides=(stride_tkk, stride_tkn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
T_O_block_ptr = tl.make_block_ptr(
base=T_Out + t_qvk_offset,
shape=(N_CTX, HEAD_DIM),
strides=(stride_tom, stride_ton),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
# load q_t: it will stay in SRAM throughout
t_q = tl.load(T_Q_block_ptr)
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32)
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
else:
t_q = None
T_V_block_ptr = None
T_K_block_ptr = None
# Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner
g_acc = tl.zeros([1, 1], dtype=tl.float32)
mu_i = tl.zeros([1], dtype=tl.float32)
p_tv_acc = tl.zeros([1, 1], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale = qk_scale * 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc,
l_i, m_i, #
mu_i, p_tv_acc, #
q, t_q, #
K_block_ptr, V_block_ptr, #
T_K_block_ptr, T_V_block_ptr, #
dtype, start_m, qk_scale, sm_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX, #
warp_specialize,
ENABLE_JVP)
# stage 2: on-band
if STAGE & 2:
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner(acc, g_acc, #
l_i, m_i, #
mu_i, p_tv_acc, #
q, t_q, #
K_block_ptr, V_block_ptr, #
T_K_block_ptr, T_V_block_ptr, #
dtype, start_m, qk_scale, sm_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
2, offs_m, offs_n, N_CTX, #
warp_specialize,
ENABLE_JVP)
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
if ENABLE_JVP:
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc
t_y_out = t_p_v + p_tv_acc / l_i[:, None]
tl.store(T_O_block_ptr, t_y_out.to(T_Out.type.element_ty))
def _tma_pre_hook(nargs):
BLOCK_M = nargs["BLOCK_M"]
BLOCK_N = nargs["BLOCK_N"]
HEAD_DIM = nargs["HEAD_DIM"]
nargs["desc_q"].block_shape = [BLOCK_M, HEAD_DIM]
nargs["desc_v"].block_shape = [BLOCK_N, HEAD_DIM]
nargs["desc_k"].block_shape = [BLOCK_N, HEAD_DIM]
nargs["desc_o"].block_shape = [BLOCK_M, HEAD_DIM]
# We don't run auto-tuning every time to keep the tutorial fast. Keeping
# the code below and commenting out the equivalent parameters is convenient for
# re-tuning.
configs_tma = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w, pre_hook=_tma_pre_hook) \
for BM in [64, 128, 256]\
for BN in [64, 128]\
for s in [3, 4, 5]\
for w in [4, 8]\
]
def keep_tma(conf):
BLOCK_M = conf.kwargs["BLOCK_M"]
BLOCK_N = conf.kwargs["BLOCK_N"]
return not (is_cuda() and torch.cuda.get_device_capability()[0] == 9 and BLOCK_M * BLOCK_N < 128 * 128
and conf.num_warps == 8)
@triton.autotune(configs=list(filter(keep_tma, configs_tma)), key=["N_CTX", "HEAD_DIM", "FP8_OUTPUT", "warp_specialize"],
prune_configs_by={'early_config_prune': prune_invalid_configs})
@triton.jit
def _attn_fwd_tma(sm_scale, M, #
Z, H, #
desc_q, desc_k, desc_v, #
desc_q_t, desc_k_t, desc_v_t, #
desc_o, desc_o_t, #
N_CTX, #
HEAD_DIM: tl.constexpr, #
BLOCK_M: tl.constexpr, #
BLOCK_N: tl.constexpr, #
FP8_OUTPUT: tl.constexpr, #
STAGE: tl.constexpr, #
warp_specialize: tl.constexpr, #
ENABLE_JVP: tl.constexpr, #
):
dtype = tl.float8e5 if FP8_OUTPUT else tl.float16
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
y_dim = Z * H * N_CTX
desc_q = _maybe_make_tensor_desc(desc_q, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM])
if FP8_OUTPUT:
v_shape = [HEAD_DIM, y_dim]
v_strides = [N_CTX, 1]
v_block_shape = [HEAD_DIM, BLOCK_N]
else:
v_shape = [y_dim, HEAD_DIM]
v_strides = [HEAD_DIM, 1]
v_block_shape = [BLOCK_N, HEAD_DIM]
desc_v = _maybe_make_tensor_desc(desc_v, shape=v_shape, strides=v_strides, block_shape=v_block_shape)
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM])
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM])
offset_y = off_z * (N_CTX * H) + off_h * N_CTX
qo_offset_y = offset_y + start_m * BLOCK_M
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
if ENABLE_JVP:
desc_q_t = _maybe_make_tensor_desc(desc_q_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM])
if FP8_OUTPUT:
t_v_shape = [HEAD_DIM, y_dim]
t_v_strides = [N_CTX, 1]
t_v_block_shape = [HEAD_DIM, BLOCK_N]
else:
t_v_shape = [y_dim, HEAD_DIM]
t_v_strides = [HEAD_DIM, 1]
t_v_block_shape = [BLOCK_N, HEAD_DIM]
desc_v_t = _maybe_make_tensor_desc(desc_v_t, shape=t_v_shape, strides=t_v_strides, block_shape=t_v_block_shape)
desc_k_t = _maybe_make_tensor_desc(desc_k_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM])
desc_o_t = _maybe_make_tensor_desc(desc_o_t, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM])
# load t_q: it will stay in SRAM throughout
t_q = desc_q_t.load([qo_offset_y, 0])
g_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
mu_i = tl.zeros([BLOCK_M], dtype=tl.float32)
p_tv_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
else:
t_q = None
desc_k_t = None
desc_v_t = None
# Allocate minimal dummy tensors to keep consistent the return signature of _attn_fwd_inner_tma
g_acc = tl.zeros([1, 1], dtype=tl.float32)
mu_i = tl.zeros([1], dtype=tl.float32)
p_tv_acc = tl.zeros([1, 1], dtype=tl.float32)
# load scales
qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2)
# load q: it will stay in SRAM throughout
q = desc_q.load([qo_offset_y, 0])
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma(acc, g_acc, #
l_i, m_i, #
mu_i, p_tv_acc, #
q, t_q, #
desc_k, desc_v, #
desc_k_t, desc_v_t, #
offset_y, dtype, start_m, qk_scale, sm_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
4 - STAGE, offs_m, offs_n, N_CTX, #
warp_specialize,
ENABLE_JVP)
# stage 2: on-band
if STAGE & 2:
acc, g_acc, l_i, m_i, mu_i, p_tv_acc = _attn_fwd_inner_tma(acc, g_acc, #
l_i, m_i, #
mu_i, p_tv_acc, #
q, t_q, #
desc_k, desc_v, #
desc_k_t, desc_v_t, #
offset_y, dtype, start_m, qk_scale, sm_scale, #
BLOCK_M, HEAD_DIM, BLOCK_N, #
2, offs_m, offs_n, N_CTX, #
warp_specialize,
ENABLE_JVP)
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(m_ptrs, m_i)
desc_o.store([qo_offset_y, 0], acc.to(dtype))
if ENABLE_JVP:
t_p_v = g_acc / l_i[:, None] - (mu_i / l_i)[:, None] * acc
t_y_out = t_p_v + p_tv_acc / l_i[:, None]
desc_o_t.store([qo_offset_y, 0], t_y_out.to(dtype))
@triton.jit
def _attn_bwd_preprocess(O, DO, #
Delta, #
Z, H, N_CTX, #
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
off_n = tl.arange(0, HEAD_DIM)
# load
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_hz * N_CTX + off_m, delta)
# The main inner-loop logic for computing dK and dV.
@triton.jit
def _attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
# shared by Q/K/V/DO.
stride_tok, stride_d, #
H, N_CTX, BLOCK_M1: tl.constexpr, #
BLOCK_N1: tl.constexpr, #
HEAD_DIM: tl.constexpr, #
# Filled in by the wrapper.
start_n, start_m, num_steps, #
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M1)
offs_n = start_n + tl.arange(0, BLOCK_N1)
offs_k = tl.arange(0, HEAD_DIM)
qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
curr_m = start_m
step_m = BLOCK_M1
for blk_idx in range(num_steps):
qT = tl.load(qT_ptrs)
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
qkT = tl.dot(k, qT)
pT = tl.math.exp2(qkT - m[None, :])
# Autoregressive masking.
if MASK:
mask = (offs_m[None, :] >= offs_n[:, None])
pT = tl.where(mask, pT, 0.0)
do = tl.load(do_ptrs)
# Compute dV.
ppT = pT
ppT = ppT.to(tl.float16)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.float16)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
qT_ptrs += step_m * stride_tok
do_ptrs += step_m * stride_tok
return dk, dv
# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq(dq, q, K, V, #
do, m, D,
# shared by Q/K/V/DO.
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2: tl.constexpr, #
BLOCK_N2: tl.constexpr, #
HEAD_DIM: tl.constexpr,
# Filled in by the wrapper.
start_m, start_n, num_steps, #
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M2)
offs_n = start_n + tl.arange(0, BLOCK_N2)
offs_k = tl.arange(0, HEAD_DIM)
kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
curr_n = start_n
step_n = BLOCK_N2
for blk_idx in range(num_steps):
kT = tl.load(kT_ptrs)
vT = tl.load(vT_ptrs)
qk = tl.dot(q, kT)
p = tl.math.exp2(qk - m)
# Autoregressive masking.
if MASK:
offs_n = curr_n + tl.arange(0, BLOCK_N2)
mask = (offs_m[:, None] >= offs_n[None, :])
p = tl.where(mask, p, 0.0)
# Compute dP and dS.
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - Di[:, None])
ds = ds.to(tl.float16)
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
curr_n += step_n
kT_ptrs += step_n * stride_tok
vT_ptrs += step_n * stride_tok
return dq
@triton.jit
def _attn_bwd(Q, K, V, sm_scale, #
DO, #
DQ, DK, DV, #
M, D,
# shared by Q/K/V/DO.
stride_z, stride_h, stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M1: tl.constexpr, #
BLOCK_N1: tl.constexpr, #
BLOCK_M2: tl.constexpr, #
BLOCK_N2: tl.constexpr, #
BLK_SLICE_FACTOR: tl.constexpr, #
HEAD_DIM: tl.constexpr):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
bhid = tl.program_id(2)
off_chz = (bhid * N_CTX).to(tl.int64)
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
pid = tl.program_id(0)
# offset pointers for batch/head
Q += adj
K += adj
V += adj
DO += adj
DQ += adj
DK += adj
DV += adj
M += off_chz
D += off_chz
# load scales
offs_k = tl.arange(0, HEAD_DIM)
start_n = pid * BLOCK_N1
start_m = start_n
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
# load K and V: they stay in SRAM throughout the inner loop.
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
start_n, start_m, num_steps, #
MASK=True #
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv( #
dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M1, BLOCK_N1, HEAD_DIM, #
start_n, start_m, num_steps, #
MASK=False #
)
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dv_ptrs, dv)
# Write back dK.
dk *= sm_scale
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dk_ptrs, dk)
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
m = tl.load(M + offs_m)
m = m[:, None]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
MASK=True #
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, BLOCK_N2, HEAD_DIM, #
start_m, end_n - num_steps * BLOCK_N2, num_steps, #
MASK=False #
)
# Write back dQ.
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
dq *= LN2
tl.store(dq_ptrs, dq)
class JVPAttn(Function):
class Grid(NamedTuple):
M_BLOCKS: int
Z_H: int
ONE: Literal[1]
class FnCtx(FunctionCtx):
sm_scale: float
HEAD_DIM_K: int
causal: bool
grid: JVPAttn.Grid
class FwdOutCtxContrib(NamedTuple):
o_t: Optional[Tensor]
M: Tensor
grid: JVPAttn.Grid
HEAD_DIM_K: int
sm_scale: float
class FwdOut(NamedTuple):
o: Tensor
ctx: JVPAttn.FwdOutCtxContrib
class JVPOut(NamedTuple):
o: Tensor
ctx: None
class BwdOut(NamedTuple):
q: Tensor
k: Tensor
v: Tensor
q_t: None
k_t: None
v_t: None
causal: None
sm_scale: None
warp_specialize: None
USE_TMA: None
class Strides(NamedTuple):
z: int
h: int
n_ctx: int
head_dim: int
@staticmethod
def forward(
q: Tensor,
k: Tensor,
v: Tensor,
q_t: Optional[Tensor],
k_t: Optional[Tensor],
v_t: Optional[Tensor],
causal: bool,
sm_scale: Optional[float],
warp_specialize=True,
USE_TMA=True,
) -> JVPAttn.FwdOut:
# shape constraints
Z, H, N_CTX, HEAD_DIM_Q = q.shape
HEAD_DIM_K = k.shape[-1]
# when v is in float8_e5m2 it is transposed.
HEAD_DIM_V = v.shape[-1]
assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
assert HEAD_DIM_K in {16, 32, 64, 128, 256}
if sm_scale is None:
sm_scale = HEAD_DIM_K**-.5
o = torch.empty_like(q)
ENABLE_JVP = q_t is not None
o_t: Optional[Tensor] = torch.empty_like(q_t) if ENABLE_JVP else None
stage = 3 if causal else 1
extra_kern_args = {}
# Tuning for AMD target
if is_hip():
waves_per_eu = 3 if HEAD_DIM_K <= 64 else 2
extra_kern_args = {"waves_per_eu": waves_per_eu, "allow_flush_denorm": True}
if is_cuda() and warp_specialize:
# we need more registers if we're doing JVP
if (HEAD_DIM_K == 128 and q.dtype == torch.float16) or ENABLE_JVP:
extra_kern_args["maxnreg"] = 168
else:
# TODO: I think for backwards pass of dim=128 this is too low for H100; register allocation fails
extra_kern_args["maxnreg"] = 80
M = torch.empty((Z, H, N_CTX), device=q.device, dtype=torch.float32)
if hasattr(triton, 'set_allocator') and is_cuda():
def alloc_fn(size: int, align: int, _):
return torch.empty(size, dtype=torch.int8, device="cuda")
triton.set_allocator(alloc_fn)
Z_H = Z * H
def grid(META: dict[str, Any]) -> JVPAttn.Grid:
return JVPAttn.Grid(triton.cdiv(N_CTX, META["BLOCK_M"]), Z_H, 1)
# if USE_TMA and supports_tma() and not (torch.cuda.get_device_capability()[0] == 9
# and q.dtype == torch.float8_e5m2):
if USE_TMA and supports_tma():
# Note that on Hopper we cannot perform a FP8 dot with a non-transposed second tensor
y_dim = Z_H * N_CTX
dummy_block = [1, 1]
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_q_t = desc_q if q_t is None else TensorDescriptor(q_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
if q.dtype == torch.float8_e5m2:
v_shape = [HEAD_DIM_K, y_dim]
v_strides = [N_CTX, 1]
else:
v_shape = [y_dim, HEAD_DIM_K]
v_strides = [HEAD_DIM_K, 1]
desc_v = TensorDescriptor(v, shape=v_shape, strides=v_strides, block_shape=dummy_block)
# probably we could share the shape and strides from above, but whatever
if q_t.dtype == torch.float8_e5m2:
t_v_shape = [HEAD_DIM_K, y_dim]
t_v_strides = [q_t.shape[2], 1]
else:
t_v_shape = [y_dim, HEAD_DIM_K]
t_v_strides = [HEAD_DIM_K, 1]
desc_v_t = desc_v if v_t is None else TensorDescriptor(v_t, shape=t_v_shape, strides=t_v_strides, block_shape=dummy_block)
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_k_t = desc_k if k_t is None else TensorDescriptor(k_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
desc_o_t = desc_o if o_t is None else TensorDescriptor(o_t, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
_attn_fwd_tma[grid](
sm_scale, M, #
Z, H, #
desc_q, desc_k, desc_v, #
desc_q_t, desc_k_t, desc_v_t, #
desc_o, desc_o_t, #
N_CTX=N_CTX, #
HEAD_DIM=HEAD_DIM_K, #
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
STAGE=stage, #
warp_specialize=warp_specialize, #
ENABLE_JVP=ENABLE_JVP, #
**extra_kern_args)
else:
def strides_zhnd(t: Tensor) -> JVPAttn.Strides:
return JVPAttn.Strides(t.stride(0), t.stride(1), t.stride(2), t.stride(3))
_attn_fwd[grid](
q, k, v, q_t, k_t, v_t, #
sm_scale, M, o, o_t, #
*strides_zhnd(q), #
*strides_zhnd(k), #
*strides_zhnd(v), #
*strides_zhnd(q if q_t is None else q_t), #
*strides_zhnd(k if k_t is None else k_t), #
*strides_zhnd(v if v_t is None else v_t), #
*strides_zhnd(o), #
*strides_zhnd(o if o_t is None else o_t), #
Z, H, #
N_CTX=N_CTX, #
HEAD_DIM=HEAD_DIM_K, #
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
STAGE=stage, #
warp_specialize=warp_specialize, #
ENABLE_JVP=ENABLE_JVP, #
**extra_kern_args)
return JVPAttn.FwdOut(o, JVPAttn.FwdOutCtxContrib(o_t, M, grid, HEAD_DIM_K, sm_scale))
@staticmethod
def setup_context(ctx: JVPAttn.FnCtx, inputs, outputs: JVPAttn.FwdOut) -> Tensor:
(
q,
k,
v,
q_t,
k_t,
v_t,
causal,
sm_scale,
warp_specialize,
USE_TMA
) = inputs
o, (o_t, M, grid, HEAD_DIM_K, sm_scale) = outputs
ctx.grid = grid
ctx.save_for_forward(o_t)
ctx.save_for_backward(q, k, v, o, M)
ctx.sm_scale = sm_scale
ctx.HEAD_DIM_K = HEAD_DIM_K
ctx.causal = causal
@staticmethod
def fwd(
q: Tensor,
k: Tensor,
v: Tensor,
causal = False,
sm_scale: Optional[float] = None,
warp_specialize=True,
USE_TMA=True,
) -> Tensor:
"""
This is not an autograd convention, it's a workaround to get type-hinting and kwarg support
"""
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, None, None, None, causal, sm_scale, warp_specialize, USE_TMA)
a, _ = out
return a
@staticmethod
def fwd_dual(
q: Tensor,
k: Tensor,
v: Tensor,
causal = False,
sm_scale: Optional[float] = None,
warp_specialize=True,
USE_TMA=True,
) -> Tensor:
"""
This is not an autograd convention, it's a workaround for invoking
JVPAttn::forward with the right arguments when you have a dual tensor input.
"""
q_p, q_t = fwAD.unpack_dual(q)
k_p, k_t = fwAD.unpack_dual(k)
v_p, v_t = fwAD.unpack_dual(v)
# we pass some dualtensor args to ensure jvp() will be called
# but we also pass tangents separately, as forward() demotes dual tensor args to primals for some reason
out: JVPAttn.FwdOut = JVPAttn.apply(q, k, v, q_t, k_t, v_t, causal, sm_scale, warp_specialize, USE_TMA)
a, _ = out
return a
@staticmethod
def jvp(ctx: JVPAttn.FnCtx, gq: Tensor, gk: Tensor, gv: Tensor, *_) -> JVPAttn.JVPOut:
return JVPAttn.JVPOut(ctx.saved_for_forward[0], None)
@staticmethod
def backward(ctx, do, _) -> JVPAttn.BwdOut:
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
PRE_BLOCK = 128
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do, #
delta, #
BATCH, N_HEAD, N_CTX, #
BLOCK_M=PRE_BLOCK, HEAD_DIM=ctx.HEAD_DIM_K #
)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, #
M, delta, #
q.stride(0), q.stride(1), q.stride(2), q.stride(3), #
N_HEAD, N_CTX, #
BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, #
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM_K, #
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES #
)
return JVPAttn.BwdOut(dq, dk, dv, None, None, None, None, None, None, None)
attention = JVPAttn.fwd
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
@pytest.mark.parametrize("Z", [1, 4])
@pytest.mark.parametrize("H", [2, 48])
@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024])
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
@pytest.mark.parametrize("causal", [True]) # FIXME: Non-causal tests do not pass at the moment.
@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False])
@pytest.mark.parametrize("mode", ["fwd", "bwd"])
@pytest.mark.parametrize("provider", ["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []))
def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, dtype=torch.float16):
if mode == "fwd" and "fp16" in provider:
pytest.skip("Avoid running the forward computation twice.")
if mode == "bwd" and "fp8" in provider:
pytest.skip("Backward pass with FP8 is not supported.")
torch.manual_seed(20)
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())
sm_scale = 0.5
# reference implementation
ref_dtype = dtype
if mode == "fwd" and "fp8" in provider:
ref_dtype = torch.float32
q = q.to(ref_dtype)
k = k.to(ref_dtype)
v = v.to(ref_dtype)
M = torch.tril(torch.ones((N_CTX, N_CTX), device=DEVICE))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1)
p = p.to(ref_dtype)
# p = torch.exp(p)
ref_out = torch.matmul(p, v).half()
if mode == "bwd":
dout = torch.randn_like(q)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# triton implementation
if mode == "fwd" and "fp8" in provider:
q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
v = v.permute(0, 1, 3, 2).contiguous()
v = v.permute(0, 1, 3, 2)
v = v.to(torch.float8_e5m2)
tri_out = attention(q, k, v, causal, sm_scale, warp_specialize).half()
if mode == "fwd":
atol = 3 if "fp8" in provider else 1e-2
torch.testing.assert_close(tri_out, ref_out, atol=atol, rtol=0)
return
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=0)
rtol = 0.0
# Relative tolerance workaround for known hardware limitation of CDNA2 GPU.
# For details see https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices
if torch.version.hip is not None and triton.runtime.driver.active.get_current_target().arch == "gfx90a":
rtol = 1e-2
torch.testing.assert_close(tri_dv, ref_dv, atol=1e-2, rtol=rtol)
torch.testing.assert_close(tri_dk, ref_dk, atol=1e-2, rtol=rtol)
torch.testing.assert_close(tri_dq, ref_dq, atol=1e-2, rtol=rtol)
try:
from flash_attn.flash_attn_interface import \
flash_attn_qkvpacked_func as flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
BATCH, N_HEADS = 4, 32
# vary seq length for fixed head and batch=4
configs = []
for HEAD_DIM in [64, 128]:
for mode in ["fwd", "bwd"]:
for causal in [True, False]:
for warp_specialize in [True, False]:
configs.append(
triton.testing.Benchmark(
x_names=["N_CTX"],
x_vals=[2**i for i in range(10, 15)],
line_arg="provider",
line_vals=["triton-fp16"] + (["triton-fp8"] if TORCH_HAS_FP8 else []) +
(["flash"] if HAS_FLASH else []),
line_names=["Triton [FP16]"] + (["Triton [FP8]"] if TORCH_HAS_FP8 else []) +
(["Flash-2"] if HAS_FLASH else []),
styles=[("red", "-"), ("blue", "-"), ("green", "-")],
ylabel="TFLOPS",
plot_name=
f"fused-attention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-{mode}-causal={causal}-warp_specialize={warp_specialize}",
args={
"H": N_HEADS,
"BATCH": BATCH,
"HEAD_DIM": HEAD_DIM,
"mode": mode,
"causal": causal,
"warp_specialize": warp_specialize,
},
))
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, HEAD_DIM, causal, warp_specialize, mode, provider, device=DEVICE):
assert mode in ["fwd", "bwd"]
dtype = torch.float16
if "triton" in provider:
q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
if mode == "fwd" and "fp8" in provider:
q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
v = v.permute(0, 1, 3, 2).contiguous()
v = v.permute(0, 1, 3, 2)
v = v.to(torch.float8_e5m2)
sm_scale = 1.3
fn = lambda: attention(q, k, v, causal, sm_scale, warp_specialize)
if mode == "bwd":
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn)
if provider == "flash":
qkv = torch.randn((BATCH, N_CTX, 3, H, HEAD_DIM), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, causal=causal)
if mode == "bwd":
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn)
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
if mode == "bwd":
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops * 1e-12 / (ms * 1e-3)
if __name__ == "__main__":
# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path=".", print_data=True)
@benjamin-dinkelmann
Copy link

Hi,
first of all, many thanks for this implementation. Working implementation JVP + attention seem to be very rare at the moment.
I don't know if anyone else tried this with many iterations.
I would just like to point out that the current gist, unlike the alterations by @KohakuBlueleaf leaks memory like crazy in my tests.
Or rather, exactly the memory of the base output (o) tensor in every iteration.

Specifically, it occurs if setup_context returns this tensor.

Maybe this will help someone else.

@benjamin-dinkelmann
Copy link

benjamin-dinkelmann commented Jul 30, 2025

Also, as noted in the tests, only causal attention seems to work in the backward pass.
That is because the backward pass does not consider whether the forward pass was causal or not.
Rather it is specialized on the causal case, with the way blocks are processed.
However, with the slight changes as below, I can get similar or better accuracy on non-causal gradients.

If there is anything wrong with this approach, please let me know, as I have no prior experience with triton.

@@ -759,7 +762,8 @@ def _attn_bwd(Q, K, V, sm_scale,  #
               BLOCK_M2: tl.constexpr,  #
               BLOCK_N2: tl.constexpr,  #
               BLK_SLICE_FACTOR: tl.constexpr,  #
-              HEAD_DIM: tl.constexpr):
+              HEAD_DIM: tl.constexpr,
+              causal: tl.constexpr):
     LN2: tl.constexpr = 0.6931471824645996  # = ln(2)

     bhid = tl.program_id(2)
@@ -782,7 +786,6 @@ def _attn_bwd(Q, K, V, sm_scale,  #
     offs_k = tl.arange(0, HEAD_DIM)

     start_n = pid * BLOCK_N1
-    start_m = start_n

     MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
     offs_n = start_n + tl.arange(0, BLOCK_N1)
@@ -793,21 +796,24 @@ def _attn_bwd(Q, K, V, sm_scale,  #
     # load K and V: they stay in SRAM throughout the inner loop.
     k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
     v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
-
-    num_steps = BLOCK_N1 // MASK_BLOCK_M1
-
-    dk, dv = _attn_bwd_dkdv(dk, dv,  #
-                            Q, k, v, sm_scale,  #
-                            DO,  #
-                            M, D,  #
-                            stride_tok, stride_d,  #
-                            H, N_CTX,  #
-                            MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM,  #
-                            start_n, start_m, num_steps,  #
-                            MASK=True  #
-                            )
-
-    start_m += num_steps * MASK_BLOCK_M1
+    if causal:
+        start_m = start_n
+        num_steps = BLOCK_N1 // MASK_BLOCK_M1
+
+        dk, dv = _attn_bwd_dkdv(dk, dv,  #
+                                Q, k, v, sm_scale,  #
+                                DO,  #
+                                M, D,  #
+                                stride_tok, stride_d,  #
+                                H, N_CTX,  #
+                                MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM,  #
+                                start_n, start_m, num_steps,  #
+                                MASK=True  #
+                                )
+
+        start_m += num_steps * MASK_BLOCK_M1
+    else:
+        start_m = 0
     num_steps = (N_CTX - start_m) // BLOCK_M1

     # Compute dK and dV for non-masked blocks.
@@ -833,7 +839,6 @@ def _attn_bwd(Q, K, V, sm_scale,  #

     # THIS BLOCK DOES DQ:
     start_m = pid * BLOCK_M2
-    end_n = start_m + BLOCK_M2

     MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
     offs_m = start_m + tl.arange(0, BLOCK_M2)
@@ -850,16 +855,21 @@ def _attn_bwd(Q, K, V, sm_scale,  #
     # but inside each call to _attn_bwd_dq, from left to right), but that's
     # not due to anything important.  I just wanted to reuse the loop
     # structure for dK & dV above as much as possible.
-    num_steps = BLOCK_M2 // MASK_BLOCK_N2
-    dq = _attn_bwd_dq(dq, q, K, V,  #
-                      do, m, D,  #
-                      stride_tok, stride_d,  #
-                      H, N_CTX,  #
-                      BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM,  #
-                      start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,  #
-                      MASK=True  #
-                      )
-    end_n -= num_steps * MASK_BLOCK_N2
+    if causal:
+        end_n = start_m + BLOCK_M2
+        num_steps = BLOCK_M2 // MASK_BLOCK_N2
+        dq = _attn_bwd_dq(dq, q, K, V,  #
+                          do, m, D,  #
+                          stride_tok, stride_d,  #
+                          H, N_CTX,  #
+                          BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM,  #
+                          start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,  #
+                          MASK=True  #
+                          )
+        end_n -= num_steps * MASK_BLOCK_N2
+    else:
+        end_n = N_CTX
+
     # stage 2
     num_steps = end_n // BLOCK_N2
     dq = _attn_bwd_dq(dq, q, K, V,  #



@@ -1143,7 +1161,8 @@ class JVPAttn(Function):
             BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,  #
             HEAD_DIM=ctx.HEAD_DIM_K,  #
             num_warps=NUM_WARPS,  #
-            num_stages=NUM_STAGES  #
+            num_stages=NUM_STAGES,  #
+            causal=ctx.causal,
         )

Edit: Oh and obviously, the backward function of JVPAttn needs to supply the causal flag from the context.

@Peterande
Copy link

Peterande commented Aug 4, 2025

@benjamin-dinkelmann Hi~Do you mean that the updated version of the code will cause a memory leak? I also encountered a memory leak when trying to write fa2 + jvp myself. Do you know the reason for this? Is removing the return of  o  in  setup_context  a direct solution to the problem?

@benjamin-dinkelmann
Copy link

@Peterande As far as I could tell, setup_context should not return tensors/values, it should save them via the ctx object.
What happend was not technically a memory leak, i.e. unreferenced allocated memory as can be detected in memcheck, but returned tensor "o" was saved somewhere and the memory was never freed. Which leads to the program exploding in memory over time.
I'm afraid I can't tell you exactly what happens with the returned tensors however.

@Peterande
Copy link

@benjamin-dinkelmann Got it, thanks! That helps a lot.

@Birch-san
Copy link
Author

Birch-san commented Aug 5, 2025

thanks @benjamin-dinkelmann for catching the setup_context mistake. I hadn't intended to return o. I've updated the gist without that return statement now. does this fix the memory leak?

I'm not sure about that change to the backwards pass. we already test allclose() vs reference implementation, and for is_causal=False and True, these assertions are already satisfied at tighter tolerances than flashattn's rtol=1e-3, atol=1e-3.

I've updated the test script to test causal attention too. can you tell me which comparison becomes more accurate (i.e. which allclose can we tighten up the tolerance on) if your change is applied to the backwards pass?

I'm a bit suspicious of the change, because if this change is required for accuracy, then why does the official triton example for fused attention not include such handling? and what about the multi-stage mask=True, mask=False handling, does that not exist to support causal attention?

from __future__ import annotations
from _06_fused_attention_blockptr_jvp import JVPAttn
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, NamedTuple
import torch
from torch import Tensor, enable_grad
import torch.autograd.forward_ad as fwAD
from torch.nn import MSELoss
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from torch.utils.flop_counter import FlopCounterMode

NiladicFn = Callable[[], None]

def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float:
    iters_per_second = 1e3 / ms_per_iter
    return iters_per_second * flop_count


def fmt_flops(flops: int) -> str:
    return f"{flops / 1e12:5.1f} TFLOP/s"

# Python *please* bring back support for generic NamedTuples
def get_flop_count(f: Callable[[], Any], display_ops=True) -> int:
    flop_counter = FlopCounterMode(display=display_ops)
    with flop_counter:
        f()
    return flop_counter.get_total_flops()

class QKV(NamedTuple):
    q: Tensor
    k: Tensor
    v: Tensor

class UnpackedDualQKV(NamedTuple):
    primal: QKV
    tangent: QKV

@dataclass
class Args:
    bsz: int
    model_dim: int
    head_dim: int
    seq_len: int

    @staticmethod
    def get_parser() -> ArgumentParser:
        parser = ArgumentParser()
        parser.add_argument("--bsz", default=1, type=int)
        parser.add_argument("--model-dim", default=320, type=int)
        parser.add_argument("--head-dim", default=64, type=int)
        parser.add_argument("--seq-len", default=128, type=int)
        return parser

    @staticmethod
    def from_namespace(namespace: Namespace) -> Args:
        args = Args(**vars(namespace))
        return args
    
def main(args: Args) -> None:
    device = torch.device('cuda')
    dtype = torch.float16
    seed = 42
    gen = torch.Generator(device=device)

    heads = args.model_dim // args.head_dim
    q_p, q_t, k_p, k_t, v_p, v_t, target = (torch.randn(args.bsz, heads, args.seq_len, args.head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(7))
    # for t in (q_p, k_p, v_p):
    #     t.requires_grad = True
    #     t.retain_grad()

    # MSELoss only works for torch.func.jvp(), if we use MSELoss with fwAD invocation, we get this error:
    # ZeroTensors are immutable. Please use the materialized zero tensor obtained using .clone() if you want a mutable tensor.
    def loss_fn(out: Tensor, target: Tensor) -> Tensor:
        return (out - target).square().mean()

    def gimme_grads(t: Tensor) -> Tensor:
        t.requires_grad = True
        t.retain_grad()
        return t

    def make_qkv(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor) -> QKV:
        return QKV(
            q=gimme_grads(fwAD.make_dual(q_p, q_t)),
            k=gimme_grads(fwAD.make_dual(k_p, k_t)),
            v=gimme_grads(fwAD.make_dual(v_p, v_t)),
        )

    def make_qkv_unpacked(q_p: Tensor, k_p: Tensor, v_p: Tensor, q_t: Tensor, k_t: Tensor, v_t: Tensor) -> UnpackedDualQKV:
        return UnpackedDualQKV(
            primal=QKV(
                q=gimme_grads(q_p),
                k=gimme_grads(k_p),
                v=gimme_grads(v_p),
            ),
            tangent=QKV(
                q=q_t,
                k=k_t,
                v=v_t,
            )
        )

    for is_causal in (False, True):
        print("is_causal:", is_causal)
        with sdpa_kernel(SDPBackend.MATH), fwAD.dual_level(), enable_grad():
            q0, k0, v0 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())

            sdpa_out = scaled_dot_product_attention(q0, k0, v0, is_causal=is_causal)
            sdpa_out.retain_grad()
            sdpa_op, sdpa_ot = fwAD.unpack_dual(sdpa_out)

            loss0: Tensor = loss_fn(sdpa_out, target)
            loss0.backward()

            q1, k1, v1 = make_qkv(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())

            dual_out = JVPAttn.fwd_dual(q1, k1, v1, causal=is_causal)
            dual_out.retain_grad()
            dual_op, dual_ot = fwAD.unpack_dual(dual_out)

            torch.testing.assert_close(dual_op, sdpa_op, atol=5e-3 if is_causal else 5e-4, rtol=1e-5)
            # TODO: improve this accuracy
            torch.testing.assert_close(dual_ot, sdpa_ot, atol=5e-3 if is_causal else 1e-3, rtol=1e-5)

            loss1: Tensor = loss_fn(dual_out, target)
            torch.testing.assert_close(loss1, loss0, atol=5e-4, rtol=1e-5)
            loss1.backward()
            torch.testing.assert_close(q1.grad, q0.grad, atol=5e-4, rtol=1e-5)
            torch.testing.assert_close(k1.grad, k0.grad, atol=5e-4, rtol=1e-5)
            torch.testing.assert_close(v1.grad, v0.grad, atol=5e-4, rtol=1e-5)

        mse_fn = MSELoss()
        with enable_grad():
            qkv_p, qkv_t = make_qkv_unpacked(q_p.clone(), k_p.clone(), v_p.clone(), q_t.clone(), k_t.clone(), v_t.clone())
            j_p: Tensor
            j_t: Tensor
            j_p, j_t = torch.func.jvp(partial(JVPAttn.fwd_dual, causal=is_causal), qkv_p, qkv_t)
            j_p.retain_grad()
            loss2: Tensor = mse_fn(j_p, target)
            torch.testing.assert_close(loss2, loss0, atol=5e-4, rtol=1e-5)
            loss2.backward()
            torch.testing.assert_close(j_p, sdpa_op, atol=1e-3 if is_causal else 5e-4, rtol=1e-5)
            # TODO: improve this accuracy
            torch.testing.assert_close(j_t, sdpa_ot, atol=5e-3 if is_causal else 1e-3, rtol=1e-5)
            q2, k2, v2 = qkv_p
            torch.testing.assert_close(q2.grad, q0.grad, atol=5e-4, rtol=1e-5)
            torch.testing.assert_close(k2.grad, k0.grad, atol=5e-4, rtol=1e-5)
            torch.testing.assert_close(v2.grad, v0.grad, atol=5e-4, rtol=1e-5)
            pass
        pass
        print("is_causal:", is_causal, "passed all assertions.")

if __name__ == "__main__":
    parser = Args.get_parser()
    args_untyped: Namespace = parser.parse_args()
    args: Args = Args.from_namespace(args_untyped)
    main(args)

@benjamin-dinkelmann
Copy link

benjamin-dinkelmann commented Aug 11, 2025

@Birch-san
The memory leak seems fixed. Thank you

I still think my adjustments to the backward are necessary.
However, I am not familiar with the official flash attention example, can you provide a link?
I found this implementation which is pretty close to the one used here. The major difference is that your implementation has a number of steps starting on the diagonal (matching the stage handling in the forward), whereas the implementation I found, always goes through all blocks and masks as necessary (see for instance line 461ff).
The method used here is more efficient for Causal attention as it skips unnecessary blocks directly and only applies masking on diagonal blocks.
Compare to 835-836 selecting start and end of a diagonal block, and then moving through exactly enough steps to cover that block on lines 853-862. Thereafter only the previous blocks in that line are processed.
So I disagree that it is currently working correctly from a theoretical viewpoint.

As for the testing, since we are debating whether the backward is correct, the gradients are the important values.
However, in your example, maximal gradients were around 2e-4 for me, making it practically impossible to violate the tolerances.
One can scale up the loss by about 1e3 to get maximal gradients around 1e-1. (That seemed a good number to me)
Then for the current gist I get (for the path "is_causal=False"),

Mismatched elements: 34845 / 40960 (85.1%)
Greatest absolute difference: 0.057159423828125 at index (0, 4, 63, 45) (up to 0.0005 allowed)
Greatest relative difference: 4208.0 at index (0, 1, 32, 49) (up to 1e-05 allowed)

Whereas the implementation with my proposed changes runs through without issue.
Can you confirm this on your machine maybe?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment