-
-
Save chu-tianxiang/4307937fd94b49c75b61a6967716bae9 to your computer and use it in GitHub Desktop.
# Adapted from the triton implementation of flash-attention v2 | |
# https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py | |
import time | |
import torch | |
import torch.utils.benchmark as benchmark | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def _fwd_kernel( | |
Q1, Q2, K1, K2, V, sm_scale, | |
L, | |
Out, | |
stride_qz, stride_qh, stride_qm, stride_qk, | |
stride_kz, stride_kh, stride_kn, stride_kk, | |
stride_vz, stride_vh, stride_vk, stride_vn, | |
stride_oz, stride_oh, stride_om, stride_on, | |
Z, H, N_CTX, | |
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
IS_CAUSAL: tl.constexpr, | |
WINDOW: tl.constexpr, | |
): | |
start_m = tl.program_id(0) | |
off_hz = tl.program_id(1) | |
q_offset = off_hz * stride_qh | |
kv_offset = off_hz * stride_kh | |
Q1_block_ptr = tl.make_block_ptr( | |
base=Q1 + q_offset, | |
shape=(N_CTX, BLOCK_DMODEL), | |
strides=(stride_qm, stride_qk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
Q2_block_ptr = tl.make_block_ptr( | |
base=Q2 + q_offset, | |
shape=(N_CTX, BLOCK_DMODEL), | |
strides=(stride_qm, stride_qk), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
K1_block_ptr = tl.make_block_ptr( | |
base=K1 + kv_offset, | |
shape=(BLOCK_DMODEL, N_CTX), | |
strides=(stride_kk, stride_kn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_DMODEL, BLOCK_N), | |
order=(0, 1) | |
) | |
K2_block_ptr = tl.make_block_ptr( | |
base=K2 + kv_offset, | |
shape=(BLOCK_DMODEL, N_CTX), | |
strides=(stride_kk, stride_kn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_DMODEL, BLOCK_N), | |
order=(0, 1) | |
) | |
V_block_ptr = tl.make_block_ptr( | |
base=V + kv_offset, | |
shape=(N_CTX, BLOCK_DMODEL), | |
strides=(stride_vk, stride_vn), | |
offsets=(0, 0), | |
block_shape=(BLOCK_N, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
# initialize offsets | |
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
offs_n = tl.arange(0, BLOCK_N) | |
# initialize pointer to m and l | |
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") | |
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) | |
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) | |
# scale sm_scale by log_2(e) and use | |
# 2^x instead of exp in the loop because CSE and LICM | |
# don't work as expected with `exp` in the loop | |
qk_scale = sm_scale * 1.44269504 | |
# load q: it will stay in SRAM throughout | |
q1 = tl.load(Q1_block_ptr) | |
q1 = (q1 * qk_scale).to(tl.float16) | |
q2 = tl.load(Q2_block_ptr) | |
q2 = (q2 * qk_scale).to(tl.float16) | |
# loop over k, v and update accumulator | |
lo = 0 | |
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX | |
for start_n in range(lo, hi, BLOCK_N): | |
# -- compute qk --- | |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16) | |
if IS_CAUSAL: | |
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) | |
if start_n <= start_m * BLOCK_M - WINDOW - BLOCK_N or start_n >= (start_m + 1) * BLOCK_M + WINDOW: | |
k2 = tl.load(K2_block_ptr) | |
v = tl.load(V_block_ptr) | |
qk += tl.dot(q2, k2, out_dtype=tl.float16) | |
elif start_n > (start_m + 1) * BLOCK_M - WINDOW and start_n < start_m * BLOCK_M + WINDOW - BLOCK_N: | |
k1 = tl.load(K1_block_ptr) | |
v = tl.load(V_block_ptr) | |
qk += tl.dot(q1, k1, out_dtype=tl.float16) | |
else: | |
k1 = tl.load(K1_block_ptr) | |
k2 = tl.load(K2_block_ptr) | |
v = tl.load(V_block_ptr) | |
qk1 = tl.dot(q1, k1, out_dtype=tl.float16) | |
qk2 = tl.dot(q2, k2, out_dtype=tl.float16) | |
qk += tl.where(tl.abs(offs_m[:, None] - (start_n + offs_n[None, :])) < WINDOW, qk1, qk2) | |
# -- compute scaling constant --- | |
m_i_new = tl.maximum(m_i, tl.max(qk, 1)) | |
alpha = tl.math.exp2(m_i - m_i_new) | |
p = tl.math.exp2(qk - m_i_new[:, None]) | |
# -- scale and update acc -- | |
acc_scale = l_i * 0 + alpha # workaround some compiler bug | |
acc *= acc_scale[:, None] | |
acc += tl.dot(p.to(tl.float16), v) | |
# -- update m_i and l_i -- | |
l_i = l_i * alpha + tl.sum(p, 1) | |
m_i = m_i_new | |
# update pointers | |
K1_block_ptr = tl.advance(K1_block_ptr, (0, BLOCK_N)) | |
K2_block_ptr = tl.advance(K2_block_ptr, (0, BLOCK_N)) | |
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) | |
# write back l and m | |
acc = acc / l_i[:, None] | |
l_ptrs = L + off_hz * N_CTX + offs_m | |
tl.store(l_ptrs, m_i + tl.math.log2(l_i)) | |
# write back O | |
O_block_ptr = tl.make_block_ptr( | |
base=Out + q_offset, | |
shape=(N_CTX, BLOCK_DMODEL), | |
strides=(stride_om, stride_on), | |
offsets=(start_m * BLOCK_M, 0), | |
block_shape=(BLOCK_M, BLOCK_DMODEL), | |
order=(1, 0) | |
) | |
tl.store(O_block_ptr, acc.to(tl.float16)) | |
@triton.jit | |
def _bwd_preprocess( | |
Out, DO, | |
Delta, | |
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, | |
): | |
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) | |
off_n = tl.arange(0, D_HEAD) | |
# load | |
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) | |
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) | |
# compute | |
delta = tl.sum(o * do, axis=1) | |
# write-back | |
tl.store(Delta + off_m, delta) | |
@triton.jit | |
def _bwd_kernel( | |
Q1, Q2, K1, K2, V, sm_scale, Out, DO, | |
DQ1, DQ2, DK1, DK2, DV, | |
L, | |
D, | |
stride_qz, stride_qh, stride_qm, stride_qk, | |
stride_kz, stride_kh, stride_kn, stride_kk, | |
stride_vz, stride_vh, stride_vk, stride_vn, | |
Z, H, N_CTX, | |
num_block_q, num_block_kv, | |
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
CAUSAL: tl.constexpr, | |
WINDOW: tl.constexpr, | |
): | |
off_hz = tl.program_id(0) | |
off_z = off_hz // H | |
off_h = off_hz % H | |
qk_scale = sm_scale * 1.44269504 | |
# offset pointers for batch/head | |
Q1 += off_z * stride_qz + off_h * stride_qh | |
Q2 += off_z * stride_qz + off_h * stride_qh | |
K1 += off_z * stride_kz + off_h * stride_kh | |
K2 += off_z * stride_kz + off_h * stride_kh | |
V += off_z * stride_vz + off_h * stride_vh | |
DO += off_z * stride_qz + off_h * stride_qh | |
DQ1 += off_z * stride_qz + off_h * stride_qh | |
DQ2 += off_z * stride_qz + off_h * stride_qh | |
DK1 += off_z * stride_kz + off_h * stride_kh | |
DK2 += off_z * stride_kz + off_h * stride_kh | |
DV += off_z * stride_vz + off_h * stride_vh | |
for start_n in range(0, num_block_kv): | |
if CAUSAL: | |
lo = tl.math.max(start_n * BLOCK_N, 0) | |
else: | |
lo = 0 | |
# initialize row/col offsets | |
offs_qm = lo + tl.arange(0, BLOCK_M) | |
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
offs_m = tl.arange(0, BLOCK_M) | |
offs_k = tl.arange(0, BLOCK_DMODEL) | |
# initialize pointers to value-like data | |
q1_ptrs = Q1 + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) | |
q2_ptrs = Q2 + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) | |
k1_ptrs = K1 + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) | |
k2_ptrs = K2 + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) | |
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) | |
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) | |
dq1_ptrs = DQ1 + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) | |
dq2_ptrs = DQ2 + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) | |
# pointer to row-wise quantities in value-like data | |
D_ptrs = D + off_hz * N_CTX | |
l_ptrs = L + off_hz * N_CTX | |
# initialize dk amd dv | |
dk1 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | |
dk2 = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | |
dv = tl.zeros([BLOCK_N, BLOCK_DMODEL], dtype=tl.float32) | |
# k and v stay in SRAM throughout | |
k1 = tl.load(k1_ptrs) | |
k2 = tl.load(k2_ptrs) | |
v = tl.load(v_ptrs) | |
# loop over rows | |
for start_m in range(lo, num_block_q * BLOCK_M, BLOCK_M): | |
offs_m_curr = start_m + offs_m | |
# load q, k, v, do on-chip | |
q1 = tl.load(q1_ptrs) | |
q2 = tl.load(q2_ptrs) | |
# recompute p = softmax(qk, dim=-1).T | |
if CAUSAL: | |
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf")) | |
else: | |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) | |
if start_m >= (start_n + 1) * BLOCK_N + WINDOW or start_m <= start_n * BLOCK_N - WINDOW - BLOCK_M: | |
q2 = tl.load(q2_ptrs) | |
qk += tl.dot(q2, tl.trans(k2)) | |
elif start_m > (start_n + 1) * BLOCK_N - WINDOW and start_m < start_n * BLOCK_N + WINDOW - BLOCK_M: | |
q1 = tl.load(q1_ptrs) | |
qk += tl.dot(q1, tl.trans(k1)) | |
else: | |
q1 = tl.load(q1_ptrs) | |
q2 = tl.load(q2_ptrs) | |
qk1 = tl.dot(q1, tl.trans(k1)) | |
qk2 = tl.dot(q2, tl.trans(k2)) | |
qk += tl.where(tl.abs(offs_m_curr[:, None] - offs_n[None, :]) < WINDOW, qk1, qk2) | |
qk *= qk_scale | |
l_i = tl.load(l_ptrs + offs_m_curr) | |
p = tl.math.exp2(qk - l_i[:, None]) | |
# compute dv | |
do = tl.load(do_ptrs) | |
dv += tl.dot(tl.trans(p.to(Q1.dtype.element_ty)), do) | |
# compute dp = dot(v, do) | |
Di = tl.load(D_ptrs + offs_m_curr) | |
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] | |
dp += tl.dot(do, tl.trans(v)) | |
# compute ds = p * (dp - delta[:, None]) | |
ds = p * dp * sm_scale | |
# compute dk = dot(ds.T, q) | |
if start_m >= (start_n + 1) * BLOCK_N + WINDOW or start_m <= start_n * BLOCK_N - WINDOW - BLOCK_M: | |
dk2 += tl.dot(tl.trans(ds.to(Q1.dtype.element_ty)), q2) | |
dq2 = tl.load(dq2_ptrs) | |
dq2 += tl.dot(ds.to(Q1.dtype.element_ty), k2) | |
tl.store(dq2_ptrs, dq2) | |
elif start_m > (start_n + 1) * BLOCK_N - WINDOW and start_m < start_n * BLOCK_N + WINDOW - BLOCK_M: | |
dk1 += tl.dot(tl.trans(ds.to(Q1.dtype.element_ty)), q1) | |
dq1 = tl.load(dq1_ptrs) | |
dq1 += tl.dot(ds.to(Q1.dtype.element_ty), k1) | |
tl.store(dq1_ptrs, dq1) | |
else: | |
mask = (tl.abs(offs_m_curr[:, None] - offs_n[None, :]) < WINDOW) | |
ds1 = tl.where(mask, ds, float(0.)) | |
ds2 = tl.where(mask, float(0.), ds) | |
dk1 += tl.dot(tl.trans(ds1.to(Q1.dtype.element_ty)), q1) | |
dk2 += tl.dot(tl.trans(ds2.to(Q1.dtype.element_ty)), q2) | |
dq1 = tl.load(dq1_ptrs) | |
dq2 = tl.load(dq2_ptrs) | |
dq1 += tl.dot(ds1.to(Q1.dtype.element_ty), k1) | |
dq2 += tl.dot(ds2.to(Q1.dtype.element_ty), k2) | |
tl.store(dq1_ptrs, dq1) | |
tl.store(dq2_ptrs, dq2) | |
# increment pointers | |
dq1_ptrs += BLOCK_M * stride_qm | |
dq2_ptrs += BLOCK_M * stride_qm | |
q1_ptrs += BLOCK_M * stride_qm | |
q2_ptrs += BLOCK_M * stride_qm | |
do_ptrs += BLOCK_M * stride_qm | |
# write-back | |
dk1_ptrs = DK1 + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) | |
dk2_ptrs = DK2 + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) | |
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) | |
tl.store(dk1_ptrs, dk1) | |
tl.store(dk2_ptrs, dk2) | |
tl.store(dv_ptrs, dv) | |
class _attention(torch.autograd.Function): | |
@staticmethod | |
def forward(ctx, q1, q2, k1, k2, v, causal, sm_scale, window): | |
# shape constraints | |
Lq, Lk, Lv = q1.shape[-1], k1.shape[-1], v.shape[-1] | |
assert Lq == Lk and Lk == Lv | |
assert Lk in {16, 32, 64, 128} | |
o = torch.empty_like(q1) | |
BLOCK_M = 128 | |
BLOCK_N = 64 if Lk <= 64 else 32 | |
num_stages = 4 if Lk <= 64 else 3 | |
num_warps = 4 | |
grid = (triton.cdiv(q1.shape[2], BLOCK_M), q1.shape[0] * q1.shape[1], 1) | |
L = torch.empty((q1.shape[0] * q1.shape[1], q1.shape[2]), device=q1.device, dtype=torch.float32) | |
_fwd_kernel[grid]( | |
q1, q2, k1, k2, v, sm_scale, | |
L, | |
o, | |
q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), | |
k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3), | |
v.stride(0), v.stride(1), v.stride(2), v.stride(3), | |
o.stride(0), o.stride(1), o.stride(2), o.stride(3), | |
q1.shape[0], q1.shape[1], q1.shape[2], | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, | |
IS_CAUSAL=causal, WINDOW=window, | |
num_warps=num_warps, | |
num_stages=num_stages) | |
ctx.save_for_backward(q1, q2, k1, k2, v, o, L) | |
ctx.grid = grid | |
ctx.sm_scale = sm_scale | |
ctx.BLOCK_DMODEL = Lk | |
ctx.causal = causal | |
ctx.window = window | |
return o | |
@staticmethod | |
def backward(ctx, do): | |
BLOCK = 128 | |
q1, q2, k1, k2, v, o, L = ctx.saved_tensors | |
do = do.contiguous() | |
dq1 = torch.zeros_like(q1, dtype=torch.float32) | |
dq2 = torch.zeros_like(q2, dtype=torch.float32) | |
dk1 = torch.empty_like(k1) | |
dk2 = torch.empty_like(k2) | |
dv = torch.empty_like(v) | |
delta = torch.empty_like(L) | |
_bwd_preprocess[(triton.cdiv(q1.shape[2], BLOCK) * ctx.grid[1], )]( | |
o, do, | |
delta, | |
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, | |
) | |
_bwd_kernel[(ctx.grid[1],)]( | |
q1, q2, k1, k2, v, ctx.sm_scale, | |
o, do, | |
dq1, dq2, dk1, dk2, dv, | |
L, delta, | |
q1.stride(0), q1.stride(1), q1.stride(2), q1.stride(3), | |
k1.stride(0), k1.stride(1), k1.stride(2), k1.stride(3), | |
v.stride(0), v.stride(1), v.stride(2), v.stride(3), | |
q1.shape[0], q1.shape[1], q1.shape[2], | |
triton.cdiv(q1.shape[2], BLOCK), triton.cdiv(k1.shape[2], BLOCK), | |
BLOCK_M=BLOCK, BLOCK_N=BLOCK, | |
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, | |
CAUSAL=ctx.causal, WINDOW=ctx.window, | |
num_stages=1, | |
) | |
return dq1, dq2, dk1, dk2, dv, None, None, None | |
triton_attention = _attention.apply | |
def torch_attention(q1, q2, k1, k2, v, causal, sm_scale, window): | |
# reference implementation | |
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) | |
p1 = torch.matmul(q1, k1.transpose(2, 3)) * sm_scale | |
p2 = torch.matmul(q2, k2.transpose(2, 3)) * sm_scale | |
if causal: | |
p1[:, :, M == 0] = float("-inf") | |
p2[:, :, M == 0] = float("-inf") | |
x = torch.arange(N_CTX, dtype=torch.int, device="cuda") | |
M2 = ((x[:, None] - x[None, :]).abs() < window)[None, None, :] | |
p = torch.where(M2, p1, p2) | |
p = torch.softmax(p.float(), dim=-1).half() | |
ref_out = torch.matmul(p, v) | |
return ref_out | |
Z = 1 | |
H = 32 | |
N_CTX = 8192 | |
# currently backward is VERY slow for d_head = 128 | |
# https://github.com/openai/triton/issues/1975 | |
D_HEAD = 64 | |
WINDOW = 2048 | |
sm_scale = 0.5 | |
q1 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() | |
q2 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() | |
k1 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() | |
k2 = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() | |
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=torch.float16, device="cuda").normal_(mean=0., std=0.5).requires_grad_() | |
grad = torch.randn_like(q1) | |
torch_output = torch_attention(q1, q2, k1, k2, v, False, sm_scale, WINDOW) | |
torch_output.backward(grad) | |
torch_dv, v.grad = v.grad.clone(), None | |
torch_dk1, k1.grad = k1.grad.clone(), None | |
torch_dk2, k2.grad = k2.grad.clone(), None | |
torch_dq1, q1.grad = q1.grad.clone(), None | |
torch_dq2, q2.grad = q2.grad.clone(), None | |
triton_output = triton_attention(q1, q2, k1, k2, v, False, sm_scale, WINDOW) | |
triton_output.backward(grad) | |
triton_dv, v.grad = v.grad.clone(), None | |
triton_dk1, k1.grad = k1.grad.clone(), None | |
triton_dk2, k2.grad = k2.grad.clone(), None | |
triton_dq1, q1.grad = q1.grad.clone(), None | |
triton_dq2, q2.grad = q2.grad.clone(), None | |
assert torch.allclose(torch_output, triton_output, atol=2e-2, rtol=0) | |
assert torch.allclose(torch_dv, triton_dv, atol=1e-2, rtol=0) | |
assert torch.allclose(torch_dk1, triton_dk1, atol=1e-2, rtol=0) | |
assert torch.allclose(torch_dk2, triton_dk2, atol=1e-2, rtol=0) | |
assert torch.allclose(torch_dq1, triton_dq1, atol=1e-2, rtol=0) | |
assert torch.allclose(torch_dq2, triton_dq2, atol=1e-2, rtol=0) | |
def f(fn, q1, q2, k1, k2, v, sm_scale, window, grad): | |
q1.grad, q2.grad, k1.grad, k2.grad, v.grad = None, None, None, None, None | |
out = fn(q1, q2, k1, k2, v, True, sm_scale, window) | |
out.backward(grad, retain_graph=True) | |
t0 = benchmark.Timer( | |
stmt='f(fn, q1, q2, k1, k2, v, sm_scale, window, grad)', | |
globals={'f': f, 'fn': torch_attention, 'q1': q1, 'q2': q2, 'k1': k1, 'k2': k2, 'v': v, 'sm_scale': sm_scale, 'window': WINDOW, 'grad': grad}, | |
num_threads=torch.get_num_threads()) | |
t1 = benchmark.Timer( | |
stmt='f(fn, q1, q2, k1, k2, v, sm_scale, window, grad)', | |
globals={'f': f, 'fn': triton_attention, 'q1': q1, 'q2': q2, 'k1': k1, 'k2': k2, 'v': v, 'sm_scale': sm_scale, 'window': WINDOW, 'grad': grad}, | |
num_threads=torch.get_num_threads()) | |
print(t0.timeit(10)) | |
print(t1.timeit(10)) |
Hi, I am encountering an issue with following error msg while running this code:
python: /root/.triton/llvm/llvm-5e5a22ca-ubuntu-x64/include/llvm/ADT/SmallVector.h:298: const T& llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::operator[](llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type) const [with T = long int; <template-parameter-1-2> = void; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::const_reference = const long int&; llvm::SmallVectorTemplateCommon<T, <template-parameter-1-2> >::size_type = long unsigned int]: Assertion
idx < size()' failed.`
My env is:
python 3.10.12
triton 2.2.0(from pip or source building)
Do you have any suggestions about this issue or would you like to discuss more about the details of this implementation with me?
If there is any additional information you need from my side, please contact me with [email protected] if you like.
Thanks.
I found the impls in openai/triton and Tri Dao's flash-attn repo are both use fp32 for tl.dot
. I am curious on whether there is a reason to use out_dtype=tl.float16
in tl.dot
in your kernel?
@LouChao98 You're right, I should have used float32 as accumulator. I don't remember why I chose fp16 though.
@MiaTheX I tested with triton=2.2.0 and didn't reproduce the error. I'm not sure where the problem is.
Lots have changed since this was written. Triton has implemented some major improvements to the flash-attn kernel, and there're new libraries coming out, e.g. FlashInfer has better support of computing rope for cache on the fly. This code is just to show the possibility of integrating rerope with flash-attn and the performance is lower than expected partly due to the too many if-else
conditions.
Hi, We are team from Moonshot (月之暗面), Su recommend this code to me recently. Would you like to discuss the implementation detail with us?
Please contact [email protected] if you like :D