Skip to content

Instantly share code, notes, and snippets.

@chu-tianxiang
Last active March 8, 2024 02:00
Show Gist options
  • Save chu-tianxiang/4307937fd94b49c75b61a6967716bae9 to your computer and use it in GitHub Desktop.
Save chu-tianxiang/4307937fd94b49c75b61a6967716bae9 to your computer and use it in GitHub Desktop.
triton implementation of ReRope
# Adapted from the triton implementation of flash-attention v2
# https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py
import time
import torch
import torch.utils.benchmark as benchmark
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q1, Q2, K1, K2, V, sm_scale,
L,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
WINDOW: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
q_offset = off_hz * stride_qh
kv_offset = off_hz * stride_kh
Q1_block_ptr = tl.make_block_ptr(
base=Q1 + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
Q2_block_ptr = tl.make_block_ptr(
base=Q2 + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K1_block_ptr = tl.make_block_ptr(
base=K1 + kv_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
K2_block_ptr = tl.make_block_ptr(
base=K2 + kv_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + kv_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q1 = tl.load(Q1_block_ptr)
q1 = (q1 * qk_scale).to(tl.float16)
q2 = tl.load(Q2_block_ptr)
q2 = (q2 * qk_scale).to(tl.float16)
# loop over k, v and update accumulator
lo = 0
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
for start_n in range(lo, hi, BLOCK_N):
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
if IS_CAUSAL:
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
if start_n <= start_m * BLOCK_M - WINDOW - BLOCK_N or start_n >= (start_m + 1) * BLOCK_M + WINDOW:
k2 = tl.load(K2_block_ptr)
v = tl.load(V_block_ptr)
qk += tl.dot(q2, k2, out_dtype=tl.float16)
elif start_n > (start_m + 1) * BLOCK_M - WINDOW and start_n < start_m * BLOCK_M + WINDOW - BLOCK_N:
k1 = tl.load(K1_block_ptr)
v = tl.load(V_block_ptr)
qk += tl.dot(q1, k1, out_dtype=tl.float16)
else:
k1 = tl.load(K1_block_ptr)
k2 = tl.load(K2_block_ptr)
v = tl.load(V_block_ptr)
qk1 = tl.dot(q1, k1, out_dtype=tl.float16)
qk2 = tl.dot(q2, k2, out_dtype=tl.float16)
qk += tl.where(tl.abs(offs_m[:, None] - (start_n + offs_n[None, :])) < WINDOW, qk1, qk2)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(tl.float16), v)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
K1_block_ptr = tl.advance(K1_block_ptr, (0, BLOCK_N))
K2_block_ptr = tl.advance(K2_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc / l_i[:, None]
l_ptrs = L + off_hz * N_CTX + offs_m
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(O_block_ptr, acc.to(tl.float16))
@triton.jit
def _bwd_preprocess(
Out, DO,
Delta,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
# compute
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel(
Q1, Q2, K1, K2, V, sm_scale, Out, DO,
DQ1, DQ2, DK1, DK2, DV,
L,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block_q, num_block_kv,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
CAUSAL: tl.constexpr,
WINDOW: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
qk_scale = sm_scale * 1.44269504
# offset pointers for batch/head
Q1 += off_z * stride_qz + off_h * stride_qh
Q2 += off_z * stride_qz + off_h * stride_qh
K1 += off_z * stride_kz + off_h * stride_kh
K2 += off_z * stride_kz + off_h * stride_kh
V += off_z * stride_vz + off_h * stride_vh
DO += off_z * stride_qz + off_h * stride_qh
DQ1 += off_z * stride_qz + off_h * stride_qh
DQ2 += off_z * stride_qz + off_h * stride_qh
DK1 += off_z * stride_kz + off_h * stride_kh
DK2 += off_z * stride_kz + off_h * stride_kh
DV += off_z * stride_vz + off_h * stride_vh
for start_n in range(0, num_block_kv):
if CAUSAL:
lo = tl.math.max(start_n * BLOCK_N, 0)
else:
lo = 0
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_m = tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q1_ptrs = Q1 + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
q2_ptrs = Q2 + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k1_ptrs = K1 + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
k2_ptrs = K2 + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq1_ptrs = DQ1 + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq2_ptrs = DQ2 + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
# initialize dk amd dv
dk1 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k1 = tl.load(k1_ptrs)
k2 = tl.load(k2_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q1 = tl.load(q1_ptrs)
q2 = tl.load(q2_ptrs)
# recompute p = softmax(qk, dim=-1).T
if CAUSAL:
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
else:
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if start_m >= (start_n + 1) * BLOCK_N + WINDOW or start_m <= start_n * BLOCK_N - WINDOW - BLOCK_M:
q2 = tl.load(q2_ptrs)
qk += tl.dot(q2, tl.trans(k2))
elif start_m > (start_n + 1) * BLOCK_N - WINDOW and start_m < start_n * BLOCK_N + WINDOW - BLOCK_M:
q1 = tl.load(q1_ptrs)
qk += tl.dot(q1, tl.trans(k1))
else:
q1 = tl.load(q1_ptrs)
q2 = tl.load(q2_ptrs)
qk1 = tl.dot(q1, tl.trans(k1))
qk2 = tl.dot(q2, tl.trans(k2))
qk += tl.where(tl.abs(offs_m_curr[:, None] - offs_n[None, :]) < WINDOW, qk1, qk2)
qk *= qk_scale
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q1.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
if start_m >= (start_n + 1) * BLOCK_N + WINDOW or start_m <= start_n * BLOCK_N - WINDOW - BLOCK_M:
dk2 += tl.dot(tl.trans(ds.to(Q1.dtype.element_ty)), q2)
dq2 = tl.load(dq2_ptrs)
dq2 += tl.dot(ds.to(Q1.dtype.element_ty), k2)
tl.store(dq2_ptrs, dq2)
elif start_m > (start_n + 1) * BLOCK_N - WINDOW and start_m < start_n * BLOCK_N + WINDOW - BLOCK_M:
dk1 += tl.dot(tl.trans(ds.to(Q1.dtype.element_ty)), q1)
dq1 = tl.load(dq1_ptrs)
dq1 += tl.dot(ds.to(Q1.dtype.element_ty), k1)
tl.store(dq1_ptrs, dq1)
else:
mask = (tl.abs(offs_m_curr[:, None] - offs_n[None, :]) < WINDOW)
ds1 = tl.where(mask, ds, float(0.))
ds2 = tl.where(mask, float(0.), ds)
dk1 += tl.dot(tl.trans(ds1.to(Q1.dtype.element_ty)), q1)
dk2 += tl.dot(tl.trans(ds2.to(Q1.dtype.element_ty)), q2)
dq1 = tl.load(dq1_ptrs)
dq2 = tl.load(dq2_ptrs)
dq1 += tl.dot(ds1.to(Q1.dtype.element_ty), k1)
dq2 += tl.dot(ds2.to(Q1.dtype.element_ty), k2)
tl.store(dq1_ptrs, dq1)
tl.store(dq2_ptrs, dq2)
# increment pointers
dq1_ptrs += BLOCK_M * stride_qm
dq2_ptrs += BLOCK_M * stride_qm
q1_ptrs += BLOCK_M * stride_qm
q2_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dk1_ptrs = DK1 + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
dk2_ptrs = DK2 + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
tl.store(dk1_ptrs, dk1)
tl.store(dk2_ptrs, dk2)
tl.store(dv_ptrs, dv)
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q1, q2, k1, k2, v, causal, sm_scale, window):
# shape constraints
Lq, Lk, Lv = q1.shape[-1], k1.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q1)
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4
grid = (triton.cdiv(q1.shape[2], BLOCK_M), q1.shape[0] * q1.shape[1], 1)
L = torch.empty((q1.shape[0] * q1.shape[1], q1.shape[2]), device=q1.device, dtype=torch.float32)
_fwd_kernel[grid](
q1, q2, k1, k2, v, sm_scale,
L,
o,
q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3),
k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q1.shape[0], q1.shape[1], q1.shape[2],
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal, WINDOW=window,
num_warps=num_warps,
num_stages=num_stages)
ctx.save_for_backward(q1, q2, k1, k2, v, o, L)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
ctx.window = window
return o
@staticmethod
def backward(ctx, do):
BLOCK = 128
q1, q2, k1, k2, v, o, L = ctx.saved_tensors
do = do.contiguous()
dq1 = torch.zeros_like(q1, dtype=torch.float32)
dq2 = torch.zeros_like(q2, dtype=torch.float32)
dk1 = torch.empty_like(k1)
dk2 = torch.empty_like(k2)
dv = torch.empty_like(v)
delta = torch.empty_like(L)
_bwd_preprocess[(triton.cdiv(q1.shape[2], BLOCK) * ctx.grid[1], )](
o, do,
delta,
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
_bwd_kernel[(ctx.grid[1],)](
q1, q2, k1, k2, v, ctx.sm_scale,
o, do,
dq1, dq2, dk1, dk2, dv,
L, delta,
q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3),
k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q1.shape[0], q1.shape[1], q1.shape[2],
triton.cdiv(q1.shape[2], BLOCK), triton.cdiv(k1.shape[2], BLOCK),
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
CAUSAL=ctx.causal, WINDOW=ctx.window,
num_stages=1,
)
return dq1, dq2, dk1, dk2, dv, None, None, None
triton_attention = _attention.apply
def torch_attention(q1, q2, k1, k2, v, causal, sm_scale, window):
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p1 = torch.matmul(q1, k1.transpose(2, 3)) * sm_scale
p2 = torch.matmul(q2, k2.transpose(2, 3)) * sm_scale
if causal:
p1[:, :, M == 0] = float("-inf")
p2[:, :, M == 0] = float("-inf")
x = torch.arange(N_CTX, dtype=torch.int, device="cuda")
M2 = ((x[:, None] - x[None, :]).abs() < window)[None, None, :]
p = torch.where(M2, p1, p2)
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
return ref_out
Z = 1
H = 32
N_CTX = 8192
# currently backward is VERY slow for d_head = 128
# https://github.com/openai/triton/issues/1975
D_HEAD = 64
WINDOW = 2048
sm_scale = 0.5
q1 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
q2 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k1 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k2 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
grad = torch.randn_like(q1)
torch_output = torch_attention(q1, q2, k1, k2, v, False, sm_scale, WINDOW)
torch_output.backward(grad)
torch_dv, v.grad = v.grad.clone(), None
torch_dk1, k1.grad = k1.grad.clone(), None
torch_dk2, k2.grad = k2.grad.clone(), None
torch_dq1, q1.grad = q1.grad.clone(), None
torch_dq2, q2.grad = q2.grad.clone(), None
triton_output = triton_attention(q1, q2, k1, k2, v, False, sm_scale, WINDOW)
triton_output.backward(grad)
triton_dv, v.grad = v.grad.clone(), None
triton_dk1, k1.grad = k1.grad.clone(), None
triton_dk2, k2.grad = k2.grad.clone(), None
triton_dq1, q1.grad = q1.grad.clone(), None
triton_dq2, q2.grad = q2.grad.clone(), None
assert torch.allclose(torch_output, triton_output, atol=2e-2, rtol=0)
assert torch.allclose(torch_dv, triton_dv, atol=1e-2, rtol=0)
assert torch.allclose(torch_dk1, triton_dk1, atol=1e-2, rtol=0)
assert torch.allclose(torch_dk2, triton_dk2, atol=1e-2, rtol=0)
assert torch.allclose(torch_dq1, triton_dq1, atol=1e-2, rtol=0)
assert torch.allclose(torch_dq2, triton_dq2, atol=1e-2, rtol=0)
def f(fn, q1, q2, k1, k2, v, sm_scale, window, grad):
q1.grad, q2.grad, k1.grad, k2.grad, v.grad = None, None, None, None, None
out = fn(q1, q2, k1, k2, v, True, sm_scale, window)
out.backward(grad, retain_graph=True)
t0 = benchmark.Timer(
stmt='f(fn, q1, q2, k1, k2, v, sm_scale, window, grad)',
globals={'f': f, 'fn': torch_attention, 'q1': q1, 'q2': q2, 'k1': k1, 'k2': k2, 'v': v, 'sm_scale': sm_scale, 'window': WINDOW, 'grad': grad},
num_threads=torch.get_num_threads())
t1 = benchmark.Timer(
stmt='f(fn, q1, q2, k1, k2, v, sm_scale, window, grad)',
globals={'f': f, 'fn': triton_attention, 'q1': q1, 'q2': q2, 'k1': k1, 'k2': k2, 'v': v, 'sm_scale': sm_scale, 'window': WINDOW, 'grad': grad},
num_threads=torch.get_num_threads())
print(t0.timeit(10))
print(t1.timeit(10))
@MiaTheX
Copy link

MiaTheX commented Feb 26, 2024

Hi, I am encountering an issue with following error msg while running this code:
python: /root/.triton/llvm/llvm-5e5a22ca-ubuntu-x64/include/llvm/ADT/SmallVector.h:298: const T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) const [with T = long int; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::const_reference = const long int&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion idx < size()' failed.`

My env is:
python 3.10.12
triton 2.2.0(from pip or source building)

Do you have any suggestions about this issue or would you like to discuss more about the details of this implementation with me?
If there is any additional information you need from my side, please contact me with [email protected] if you like.

Thanks.

@LouChao98
Copy link

I found the impls in openai/triton and Tri Dao's flash-attn repo are both use fp32 for tl.dot. I am curious on whether there is a reason to use out_dtype=tl.float16 in tl.dot in your kernel?

@chu-tianxiang
Copy link
Author

@LouChao98 You're right, I should have used float32 as accumulator. I don't remember why I chose fp16 though.

@MiaTheX I tested with triton=2.2.0 and didn't reproduce the error. I'm not sure where the problem is.

Lots have changed since this was written. Triton has implemented some major improvements to the flash-attn kernel, and there're new libraries coming out, e.g. FlashInfer has better support of computing rope for cache on the fly. This code is just to show the possibility of integrating rerope with flash-attn and the performance is lower than expected partly due to the too many if-else conditions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment