Skip to content

Instantly share code, notes, and snippets.

@yzhangcs
Created March 1, 2024 09:04
Show Gist options
  • Save yzhangcs/35b58ae112b5830b932b4b729eababc2 to your computer and use it in GitHub Desktop.
Save yzhangcs/35b58ae112b5830b932b4b729eababc2 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
import torch
import triton
import triton.language as tl
@triton.jit
def cumsum_matmul_kernel(
s,
z,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr,
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
o_i = tl.arange(0, BT)
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
b_z = tl.zeros([BS], dtype=tl.float32)
for i_t in range(tl.cdiv(T, BT)):
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1))
# do cumsum by tensor cores
b_sc = b_z[None, :] + tl.dot(m_s.to(b_s.dtype), b_s, allow_tf32=False)
tl.store(p_z, b_sc.to(p_z.dtype.element_ty), boundary_check=(0, 1))
b_z = b_z + tl.sum(b_s, 0)
@triton.jit
def cumsum_triton_kernel(
s,
z,
s_s_h,
s_s_t,
s_s_d,
T: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr,
):
i_s, i_bh = tl.program_id(0), tl.program_id(1)
b_z = tl.zeros([BS], dtype=tl.float32)
for i_t in range(tl.cdiv(T, BT)):
p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1))
b_sc = b_z[None, :] + tl.cumsum(b_s, 0)
tl.store(p_z, b_sc.to(p_z.dtype.element_ty), boundary_check=(0, 1))
b_z = b_z + tl.sum(b_s, 0)
def cumsum_torch(s):
return s.float().cumsum(2).to(s.dtype)
def cumsum_triton(s):
B, H, T, S = s.shape
BT, BS = 64, 64
NS = triton.cdiv(S, BS)
grid = (NS, B * H)
z = torch.empty_like(s)
cumsum_triton_kernel[grid](
s, z,
s.stride(1), s.stride(2), s.stride(3),
T=T, S=S, BT=BT, BS=BS,
num_warps=2,
num_stages=1
)
return z
def cumsum_matmul(s):
B, H, T, S = s.shape
BT, BS = 64, 64
NS = triton.cdiv(S, BS)
grid = (NS, B * H)
z = torch.empty_like(s)
cumsum_matmul_kernel[grid](
s, z,
s.stride(1), s.stride(2), s.stride(3),
T=T, S=S, BT=BT, BS=BS,
num_warps=1,
num_stages=1
)
return z
B, H, T, D = 8, 4, 2048, 256
dtype = torch.float
device = 'cuda'
s = torch.randn(B, H, T, D, device=device, dtype=dtype)
print("DIFF\t")
print('triton\t', f"{(cumsum_torch(s) - cumsum_triton(s)).abs().max():>10.6f}")
print('matmul\t', f"{(cumsum_torch(s) - cumsum_matmul(s)).abs().max():>10.6f}")
print('Done!')
@triton.testing.perf_report(
triton.testing.Benchmark(
# argument names to use as an x-axis for the plot
x_names=['seq_len'],
# different possible values for `x_name`
x_vals=[128 * 2 ** i for i in range(0, 8)],
# argument name whose value corresponds to a different line in the plot
line_arg='provider',
# possible values for `line_arg``
line_vals=['torch', 'matmul', 'triton'],
# label name for the lines
line_names=['torch', 'matmul', 'triton'],
# line styles
styles=[('green', '-'), ('blue', '--'), ('red', '-.')],
ylabel="Execution Time (ms)", # label name for the y-axis
# name for the plot. Used also as a file name for saving the plot.
plot_name="Performance",
args={},
)
)
def benchmark(seq_len, provider):
device = 'cuda'
dtype = torch.bfloat16
s = torch.randn(B, H, seq_len, D, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
results = 0, 0, 0
if provider == 'torch':
results = triton.testing.do_bench(lambda: cumsum_torch(s), quantiles=quantiles)
elif provider == 'matmul':
results = triton.testing.do_bench(lambda: cumsum_matmul(s), quantiles=quantiles)
elif provider == 'triton':
results = triton.testing.do_bench(lambda: cumsum_triton(s), quantiles=quantiles)
return results
benchmark.run(print_data=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment