Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created December 18, 2024 18:29
Show Gist options
  • Save davidberard98/b7ead38c7230c17e4f0eb5acf9685f81 to your computer and use it in GitHub Desktop.
Save davidberard98/b7ead38c7230c17e4f0eb5acf9685f81 to your computer and use it in GitHub Desktop.
"""
Original kernel is from https://github.com/triton-lang/triton/issues/4906.
This kernel is modified to use dot_scaled and fp4. It _should_ be faster than int4 because it skips the int->float conversion, but it's not.
"""
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly;
# OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 ipython3 A100_vs_4090_test.py
##########################################################################
import torch
# @manual=//triton:triton
import triton
# @manual=//triton:triton
import triton.language as tl
# @manual=//triton:triton
from triton.testing import do_bench
def eval_time(fct, params):
return do_bench(lambda: fct(**params), warmup=200, rep=1000)
@triton.jit()
def swizzle_tile(
pid,
M,
N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
grid_m = tl.cdiv(M, BLOCK_SIZE_M)
grid_n = tl.cdiv(N, BLOCK_SIZE_N)
width = GROUP_SIZE_M * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * GROUP_SIZE_M, GROUP_SIZE_M)
pid_m = group_id * GROUP_SIZE_M + (pid % group_size)
pid_n = (pid % width) // group_size
return pid_m, pid_n
@triton.jit()
def linear_tile(
pid,
M,
N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)
pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)
return pid_m, pid_n
@triton.jit
def dummy_kernel(
a_ptr,
b_ptr,
c_ptr,
M,
N,
K,
elements_per_sample: tl.constexpr,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
pid = tl.program_id(axis=0)
pid_k = 0
pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M)
# pid_m, pid_n = linear_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M)
# ------------------------------------------------
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
# Offsets
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
QUANT_SIZE_K: tl.constexpr = BLOCK_SIZE_K // elements_per_sample
offs_k_quant = pid_k * QUANT_SIZE_K + tl.arange(0, QUANT_SIZE_K)
# Vectorized coalesced load
offs_am = offs_m
offs_bn = offs_n
# offs_am = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M)
# offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k_quant[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in tl.range(0, num_pid_k, 1, num_stages=1):
# Best for 4090 RTX
b = tl.load(b_ptrs, eviction_policy="evict_first")
a = tl.load(a_ptrs, eviction_policy="evict_last")
# Best for A100
# b = tl.load(b_ptrs, eviction_policy='evict_first')
# a = tl.load(a_ptrs, eviction_policy='')
#######################################
# Shifts
# q_shifts = offs_k[:, None] #This is fast
# q_shifts = offs_k[:, None] % 8 #This is much slower
# b = (b >> q_shifts) & 0x0F
#######################################
# acc = tl.dot(a, b.to(a.dtype), acc=acc, out_dtype=tl.float32)
b_scale = tl.full([BLOCK_SIZE_N, BLOCK_SIZE_K // 32], 127, dtype=tl.uint8)
acc = tl.dot_scaled(a, None, "bf16", b, b_scale, "e2m1", acc) # pyre-ignore[16]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += (BLOCK_SIZE_K // elements_per_sample) * stride_bk
acc = acc.to(tl.bfloat16)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
tl.store(c_ptrs, acc)
def forward(x, W_q, elements_per_sample, debug=False):
M, K, N = x.shape[0], x.shape[1], W_q.shape[1]
output = torch.zeros((M, N), device=W_q.device, dtype=torch.bfloat16)
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = 128
num_stages = 2
num_warps = 4
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)
k = dummy_kernel[grid](
x,
W_q,
output,
M,
N,
K,
elements_per_sample,
x.stride(0),
x.stride(1),
W_q.stride(0),
W_q.stride(1),
output.stride(0),
output.stride(1),
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=8,
num_stages=num_stages,
num_warps=num_warps,
)
if debug:
with open("dequant_simple.txt", "w") as f:
print(f"{k.n_regs} registers used, {k.n_spills} spills\n", file=f)
print("IR", k.asm["ttir"], file=f)
print("TTGIR", k.asm["ttgir"], file=f)
print("PTX", k.asm["ptx"], file=f)
return output
##################################################################################
torch.manual_seed(1)
M, N, K = 1, 4096 * 4, 4096 * 4
# TRITON_LLVM_DEBUG_ONLY="triton-matmul-loop-pipeline,axis-info"
# input_dtype, elements_per_sample = torch.float16, 1 #FP16
# input_dtype, elements_per_sample = torch.int8, 1 // 1 #INT8
# input_dtype, elements_per_sample = torch.int8, 8 // 4 #INT4
# input_dtype, elements_per_sample = torch.int32, 32 // 4 # INT4
# input_dtype, elements_per_sample = torch.int32, 32 // 1 #INT1
input_dtype, elements_per_sample = torch.uint8, 8 // 4 # FP4
W = torch.randn((N, K), dtype=torch.bfloat16, device="cuda")
W_q = (
torch.randint(
# 0, 2**4, (N, K // elements_per_sample), dtype=input_dtype, device="cuda"
0,
2**8,
(N, K // elements_per_sample),
dtype=input_dtype,
device="cuda",
)
.t()
.contiguous()
) # Col-major
print(W_q.shape)
# W_q *= 0
x = torch.randn((M, K), dtype=torch.bfloat16, device="cuda").contiguous() # row-major
# out = forward(x, W_q, debug=True)
ref = eval_time(lambda x: torch.matmul(x, W.T), {"x": x.to(W.dtype)})
new = eval_time(
forward, {"x": x, "W_q": W_q, "elements_per_sample": elements_per_sample}
)
print("ref", ref)
print("took", new, ref / new)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment