Skip to content

Instantly share code, notes, and snippets.

@gau-nernst
Created August 24, 2024 04:51
Show Gist options
  • Save gau-nernst/2d1842b82dd7581702b3d13ff759b15d to your computer and use it in GitHub Desktop.
Save gau-nernst/2d1842b82dd7581702b3d13ff759b15d to your computer and use it in GitHub Desktop.
FP8 linear triton with row-wise scaling
import torch
import triton
import triton.language as tl
from torch import Tensor
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps)
configs = [
(128, 256, 64, 3, 8),
(64, 256, 32, 4, 4),
(128, 128, 32, 4, 4),
(128, 64, 32, 4, 4),
(64, 128, 32, 4, 4),
(128, 32, 32, 4, 4),
(64, 32, 32, 5, 2),
(32, 64, 32, 5, 2),
# Good config for fp8 inputs
(128, 256, 128, 3, 8),
(256, 128, 128, 3, 8),
(256, 64, 128, 4, 4),
(64, 256, 128, 4, 4),
(128, 128, 128, 4, 4),
(128, 64, 64, 4, 4),
(64, 128, 64, 4, 4),
(128, 32, 64, 4, 4),
# https://github.com/pytorch/pytorch/blob/7868b65c4d4f34133607b0166f08e9fbf3b257c4/torch/_inductor/kernel/mm_common.py#L172
(64, 64, 32, 2, 4),
(64, 128, 32, 3, 4),
(128, 64, 32, 3, 4),
(64, 128, 32, 4, 8),
(128, 64, 32, 4, 8),
(64, 32, 32, 5, 8),
(32, 64, 32, 5, 8),
(128, 128, 32, 2, 8),
(64, 64, 64, 3, 8),
(128, 256, 128, 3, 8),
(256, 128, 128, 3, 8),
]
configs = [
triton.Config(dict(BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K), num_stages=num_stages, num_warps=num_warps)
for BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps in configs
]
@triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"])
@triton.jit
def fp8_mm_kernel(
# fmt: off
A_ptr, B_ptr, C_ptr,
A_scale_rowwise_ptr,
B_scale_colwise_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr = 8,
EVEN_K: tl.constexpr = True,
# fmt: on
):
# based on triton.ops.matmul
pid = tl.program_id(0)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = tl.arange(0, BLOCK_K)
A = A_ptr + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B_ptr + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.0)
b = tl.load(B, mask=rk[:, None] < k, other=0.0)
acc += tl.dot(a, b)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
a_scale = tl.load(A_scale_rowwise_ptr + idx_m, mask=idx_m < M).to(tl.float32)
b_scale = tl.load(B_scale_colwise_ptr + idx_n, mask=idx_n < N).to(tl.float32)
acc = acc.to(tl.float32) * a_scale * b_scale
# inductor generates a suffix
xindex = idx_m * stride_cm + idx_n * stride_cn
tl.store(C_ptr + tl.broadcast_to(xindex, mask.shape), acc, mask)
def quantize_fp8_rowwise(x: Tensor):
scale = x.abs().amax(1) / torch.finfo(torch.float8_e4m3fn).max
x = x.float() / scale.view(-1, 1)
return x.to(torch.float8_e4m3fn), scale
def grid(meta):
return (triton.cdiv(meta["M"], meta["BLOCK_M"]) * triton.cdiv(meta["N"], meta["BLOCK_N"]),)
def fp8_linear_dynamic_act(act: Tensor, weight_fp8: Tensor, weight_scale: Tensor):
M, K = act.shape
N, _ = weight_fp8.shape
weight_fp8_t = weight_fp8.T
act_fp8, act_scale = torch.compile(quantize_fp8_rowwise)(act)
out = torch.empty(M, N, device=act.device, dtype=act.dtype)
fp8_mm_kernel[grid](
act_fp8,
weight_fp8_t,
out,
act_scale,
weight_scale,
M,
N,
K,
*act_fp8.stride(),
*weight_fp8_t.stride(),
*out.stride(),
EVEN_K=K % 2 == 0
)
return out
if __name__ == "__main__":
from triton.testing import do_bench
act_bf16 = torch.randn(1024, 2048).bfloat16().cuda()
weight_bf16 = torch.randn(4096, 2048).bfloat16().cuda()
ref = act_bf16 @ weight_bf16.T
weight_fp8, weight_scale = quantize_fp8_rowwise(weight_bf16)
out1 = fp8_linear_dynamic_act(act_bf16, weight_fp8, weight_scale)
print((ref - out1).abs() / ref.abs())
print("BF16 linear:", do_bench(lambda: act_bf16 @ weight_bf16.T, fast_flush=False, return_mode="median"))
print(
"FP8-rowwise linear:",
do_bench(
lambda: fp8_linear_dynamic_act(act_bf16, weight_fp8, weight_scale), fast_flush=False, return_mode="median"
),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment