Skip to content

Instantly share code, notes, and snippets.

@chu-tianxiang
Created August 31, 2023 10:44
Show Gist options
  • Save chu-tianxiang/dad59360f8852810c943fbbe8978c307 to your computer and use it in GitHub Desktop.
Save chu-tianxiang/dad59360f8852810c943fbbe8978c307 to your computer and use it in GitHub Desktop.
triton implementation of ReRope forward pass
import time
import torch
import torch.utils.benchmark as benchmark
import triton
import triton.language as tl
@triton.jit
def _fwd_kernel(
Q1, Q2, K1, K2, V, sm_scale,
L,
Out,
stride_q1z, stride_q1h, stride_q1m, stride_q1k,
stride_q2z, stride_q2h, stride_q2m, stride_q2k,
stride_k1z, stride_k1h, stride_k1n, stride_k1k,
stride_k2z, stride_k2h, stride_k2n, stride_k2k,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX, P_SEQ,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
WINDOW: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
q_offset = off_hz * stride_q1h
kv_offset = off_hz * stride_k1h
Q1_block_ptr = tl.make_block_ptr(
base=Q1 + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_q1m, stride_q1k),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
Q2_block_ptr = tl.make_block_ptr(
base=Q2 + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_q2m, stride_q2k),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K1_block_ptr = tl.make_block_ptr(
base=K1 + kv_offset,
shape=(BLOCK_DMODEL, N_CTX + P_SEQ),
strides=(stride_k1k, stride_k1n),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
K2_block_ptr = tl.make_block_ptr(
base=K2 + kv_offset,
shape=(BLOCK_DMODEL, N_CTX + P_SEQ),
strides=(stride_k2k, stride_k2n),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + kv_offset,
shape=(N_CTX + P_SEQ, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q1 = tl.load(Q1_block_ptr)
q1 = (q1 * qk_scale).to(tl.float16)
q2 = tl.load(Q2_block_ptr)
q2 = (q2 * qk_scale).to(tl.float16)
# loop over k, v and update accumulator
lo = 0
hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ
for start_n in range(lo, hi, BLOCK_N):
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
if IS_CAUSAL:
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
if start_n + BLOCK_N <= start_m * BLOCK_M - WINDOW or start_n >= (start_m + 1) * BLOCK_M + WINDOW:
k2 = tl.load(K2_block_ptr)
v = tl.load(V_block_ptr)
qk += tl.dot(q2, k2, out_dtype=tl.float16)
else:
k1 = tl.load(K1_block_ptr)
k2 = tl.load(K2_block_ptr)
v = tl.load(V_block_ptr)
qk += tl.where(P_SEQ + offs_m[:, None] < (start_n + offs_n[None, :] + WINDOW), tl.dot(q1, k1, out_dtype=tl.float16), tl.dot(q2, k2, out_dtype=tl.float16))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(tl.float16), v)
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
K1_block_ptr = tl.advance(K1_block_ptr, (0, BLOCK_N))
K2_block_ptr = tl.advance(K2_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc / l_i[:, None]
l_ptrs = L + off_hz * N_CTX + offs_m
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(O_block_ptr, acc.to(tl.float16))
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q1, q2, k1, k2, v, causal, sm_scale, window):
# shape constraints
Lq, Lk, Lv = q1.shape[-1], k1.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q1)
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4
grid = (triton.cdiv(q1.shape[2], BLOCK_M), q1.shape[0] * q1.shape[1], 1)
L = torch.empty((q1.shape[0] * q1.shape[1], q1.shape[2]), device=q1.device, dtype=torch.float32)
P_SEQ = 0 if q1.shape[-2] == k1.shape[-2] else k1.shape[-2] - q1.shape[-2]
_fwd_kernel[grid](
q1, q2, k1, k2, v, sm_scale,
L,
o,
q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3),
q2.stride(0), q2.stride(1), q2.stride(2), q2.stride(3),
k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3),
k2.stride(0), k2.stride(1), k2.stride(2), k2.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q1.shape[0], q1.shape[1], q1.shape[2], P_SEQ,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal, WINDOW=window,
num_warps=num_warps,
num_stages=num_stages)
ctx.save_for_backward(q1, q2, k1, k2, v, o, L)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
ctx.P_SEQ = P_SEQ
return o
triton_attention = _attention.apply
def torch_attention(q1, q2, k1, k2, v, causal, sm_scale, window):
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p1 = torch.matmul(q1, k1.transpose(2, 3)) * sm_scale
p2 = torch.matmul(q2, k2.transpose(2, 3)) * sm_scale
if causal:
p1[:, :, M == 0] = float("-inf")
p2[:, :, M == 0] = float("-inf")
x = torch.arange(N_CTX, dtype=torch.int, device="cuda")
M2 = ((x[:, None] - x[None, :]).abs() < window)[None, None, :]
p = torch.where(M2, p1, p2)
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
return ref_out
Z = 1
H = 40
N_CTX = 8192
D_HEAD = 128
WINDOW = 2048
sm_scale = 0.5
q1 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5)
q2 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5)
k1 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5)
k2 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5)
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5)
torch_output = torch_attention(q1, q2, k1, k2, v, True, sm_scale, WINDOW)
triton_output = triton_attention(q1, q2, k1, k2, v, True, sm_scale, WINDOW)
assert torch.allclose(torch_output, triton_output, atol=1e-2, rtol=0)
t0 = benchmark.Timer(
stmt='torch_attention(q1, q2, k1, k2, v, True, sm_scale, WINDOW)',
setup='from __main__ import torch_attention',
globals={'q1': q1, 'q2': q2, 'k1': k1, 'k2': k2, 'v': v, 'sm_scale': sm_scale, 'WINDOW': WINDOW})
t1 = benchmark.Timer(
stmt='triton_attention(q1, q2, k1, k2, v, True, sm_scale, WINDOW)',
setup='from __main__ import triton_attention',
globals={'q1': q1, 'q2': q2, 'k1': k1, 'k2': k2, 'v': v, 'sm_scale': sm_scale, 'WINDOW': WINDOW})
print(t0.timeit(100))
print(t1.timeit(100))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment