Created
August 24, 2024 04:51
-
-
Save gau-nernst/2d1842b82dd7581702b3d13ff759b15d to your computer and use it in GitHub Desktop.
FP8 linear triton with row-wise scaling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
from torch import Tensor | |
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html | |
# (BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps) | |
configs = [ | |
(128, 256, 64, 3, 8), | |
(64, 256, 32, 4, 4), | |
(128, 128, 32, 4, 4), | |
(128, 64, 32, 4, 4), | |
(64, 128, 32, 4, 4), | |
(128, 32, 32, 4, 4), | |
(64, 32, 32, 5, 2), | |
(32, 64, 32, 5, 2), | |
# Good config for fp8 inputs | |
(128, 256, 128, 3, 8), | |
(256, 128, 128, 3, 8), | |
(256, 64, 128, 4, 4), | |
(64, 256, 128, 4, 4), | |
(128, 128, 128, 4, 4), | |
(128, 64, 64, 4, 4), | |
(64, 128, 64, 4, 4), | |
(128, 32, 64, 4, 4), | |
# https://github.com/pytorch/pytorch/blob/7868b65c4d4f34133607b0166f08e9fbf3b257c4/torch/_inductor/kernel/mm_common.py#L172 | |
(64, 64, 32, 2, 4), | |
(64, 128, 32, 3, 4), | |
(128, 64, 32, 3, 4), | |
(64, 128, 32, 4, 8), | |
(128, 64, 32, 4, 8), | |
(64, 32, 32, 5, 8), | |
(32, 64, 32, 5, 8), | |
(128, 128, 32, 2, 8), | |
(64, 64, 64, 3, 8), | |
(128, 256, 128, 3, 8), | |
(256, 128, 128, 3, 8), | |
] | |
configs = [ | |
triton.Config(dict(BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K), num_stages=num_stages, num_warps=num_warps) | |
for BLOCK_M, BLOCK_N, BLOCK_K, num_stages, num_warps in configs | |
] | |
@triton.autotune(configs=configs, key=["M", "N", "K", "stride_ak", "stride_bk"]) | |
@triton.jit | |
def fp8_mm_kernel( | |
# fmt: off | |
A_ptr, B_ptr, C_ptr, | |
A_scale_rowwise_ptr, | |
B_scale_colwise_ptr, | |
M, N, K, | |
stride_am, stride_ak, | |
stride_bk, stride_bn, | |
stride_cm, stride_cn, | |
BLOCK_M: tl.constexpr, | |
BLOCK_N: tl.constexpr, | |
BLOCK_K: tl.constexpr, | |
GROUP_M: tl.constexpr = 8, | |
EVEN_K: tl.constexpr = True, | |
# fmt: on | |
): | |
# based on triton.ops.matmul | |
pid = tl.program_id(0) | |
grid_m = (M + BLOCK_M - 1) // BLOCK_M | |
grid_n = (N + BLOCK_N - 1) // BLOCK_N | |
# re-order program ID for better L2 performance | |
width = GROUP_M * grid_n | |
group_id = pid // width | |
group_size = min(grid_m - group_id * GROUP_M, GROUP_M) | |
pid_m = group_id * GROUP_M + (pid % group_size) | |
pid_n = (pid % width) // (group_size) | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) | |
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) | |
rk = tl.arange(0, BLOCK_K) | |
A = A_ptr + (ram[:, None] * stride_am + rk[None, :] * stride_ak) | |
B = B_ptr + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) | |
for k in range(K, 0, -BLOCK_K): | |
if EVEN_K: | |
a = tl.load(A) | |
b = tl.load(B) | |
else: | |
a = tl.load(A, mask=rk[None, :] < k, other=0.0) | |
b = tl.load(B, mask=rk[:, None] < k, other=0.0) | |
acc += tl.dot(a, b) | |
A += BLOCK_K * stride_ak | |
B += BLOCK_K * stride_bk | |
# rematerialize rm and rn to save registers | |
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) | |
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) | |
idx_m = rm[:, None] | |
idx_n = rn[None, :] | |
mask = (idx_m < M) & (idx_n < N) | |
a_scale = tl.load(A_scale_rowwise_ptr + idx_m, mask=idx_m < M).to(tl.float32) | |
b_scale = tl.load(B_scale_colwise_ptr + idx_n, mask=idx_n < N).to(tl.float32) | |
acc = acc.to(tl.float32) * a_scale * b_scale | |
# inductor generates a suffix | |
xindex = idx_m * stride_cm + idx_n * stride_cn | |
tl.store(C_ptr + tl.broadcast_to(xindex, mask.shape), acc, mask) | |
def quantize_fp8_rowwise(x: Tensor): | |
scale = x.abs().amax(1) / torch.finfo(torch.float8_e4m3fn).max | |
x = x.float() / scale.view(-1, 1) | |
return x.to(torch.float8_e4m3fn), scale | |
def grid(meta): | |
return (triton.cdiv(meta["M"], meta["BLOCK_M"]) * triton.cdiv(meta["N"], meta["BLOCK_N"]),) | |
def fp8_linear_dynamic_act(act: Tensor, weight_fp8: Tensor, weight_scale: Tensor): | |
M, K = act.shape | |
N, _ = weight_fp8.shape | |
weight_fp8_t = weight_fp8.T | |
act_fp8, act_scale = torch.compile(quantize_fp8_rowwise)(act) | |
out = torch.empty(M, N, device=act.device, dtype=act.dtype) | |
fp8_mm_kernel[grid]( | |
act_fp8, | |
weight_fp8_t, | |
out, | |
act_scale, | |
weight_scale, | |
M, | |
N, | |
K, | |
*act_fp8.stride(), | |
*weight_fp8_t.stride(), | |
*out.stride(), | |
EVEN_K=K % 2 == 0 | |
) | |
return out | |
if __name__ == "__main__": | |
from triton.testing import do_bench | |
act_bf16 = torch.randn(1024, 2048).bfloat16().cuda() | |
weight_bf16 = torch.randn(4096, 2048).bfloat16().cuda() | |
ref = act_bf16 @ weight_bf16.T | |
weight_fp8, weight_scale = quantize_fp8_rowwise(weight_bf16) | |
out1 = fp8_linear_dynamic_act(act_bf16, weight_fp8, weight_scale) | |
print((ref - out1).abs() / ref.abs()) | |
print("BF16 linear:", do_bench(lambda: act_bf16 @ weight_bf16.T, fast_flush=False, return_mode="median")) | |
print( | |
"FP8-rowwise linear:", | |
do_bench( | |
lambda: fp8_linear_dynamic_act(act_bf16, weight_fp8, weight_scale), fast_flush=False, return_mode="median" | |
), | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment