Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Last active October 20, 2023 10:30
Show Gist options
  • Save norabelrose/398cac6a0a24a409f6272bdcfc2736b6 to your computer and use it in GitHub Desktop.
Save norabelrose/398cac6a0a24a409f6272bdcfc2736b6 to your computer and use it in GitHub Desktop.
Compute covariance matrix in Triton
from itertools import product
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({'BLOCK_N': n, 'BLOCK_D': d, 'GROUP_SIZE_D': 8}, num_stages=4, num_warps=4)
for n, d in product([32, 64, 128, 256], repeat=2)
] + [
triton.Config({'BLOCK_N': 32, 'BLOCK_D': 32, 'GROUP_SIZE_D': 8}, num_stages=5, num_warps=2)
],
key=['N', 'D'],
)
@triton.jit
def cumulant_kernel(
in_ptr,
out_ptr,
# Matrix dimensions
N: int, D: int,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension
stride_n: int, stride_d: int,
# Strides for output tensor
stride_out1: int, stride_out2: int,
# Meta-parameters
BLOCK_D: tl.constexpr,
BLOCK_N: tl.constexpr,
GROUP_SIZE_D: tl.constexpr
):
"""Compute covariance matrix of X using Triton."""
# -----------------------------------------------------------
# Map program ids `pid` to the block of Out it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See the matrix multiplication tutorial for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(D, BLOCK_D)
num_pid_n = tl.cdiv(D, BLOCK_D)
num_pid_in_group = GROUP_SIZE_D * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_D
GROUP_SIZE_D = min(num_pid_m - first_pid_m, GROUP_SIZE_D)
pid_m = first_pid_m + (pid % GROUP_SIZE_D)
pid_n = (pid % num_pid_in_group) // GROUP_SIZE_D
# ----------------------------------------------------------
# Create block pointers for the first blocks of A and B.
# We will advance this pointer as we move along the N axis and accumulate.
# See above `Make a Block Pointer` section for details.
# Construct a transposed block view of X
x_ptr = tl.make_block_ptr(
in_ptr,
shape=(D, N),
strides=(stride_d, stride_n),
offsets=(pid_m * BLOCK_D, 0),
block_shape=(BLOCK_D, BLOCK_N),
order=(1, 0)
)
y_ptr = tl.make_block_ptr(
in_ptr,
shape=(N, D),
strides=(stride_n, stride_d),
offsets=(0, pid_n * BLOCK_D),
block_shape=(BLOCK_N, BLOCK_D),
order=(1, 0)
)
# Use float32 during accumulation for higher accuracy
accumulator = tl.zeros((BLOCK_D, BLOCK_D), dtype=tl.float32)
# Iterate over blocks of the N axis
for _ in range(0, N, BLOCK_N):
# See above `Load/Store a Block Pointer` section for details.
x = tl.load(x_ptr, boundary_check=(0, 1))
y = tl.load(y_ptr, boundary_check=(0, 1))
# We accumulate along the N axis.
accumulator += tl.dot(x, y)
# Advance the block pointer to the next N block.
# See above `Advance a Block Pointer` section for details.
x_ptr = tl.advance(x_ptr, (0, BLOCK_N))
y_ptr = tl.advance(y_ptr, (BLOCK_N, 0))
# Normalize with Bessel's correction
accumulator /= (N - 1)
out = accumulator.to(out_ptr.dtype.element_ty)
# ----------------------------------------------------------------
# Write back the block of the output matrix Out with boundary checks.
# See above `Load/Store a Block Pointer` section for details.
out_block_ptr = tl.make_block_ptr(
out_ptr,
shape=(D, D),
strides=(stride_out1, stride_out2),
offsets=(pid_m * BLOCK_D, pid_n * BLOCK_D),
block_shape=(BLOCK_D, BLOCK_D),
order=(1, 0)
)
tl.store(out_block_ptr, out, boundary_check=(0, 1))
def cumulant(x):
n, d = x.shape
assert x.is_contiguous(), "X must be contiguous"
# Allocate output
out = x.new_empty((d, d))
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(d, META['BLOCK_D']) ** 2,
)
cumulant_kernel[grid](
x, out,
n, d,
*x.stride(),
*out.stride(),
)
return out
if __name__ == '__main__':
torch.manual_seed(0)
n, d = 512, 16
X = torch.randn((n, d), device='cuda', dtype=torch.float16)
X -= torch.mean(X, dim=0)
triton_output = cumulant(X)
torch_output = torch.matmul(X.mT, X) / (n - 1)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'], # Argument names to use as an x-axis for the plot
x_vals=[
128 * i for i in range(2, 33)
], # Different possible values for `x_name`
line_arg='provider', # Argument name whose value corresponds to a different line in the plot
# Possible values for `line_arg`
line_vals=['cublas', 'triton'],
# Label name for the lines
line_names=["cuBLAS", "Triton"],
# Line styles
styles=[('green', '-'), ('blue', '-')],
ylabel="TFLOPS", # Label name for the y-axis
plot_name="cov-performance", # Name for the plot, used also as a file name for saving the plot.
args={},
)
)
def benchmark(N, provider):
D = 512
a = torch.randn((N, D), device='cuda', dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
if provider == 'cublas':
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: torch.matmul(a.mT, a) / (n - 1), quantiles=quantiles
)
if provider == 'triton':
ms, min_ms, max_ms = triton.testing.do_bench(lambda: cumulant(a), quantiles=quantiles)
perf = lambda ms: 2 * N * N * D * 1e-12 / (ms * 1e-3)
return perf(ms), perf(max_ms), perf(min_ms)
benchmark.run(show_plots=True, print_data=True, save_path='.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment