Created
January 16, 2025 17:56
-
-
Save davidberard98/fcd85d04ce0f7d335dfc2e3f3ff50fdd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Original kernel from Chao Xu (Dustinpro) and Yuanwei Fang (fywkevin) | |
# | |
# Triton official: https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html | |
# scatter2scatter: https://github.com/shawntan/scattermoe/blob/main/scattermoe/kernels/ops.py#L58 | |
# OpenAI sparse-autoencoder: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L220 | |
# Apple sparse-CCE: https://github.com/apple/ml-cross-entropy/blob/e43af99cb21ea27e4afe0c90c04e66f9abfd47c6/cut_cross_entropy/cce_lse_forward.py#L26 | |
# Sparse Toolkit: https://github.com/stanford-futuredata/stk/blob/736313768ef697ce13a0594a41b2512a0fbc9884/stk/backend/triton_kernels.py#L34 | |
import torch | |
# @manual=//triton:triton | |
import triton | |
# @manual=//triton:triton | |
import triton.language as tl | |
def is_cuda(): | |
return triton.runtime.driver.active.get_current_target().backend == "cuda" | |
def is_hip_mi200(): | |
target = triton.runtime.driver.active.get_current_target() | |
return target.backend == "hip" and target.arch == "gfx90a" | |
def get_cuda_autotune_config(sparse: bool): | |
# keep a simple config for simplicity | |
return [ | |
triton.Config( | |
{ | |
"BLOCK_SIZE_M": 128, | |
"BLOCK_SIZE_N": 256, | |
"BLOCK_SIZE_K": 64, | |
"GROUP_SIZE_M": 8, | |
}, | |
num_stages=6 if sparse else 3, | |
num_warps=8, | |
) | |
] | |
def get_autotune_config(sparse: bool): | |
return get_cuda_autotune_config(sparse) | |
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: | |
# - A list of `triton.Config` objects that define different configurations of | |
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try | |
# - An auto-tuning *key* whose change in values will trigger evaluation of all the | |
# provided configs | |
@triton.autotune( | |
configs=get_autotune_config(sparse=False), | |
key=["M", "N", "K"], | |
) | |
@triton.jit | |
def matmul_kernel( | |
# Pointers to matrices | |
a_ptr, | |
b_ptr, | |
c_ptr, | |
# Matrix dimensions | |
M, | |
N, | |
K, | |
# The stride variables represent how much to increase the ptr by when moving by 1 | |
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` | |
# by to get the element one row down (A has M rows). | |
stride_am, | |
stride_ak, # | |
stride_bk, | |
stride_bn, # | |
stride_cm, | |
stride_cn, | |
# Meta-parameters | |
BLOCK_SIZE_M: tl.constexpr, | |
BLOCK_SIZE_N: tl.constexpr, | |
BLOCK_SIZE_K: tl.constexpr, # | |
GROUP_SIZE_M: tl.constexpr, # | |
): | |
"""Kernel for computing the matmul C = A x B. | |
A has shape (M, K), B has shape (K, N) and C has shape (M, N) | |
""" | |
# ----------------------------------------------------------- | |
# Map program ids `pid` to the block of C it should compute. | |
# This is done in a grouped ordering to promote L2 data reuse. | |
# See the `L2 Cache Optimizations` section for details. | |
## This section can reduce 10% runtime but let's comment it out for now for simplicity | |
# pid = tl.program_id(axis=0) | |
# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
# num_pid_in_group = GROUP_SIZE_M * num_pid_n | |
# group_id = pid // num_pid_in_group | |
# first_pid_m = group_id * GROUP_SIZE_M | |
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) | |
# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) | |
# pid_n = (pid % num_pid_in_group) // group_size_m | |
pid_m = tl.program_id(axis=0) | |
pid_n = tl.program_id(axis=1) | |
# ---------------------------------------------------------- | |
# Create pointers for the first blocks of A and B. | |
# We will advance this pointer as we move in the K direction | |
# and accumulate | |
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers | |
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers | |
# See above `Pointer Arithmetic` section for details | |
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
offs_k = tl.arange(0, BLOCK_SIZE_K) | |
# offs_k = tl.max_contiguous(tl.load(idx_ptr + offs_k), BLOCK_SIZE_K) | |
# offs_k = tl.arange(0, BLOCK_SIZE_K) | |
# offs_am = tl.load(idx_ptr + tl.arange(0, BLOCK_SIZE_M) * ) | |
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | |
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | |
# ----------------------------------------------------------- | |
# Iterate to compute a block of the C matrix. | |
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block | |
# of fp32 values for higher accuracy. | |
# `accumulator` will be converted back to fp16 after the loop. | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
# Load the next block of A and B, generate a mask by checking the K dimension. | |
# If it is out of bounds, set it to 0. | |
# For simplicity, I assume the dimensions are always divisible by BLOCK_SIZE_K | |
a = tl.load(a_ptrs) # , mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) | |
b = tl.load(b_ptrs) # , mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) | |
# [Failed Attempt] If we only load half values, a @ b will be much slower. | |
# a = tl.load(a_ptrs, mask=offs_k[None, :] % 2 == 0, other=0.0) | |
# b = tl.load(b_ptrs, mask=offs_k[:, None] % 2 == 0, other=0.0) | |
# We accumulate along the K dimension. | |
accumulator = tl.dot(a, b, accumulator) | |
# Advance the ptrs to the next K block. | |
a_ptrs += BLOCK_SIZE_K * stride_ak | |
b_ptrs += BLOCK_SIZE_K * stride_bk | |
c = accumulator.to(tl.float16) | |
# ----------------------------------------------------------- | |
# Write back the block of the output matrix C with masks. | |
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | |
# this mask is somehow necessary for accurate match | |
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | |
tl.store(c_ptrs, c, mask=c_mask) | |
@triton.autotune( | |
configs=get_autotune_config(sparse=True), | |
key=["M", "N", "K"], | |
) | |
@triton.jit | |
def matmul_sparse_kernel( | |
# Pointers to matrices | |
a_ptr, | |
b_ptr, | |
c_ptr, | |
idx_ptr, | |
# Matrix dimensions | |
M, | |
N, | |
K, | |
stride_am, | |
stride_ak, # | |
stride_bk, | |
stride_bn, # | |
stride_cm, | |
stride_cn, | |
# Meta-parameters | |
BLOCK_SIZE_M: tl.constexpr, | |
BLOCK_SIZE_N: tl.constexpr, | |
BLOCK_SIZE_K: tl.constexpr, # | |
GROUP_SIZE_M: tl.constexpr, # | |
): | |
"""Kernel for computing the matmul C = A x B. | |
A has shape (M, K), B has shape (K, N) and C has shape (M, N) | |
""" | |
pid_m = tl.program_id(axis=0) | |
pid_n = tl.program_id(axis=1) | |
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M | |
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N | |
offs_k = tl.arange(0, BLOCK_SIZE_K) | |
a_ptrs = a_ptr + offs_am[:, None] * stride_am # + offs_k[None, :] * stride_ak | |
b_ptrs = b_ptr + offs_bn[None, :] * stride_bn # + offs_k[:, None] * stride_bk | |
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): | |
# Load idx_ptr to get the indices | |
selected_idx = tl.max_contiguous( | |
tl.load(idx_ptr + k * BLOCK_SIZE_K + offs_k), BLOCK_SIZE_K | |
) | |
# Only load the selected rows and cols | |
a = tl.load(a_ptrs + selected_idx[None, :] * stride_ak) | |
b = tl.load(b_ptrs + selected_idx[:, None] * stride_bk) | |
# We accumulate along the K dimension. | |
accumulator = tl.dot(a, b, accumulator) | |
c = accumulator.to(tl.float16) | |
# ----------------------------------------------------------- | |
# Write back the block of the output matrix C with masks. | |
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] | |
# this mask is somehow necessary for accurate match | |
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) | |
tl.store(c_ptrs, c, mask=c_mask) | |
def matmul(a, b): | |
# Check constraints. | |
assert a.shape[1] == b.shape[0], "Incompatible dimensions" | |
# assert a.is_contiguous(), "Matrix A must be contiguous" | |
M, K = a.shape | |
K, N = b.shape | |
b = b.contiguous() | |
# Allocates output. | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
# 1D launch kernel where each block gets its own program. | |
grid = lambda META: ( | |
# triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), | |
triton.cdiv(M, META["BLOCK_SIZE_M"]), | |
triton.cdiv(N, META["BLOCK_SIZE_N"]), | |
) | |
matmul_kernel[grid]( | |
a, | |
b, | |
c, | |
M, | |
N, | |
K, | |
a.stride(0), | |
a.stride(1), # | |
b.stride(0), | |
b.stride(1), # | |
c.stride(0), | |
c.stride(1), # | |
) | |
return c | |
def matmul_sparse(a, b, step_size=1): | |
# NOTE: may set the step_size as 1, 2, 4 ... for different sparsity levels | |
# Check constraints. | |
assert a.shape[1] == b.shape[0], "Incompatible dimensions" | |
# assert a.is_contiguous(), "Matrix A must be contiguous" | |
M, K = a.shape | |
K, N = b.shape | |
b = b.contiguous() | |
# Allocates output. | |
c = torch.empty((M, N), device=a.device, dtype=torch.float16) | |
# 1D launch kernel where each block gets its own program. | |
grid = lambda META: ( | |
# triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), | |
triton.cdiv(M, META["BLOCK_SIZE_M"]), | |
triton.cdiv(N, META["BLOCK_SIZE_N"]), | |
) | |
# One more arg for idx_ptr | |
selected_k_idx = torch.arange(0, K, step_size, device=a.device) | |
matmul_sparse_kernel[grid]( | |
a, | |
b, | |
c, | |
selected_k_idx, | |
M, | |
N, | |
K // step_size, | |
a.stride(0), | |
a.stride(1), # | |
b.stride(0), | |
b.stride(1), # | |
c.stride(0), | |
c.stride(1), # | |
) | |
return c | |
TORCH_HAS_FP8 = False | |
ref_lib = "cuBLAS" if is_cuda() else "rocBLAS" | |
configs = [] | |
for fp8_inputs in [False, True]: | |
if fp8_inputs and (not TORCH_HAS_FP8 or not is_cuda()): | |
continue | |
configs.append( | |
triton.testing.Benchmark( | |
x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot | |
# NOTE: I fixed the input M, N, K to be 5120, 13824, 5120 | |
x_vals=[(8192, 13824, 5120)] * 3, # Different possible values for `x_name` | |
line_arg="provider", # Argument name whose value corresponds to a different line in the plot | |
# Possible values for `line_arg` | |
# Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. | |
line_vals=["triton"] | |
if fp8_inputs | |
else [ | |
ref_lib.lower(), | |
"triton", | |
"triton_sparse", | |
], # Label name for the lines | |
line_names=["Triton"] | |
if fp8_inputs | |
else [ref_lib, "Triton", "Triton_Sparse"], # Line styles | |
styles=[("green", "-"), ("blue", "-"), ("red", "-")], # Line styles | |
ylabel="TFLOPS", # Label name for the y-axis | |
plot_name="matmul-performance-" | |
+ ( | |
"fp16" if not fp8_inputs else "fp8" | |
), # Name for the plot, used also as a file name for saving the plot. | |
args={"fp8_inputs": fp8_inputs}, | |
) | |
) | |
@triton.testing.perf_report(configs) | |
def benchmark(M, N, K, provider, fp8_inputs): | |
a = torch.randn((M, K), device="cuda", dtype=torch.float16) | |
b = torch.randn((K, N), device="cuda", dtype=torch.float16) | |
a = a.T.contiguous().T | |
quantiles = [0.5, 0.2, 0.8] | |
if provider == ref_lib.lower(): | |
ms, min_ms, max_ms = triton.testing.do_bench( | |
lambda: torch.matmul(a, b), quantiles=quantiles | |
) | |
if provider == "triton": | |
ms, min_ms, max_ms = triton.testing.do_bench( | |
lambda: matmul(a, b), quantiles=quantiles | |
) | |
print(matmul_kernel.best_config, ms) | |
if provider == "triton_sparse": | |
ms, min_ms, max_ms = triton.testing.do_bench( | |
lambda: matmul_sparse(a, b), quantiles=quantiles | |
) | |
print(matmul_sparse_kernel.best_config, ms) | |
# The higher, the better | |
perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) | |
return perf(ms), perf(max_ms), perf(min_ms) | |
def main() -> None: | |
global a, b | |
# Unit Test | |
torch.manual_seed(0) | |
step_size = 2 | |
# both are row-majored | |
# pyre-fixme[10]: Name `a` is used but not defined. | |
a = torch.randn((256, 128), device="cuda", dtype=torch.float16) | |
# pyre-fixme[10]: Name `b` is used but not defined. | |
b = torch.randn((128, 512), device="cuda", dtype=torch.float16) | |
a = a.T.contiguous().T | |
triton_output = matmul_sparse(a, b, step_size=step_size) | |
torch_output = a[:, ::step_size] @ b[::step_size, :] | |
print(f"triton_output_with_fp16_inputs={triton_output}") | |
print(f"torch_output_with_fp16_inputs={torch_output}") | |
# Bigger tolerance for AMD MI200 devices. | |
# MI200 devices use reduced precision fp16 and bf16 and flush input and | |
# output denormal values to zero. Detailed info is at: https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices | |
rtol = 1e-2 if is_hip_mi200() else 0 | |
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): | |
print("✅ Triton and Torch match") | |
else: | |
print("❌ Triton and Torch differ") | |
# Benchmark | |
benchmark.run(show_plots=True, print_data=True) | |
if __name__ == "__main__": | |
# Do not add code here, it won't be run. Add them to the function called below. | |
main() # pragma: no cover | |
# Expected Output on an H100: | |
# for the default: step_size = 1 | |
# $ python matmul_sparse.py | |
# Monkey patched Triton's _build! See /packages/xlformers_emu_conda_unified/conda/lib/python3.10/site-packages/patch_triton.py | |
# Monkey patched Triton's nvsmi! See /packages/xlformers_emu_conda_unified/conda/lib/python3.10/site-packages/patch_triton.py | |
# triton_output_with_fp16_inputs=tensor([[ 2.7754, -0.1849, -1.8604, ..., -12.0938, 16.1406, -14.7500], | |
# [ 14.5234, 16.6719, -10.4844, ..., -6.0742, -17.8906, -0.6675], | |
# [-15.9922, 6.3867, 5.9883, ..., -1.6328, -7.3750, 18.0625], | |
# ..., | |
# [ 10.7422, 2.3633, -4.8203, ..., 4.7578, 1.4883, -2.9141], | |
# [ 20.5312, 6.5547, 9.5156, ..., 7.6758, 11.3359, 4.1211], | |
# [ 10.8359, -2.3418, -1.1533, ..., 3.5332, 8.2109, -10.4531]], | |
# device='cuda:0', dtype=torch.float16) | |
# torch_output_with_fp16_inputs=tensor([[ 2.7754, -0.1849, -1.8604, ..., -12.0938, 16.1406, -14.7500], | |
# [ 14.5234, 16.6719, -10.4844, ..., -6.0742, -17.8906, -0.6675], | |
# [-15.9922, 6.3867, 5.9883, ..., -1.6328, -7.3750, 18.0625], | |
# ..., | |
# [ 10.7422, 2.3633, -4.8203, ..., 4.7578, 1.4883, -2.9141], | |
# [ 20.5312, 6.5547, 9.5156, ..., 7.6758, 11.3359, 4.1211], | |
# [ 10.8359, -2.3418, -1.1533, ..., 3.5332, 8.2109, -10.4531]], | |
# device='cuda:0', dtype=torch.float16) | |
# ✅ Triton and Torch match | |
# BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 1.920896053314209 | |
# BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 6.49567985534668 | |
# BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 1.9256799221038818 | |
# BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 6.48799991607666 | |
# BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 1.9171359539031982 | |
# BLOCK_SIZE_M: 128, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 64, GROUP_SIZE_M: 8, num_warps: 8, num_ctas: 1, num_stages: 3, maxnreg: None 6.476831912994385 | |
# matmul-performance-fp16: | |
# M N K cuBLAS Triton Triton_Sparse | |
# 0 8192.0 13824.0 5120.0 605.427753 603.698033 178.524988 | |
# 1 8192.0 13824.0 5120.0 598.325609 602.198297 178.736311 | |
# 2 8192.0 13824.0 5120.0 598.049139 604.882073 179.044506 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment