|
import triton |
|
import torch |
|
import triton.language as tl |
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), |
|
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), |
|
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), |
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), |
|
], |
|
key=['M', 'N', 'K'], |
|
) |
|
@triton.jit |
|
def row_major_ordering(a_ptr, b_ptr, out_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_om, stride_on, GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): |
|
pid = tl.program_id(0) |
|
grid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M |
|
grid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N |
|
pid_m = pid // grid_m |
|
pid_n = pid % grid_n |
|
|
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
|
offs_k = tl.arange(0, BLOCK_SIZE_K) |
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
|
|
|
for k in range(0, K, BLOCK_SIZE_K): |
|
# Note that for simplicity, we don't apply a mask here. |
|
# This means that if K is not a multiple of BLOCK_SIZE_K, |
|
# this will access out-of-bounds memory and produce an |
|
# error or (worse!) incorrect results. |
|
a = tl.load(a_ptrs) |
|
b = tl.load(b_ptrs) |
|
# We accumulate along the K dimension |
|
accumulator += tl.dot(a, b) |
|
# Advance the ptrs to the next K block |
|
a_ptrs += BLOCK_SIZE_K * stride_ak |
|
b_ptrs += BLOCK_SIZE_K * stride_bk |
|
|
|
out = accumulator.to(tl.float16) |
|
|
|
# Write back the block of the output matrix C |
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
|
out_ptrs = out_ptr + stride_om * offs_cm[:, None] + stride_on * offs_cn[None, :] |
|
out_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
|
tl.store(out_ptrs, out, mask=out_mask) |
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), |
|
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), |
|
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), |
|
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), |
|
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), |
|
], |
|
key=['M', 'N', 'K'], |
|
) |
|
@triton.jit |
|
def group_ordering(a_ptr, b_ptr, out_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_om, stride_on, GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): |
|
|
|
# program ID |
|
pid = tl.program_id(axis=0) |
|
# number of program ids along the M axis |
|
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
|
# number of programs ids along the N axis |
|
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) |
|
# number of programs in group |
|
num_pid_in_group = GROUP_SIZE_M * num_pid_n |
|
# id of the group this program is in |
|
group_id = pid // num_pid_in_group |
|
# row-id of the first program in the group |
|
first_pid_m = group_id * GROUP_SIZE_M |
|
# if `num_pid_m` isn't divisible by `GROUP_SIZE_M`, the last group is smaller |
|
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) |
|
# *within groups*, programs are ordered in a column-major order |
|
# row-id of the program in the *launch grid* |
|
pid_m = first_pid_m + (pid % group_size_m) |
|
# col-id of the program in the *launch grid* |
|
pid_n = (pid % num_pid_in_group) // group_size_m |
|
|
|
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
|
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
|
offs_k = tl.arange(0, BLOCK_SIZE_K) |
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) |
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) |
|
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
|
|
|
for k in range(0, K, BLOCK_SIZE_K): |
|
# Note that for simplicity, we don't apply a mask here. |
|
# This means that if K is not a multiple of BLOCK_SIZE_K, |
|
# this will access out-of-bounds memory and produce an |
|
# error or (worse!) incorrect results. |
|
a = tl.load(a_ptrs) |
|
b = tl.load(b_ptrs) |
|
# We accumulate along the K dimension |
|
accumulator += tl.dot(a, b) |
|
# Advance the ptrs to the next K block |
|
a_ptrs += BLOCK_SIZE_K * stride_ak |
|
b_ptrs += BLOCK_SIZE_K * stride_bk |
|
|
|
out = accumulator.to(tl.float16) |
|
|
|
# Write back the block of the output matrix C |
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) |
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) |
|
out_ptrs = out_ptr + stride_om * offs_cm[:, None] + stride_on * offs_cn[None, :] |
|
out_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) |
|
tl.store(out_ptrs, out, mask=out_mask) |
|
|
|
|
|
def matmul(a, b, funcname=""): |
|
# checks constraints |
|
assert a.shape[1] == b.shape[0], "incompatible dimensions" |
|
assert a.is_contiguous(), "matrix A must be contiguous" |
|
assert b.is_contiguous(), "matrix B must be contiguous" |
|
M, K = a.shape |
|
K, N = b.shape |
|
assert ( |
|
K % 32 == 0 |
|
), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K" |
|
# allocates output |
|
c = torch.empty((M, N), device=a.device, dtype=a.dtype) |
|
# 1D launch kernel where each block gets its own program. |
|
grid = lambda META: ( |
|
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), |
|
) |
|
|
|
if funcname == "group_ordering": |
|
group_ordering[grid]( |
|
a, b, c, |
|
M, N, K, |
|
a.stride(0), a.stride(1), |
|
b.stride(0), b.stride(1), |
|
c.stride(0), c.stride(1), |
|
) |
|
elif funcname == "row_major_ordering": |
|
row_major_ordering[grid]( |
|
a, b, c, |
|
M, N, K, |
|
a.stride(0), a.stride(1), |
|
b.stride(0), b.stride(1), |
|
c.stride(0), c.stride(1), |
|
|
|
) |
|
else: |
|
raise ValueError(f"unknown funcname {funcname}") |
|
|
|
return c |
|
|
|
@triton.testing.perf_report( |
|
triton.testing.Benchmark( |
|
x_names=['M', 'N', 'K'], # argument names to use as an x-axis for the plot |
|
x_vals=[ |
|
128 * i for i in range(2, 12) |
|
], # different possible values for `x_name` |
|
line_arg='provider', # argument name whose value corresponds to a different line in the plot |
|
# possible values for `line_arg`` |
|
line_vals=['group_ordering', 'row_major_ordering'], |
|
# label name for the lines |
|
line_names=['group_ordering', 'row_major_ordering'], |
|
# line styles |
|
styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], |
|
ylabel="TFLOPS", # label name for the y-axis |
|
plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. |
|
args={}, |
|
) |
|
) |
|
def benchmark(M, N, K, provider): |
|
a = torch.randn((M, K), device='cuda', dtype=torch.float16) |
|
b = torch.randn((K, N), device='cuda', dtype=torch.float16) |
|
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) |
|
|
|
if provider == 'group_ordering': |
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, "group_ordering"), rep=100) |
|
return perf(ms), perf(max_ms), perf(min_ms) |
|
if provider == 'row_major_ordering': |
|
ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, "row_major_ordering"), rep=100) |
|
return perf(ms), perf(max_ms), perf(min_ms) |
|
|
|
|
|
benchmark.run(show_plots=True, print_data=True) |
BUG in Line 25: pid_m = pid // grid_m
should be: pid_m = pid // grid_n