Skip to content

Instantly share code, notes, and snippets.

@HDCharles
Created July 27, 2023 07:22
Show Gist options
  • Select an option

  • Save HDCharles/a67453ca18e2462102dec8a16c83ed1f to your computer and use it in GitHub Desktop.

Select an option

Save HDCharles/a67453ca18e2462102dec8a16c83ed1f to your computer and use it in GitHub Desktop.
triton mixed matmul kernels
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
# these
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),
],
key=['M', 'N', 'K'],
)
@triton.jit
def int8_weight_only_linear_kernel(
# Pointers to matrices
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr,
# Matrix dimensions
M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr`
# by to get the element one row down (A has M rows).
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_b,
stride_ym, stride_yn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of Y it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of X and W.
# We will advance this pointer as we move in the K direction
# and accumulate
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_wn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
w_ptrs = w_ptr + (offs_k[:, None] * stride_wk + offs_wn[None, :] * stride_wn)
b_ptrs = b_ptr + (offs_wn * stride_b)
step_x = BLOCK_SIZE_K * stride_xk
step_w = BLOCK_SIZE_K * stride_wk
# -----------------------------------------------------------
# Iterate to compute a block of the Y matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(x, w.to(tl.bfloat16))
# Advance the ptrs to the next K block.
x_ptrs += step_x
w_ptrs += step_w
s = tl.load(s_ptr)
b = tl.load(b_ptrs)
y = (accumulator * s + b)
# y = accumulator
# -----------------------------------------------------------
# Write back the block of the output matrix Y with masks.
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :]
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N)
tl.store(y_ptrs, y, mask=y_mask)
def int8_weight_only_linear(x, w, b, s):
# Check constraints.
assert x.shape[1] == w.shape[0], "Incompatible dimensions"
# assert x.is_contiguous(), "Matrix x must be contiguous"
# assert w.is_contiguous(), "Matrix w must be contiguous"
M, K = x.shape
K, N = w.shape
assert b.shape[0] == N
# Allocates output.
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
int8_weight_only_linear_kernel[grid](
x, w, b, s, y,
M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
b.stride(0),
y.stride(0), y.stride(1),
)
return y
# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes:
# - A list of `triton.Config` objects that define different configurations of
# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try
# - An auto-tuning *key* whose change in values will trigger evaluation of all the
# provided configs
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),
],
key=['M', 'N', 'K'],
)
@triton.jit
def uint4x2_weight_only_linear_kernel(
# Pointers to matrices
x_ptr, w_ptr, b_ptr, s_ptr, y_ptr,
# Matrix dimensions
M, N, K, # x is Mx(K*2) and w is KxN
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `x_ptr`
# by to get the element one row down (A has M rows).
stride_xm, stride_xk,
stride_wk, stride_wn,
stride_b,
stride_ym, stride_yn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of Y it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of X and W.
# We will advance this pointer as we move in the K direction
# and accumulate
# `x_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `w_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
# See above `Pointer Arithmetics` section for details
offs_xm = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_wn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (offs_xm[:, None] * stride_xm + offs_k[None, :] * stride_xk)
w_ptrs = w_ptr + (offs_k[:, None]//2 * stride_wk + offs_wn[None, :] * stride_wn)
w_shifts = (offs_k % 2) * 4
b_ptrs = b_ptr + (offs_wn * stride_b)
step_x = BLOCK_SIZE_K * stride_xk
step_w = BLOCK_SIZE_K//2 * stride_wk
# -----------------------------------------------------------
# Iterate to compute a block of the Y matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
# If it is out of bounds, set it to 0.
x = tl.load(x_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
w = tl.load(w_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
w = ((w >> w_shifts[:, None]) & 0xF) - 8
# We accumulate along the K dimension.
accumulator += tl.dot(x, w.to(tl.bfloat16))
# Advance the ptrs to the next K block.
x_ptrs += step_x
w_ptrs += step_w
s = tl.load(s_ptr)
b = tl.load(b_ptrs)
y = (accumulator * s)+b
# -----------------------------------------------------------
# Write back the block of the output matrix Y with masks.
offs_ym = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_yn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
y_ptrs = y_ptr + stride_ym * offs_ym[:, None] + stride_yn * offs_yn[None, :]
y_mask = (offs_ym[:, None] < M) & (offs_yn[None, :] < N)
tl.store(y_ptrs, y, mask=y_mask)
def uint4x2_weight_only_linear(x, w, b, s):
# Check constraints.
assert x.shape[1] == w.shape[0]*2, "Incompatible dimensions"
# assert x.is_contiguous(), "Matrix x must be contiguous"
# assert w.is_contiguous(), "Matrix w must be contiguous"
M, K = x.shape
_, N = w.shape
assert b.shape[0] == N
# Allocates output.
y = torch.empty((M, N), device=x.device, dtype=x.dtype)
# 1D launch kernel where each block gets its own program.
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
)
uint4x2_weight_only_linear_kernel[grid](
x, w, b, s, y,
M, N, K,
x.stride(0), x.stride(1),
w.stride(0), w.stride(1),
b.stride(0),
y.stride(0), y.stride(1),
)
return y
quantiles = [0.5, 0.2, 0.8] # idk what this is for but the tutorial had it
result = {}
for D in [2**8, 2**10, 2**12, 2**14]:
result[D]={}
result[D]["linear"]={}
result[D]["int8"]={}
result[D]["uint4x2"]={}
for t_x in [0,1]:
for t_w in [0,1]:
x = torch.randn(D,D).to('cuda').to(torch.bfloat16)
w_bf16 = torch.randn(D,D, dtype=torch.bfloat16).cuda()
bias = torch.randn(D, dtype=torch.bfloat16).cuda()
if t_x:
x = x.t()
if t_w:
w_bf16 = w_bf16.t()
torch.nn.functional.linear(x, w_bf16, bias)
torch.cuda.synchronize()
result[D]["linear"][(t_x, t_w)] = triton.testing.do_bench(lambda: torch.nn.functional.linear(x, w_bf16, bias), quantiles=quantiles)[0]
torch.cuda.synchronize()
del w_bf16
w_int8 = torch.randint(-128, 127, (D, D), dtype=torch.int8).cuda()
if t_w:
w_int8 = w_int8.t()
scale = torch.randn(D, dtype=torch.bfloat16).cuda()
int8_weight_only_linear(x, w_int8, bias, scale)
torch.cuda.synchronize()
result[D]["int8"][(t_x, t_w)] = triton.testing.do_bench(lambda: int8_weight_only_linear(x, w_int8, bias, scale), quantiles=quantiles)[0]
torch.cuda.synchronize()
del w_int8
w_uint4x2 = torch.randint(0, 255, (D//2, D), dtype=torch.uint8).cuda()
if t_w:
w_uint4x2 = torch.randint(0, 255, (D, D//2), dtype=torch.uint8).cuda().t()
uint4x2_weight_only_linear(x, w_uint4x2, bias, scale)
torch.cuda.synchronize()
result[D]["uint4x2"][(t_x, t_w)] = triton.testing.do_bench(lambda: uint4x2_weight_only_linear(x, w_uint4x2, bias, scale), quantiles=quantiles)[0]
torch.cuda.synchronize()
del w_uint4x2
print("(0,0) (0,1) (0,1) (0,1)")
for d in result.keys():
print(f"({d}, {d}) . ({d}, {d})*({d})+({d})")
for name in result[d].keys():
r = result[d][name]
print(f"{r[(0,0)]:.4f}, {r[(0,1)]:.4f}, {r[(0,1)]:.4f}, {r[(0,1)]:.4f} for {name}")
# using triton version triton-nightly 2.1.0.dev20230726014945
# install pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
# using torch compiled from 2.1.0a0+git9c2122d
# using cuda 12.1, cudnn 8.9.2 on A100 GPU
# ---------- OUTPUT --------------
# (0,0) (0,1) (0,1) (0,1)
# (256, 256) . (256, 256)*(256)+(256)
# 0.0092, 0.0102, 0.0102, 0.0102 for linear
# 0.0113, 0.0092, 0.0092, 0.0092 for int8
# 0.0143, 0.0133, 0.0133, 0.0133 for uint4x2
# (1024, 1024) . (1024, 1024)*(1024)+(1024)
# 0.0246, 0.0236, 0.0236, 0.0236 for linear
# 0.0338, 0.0543, 0.0543, 0.0543 for int8
# 0.0440, 0.0481, 0.0481, 0.0481 for uint4x2
# (4096, 4096) . (4096, 4096)*(4096)+(4096)
# 0.5837, 0.5806, 0.5806, 0.5806 for linear
# 0.9687, 1.2698, 1.2698, 1.2698 for int8
# 1.0353, 1.1960, 1.1960, 1.1960 for uint4x2
# (16384, 16384) . (16384, 16384)*(16384)+(16384)
# 36.5676, 36.3018, 36.3018, 36.3018 for linear
# 65.3066, 59.8753, 59.8753, 59.8753 for int8
# 64.7076, 81.4725, 81.4725, 81.4725 for uint4x2
# using triton version pytorch-triton 2.1.0+9e3e10c5ed
# install: pytorch dir->make triton
# ---------- OUTPUT --------------
# (0,0) (0,1) (0,1) (0,1)
# (256, 256) . (256, 256)*(256)+(256)
# 0.0092, 0.0113, 0.0113, 0.0113 for linear
# 0.0102, 0.0092, 0.0092, 0.0092 for int8
# 0.0143, 0.0143, 0.0143, 0.0143 for uint4x2
# (1024, 1024) . (1024, 1024)*(1024)+(1024)
# 0.0256, 0.0236, 0.0236, 0.0236 for linear
# 0.0348, 0.0543, 0.0543, 0.0543 for int8
# 0.0430, 0.0481, 0.0481, 0.0481 for uint4x2
# (4096, 4096) . (4096, 4096)*(4096)+(4096)
# 0.5847, 0.5919, 0.5919, 0.5919 for linear
# 0.9708, 1.2646, 1.2646, 1.2646 for int8
# 1.0496, 1.1899, 1.1899, 1.1899 for uint4x2
# (16384, 16384) . (16384, 16384)*(16384)+(16384)
# 36.3909, 36.1605, 36.1605, 36.1605 for linear
# 66.3859, 81.7807, 81.7807, 81.7807 for int8
# 64.8090, 82.2958, 82.2958, 82.2958 for uint4x2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment