Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created November 25, 2024 22:45
Show Gist options
  • Save davidberard98/76d5e1f838c4e504c08859deb1321237 to your computer and use it in GitHub Desktop.
Save davidberard98/76d5e1f838c4e504c08859deb1321237 to your computer and use it in GitHub Desktop.
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly;
# OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 ipython3 A100_vs_4090_test.py
##########################################################################
import torch
import triton
import triton.language as tl
from triton.testing import do_bench
import itertools
def eval_time(fct, params):
return do_bench(lambda: fct(**params), warmup=200, rep=1000) # , fast_flush=True, return_mode='min')
@triton.jit()
def swizzle_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
grid_m = tl.cdiv(M, BLOCK_SIZE_M)
grid_n = tl.cdiv(N, BLOCK_SIZE_N)
width = GROUP_SIZE_M * grid_n
group_id = pid // width
group_size = tl.minimum(grid_m - group_id * GROUP_SIZE_M, GROUP_SIZE_M)
pid_m = group_id * GROUP_SIZE_M + (pid % group_size)
pid_n = (pid % width) // group_size
return pid_m, pid_n
@triton.jit()
def linear_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)
pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)
return pid_m, pid_n
'''
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_N": n, "BLOCK_SIZE_K": k, "GROUP_SIZE_M": 8}, num_stages=ns, num_warps=nw)
for m, n, k, ns, nw in itertools.product([16, 32], [16, 32, 64], [16, 32, 64, 128], [2, 3], [2, 4, 8])
],
key=[],
)
'''
@triton.jit
def dummy_kernel(
a_ptr, b_ptr, c_ptr,
M, N, K,
elements_per_sample: tl.constexpr,
b_type: tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr
):
pid = tl.program_id(axis=0)
pid_k = 0
pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M)
#pid_m, pid_n = linear_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M)
#------------------------------------------------
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
#Offsets
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
#Vectorized coalesced load
offs_am = offs_m
offs_bn = offs_n
# offs_am = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M)
# offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
# b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn)
offs_k_quantized = pid_k * BLOCK_SIZE_K // elements_per_sample + tl.arange(0, BLOCK_SIZE_K // elements_per_sample)
b_ptrs = b_ptr + (offs_k_quantized[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in tl.range(0, num_pid_k, 1, num_stages=1):
#Best for 4090 RTX
b = tl.load(b_ptrs, eviction_policy='evict_first')
a = tl.load(a_ptrs, eviction_policy='evict_last')
#Best for A100
#b = tl.load(b_ptrs, eviction_policy='evict_first')
#a = tl.load(a_ptrs, eviction_policy='')
#######################################
#Shifts
#q_shifts = offs_k[:, None] #This is fast
#q_shifts = offs_k[:, None] % 8 #This is much slower
#b = (b >> q_shifts) & 0x0F
#######################################
# acc = tl.dot(a, b.to(a.dtype), acc=acc, out_dtype=tl.float32)
b_scale = tl.full([BLOCK_SIZE_N, BLOCK_SIZE_K // 32], 127, dtype=tl.uint8)
acc = tl.dot_scaled(a, None, "bf16", b, b_scale, "e2m1", acc=acc)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += (BLOCK_SIZE_K // elements_per_sample) * stride_bk
acc = acc.to(tl.float16)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
tl.store(c_ptrs, acc)
def forward(x, W_q, elements_per_sample, b_type, debug=False):
M, K, N = x.shape[0], x.shape[1], W_q.shape[1]
output = torch.zeros((M, N), device=W_q.device, dtype=torch.float16)
BLOCK_SIZE_M = 16
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = 128
num_stages = 2
num_warps = 4
def grid(meta):
return (triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_M"]),)
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),)
k = dummy_kernel[grid](
x, W_q, output,
M, N, K,
elements_per_sample,
b_type,
x.stride(0), x.stride(1),
W_q.stride(0), W_q.stride(1),
output.stride(0), output.stride(1),
BLOCK_SIZE_M = BLOCK_SIZE_M, BLOCK_SIZE_N = BLOCK_SIZE_N, BLOCK_SIZE_K = BLOCK_SIZE_K,
GROUP_SIZE_M = 8,
num_stages = num_stages, num_warps = num_warps,
)
if(debug):
with open('dequant_simple.txt', 'w') as f:
print(f"{k.n_regs} registers used, {k.n_spills} spills\n", file=f)
print("IR", k.asm['ttir'], file=f)
print("TTGIR", k.asm['ttgir'], file=f)
print("PTX", k.asm['ptx'], file=f)
return output
##################################################################################
torch.manual_seed(1)
M, N, K = 1, 4096*4, 4096*4
#TRITON_LLVM_DEBUG_ONLY="triton-matmul-loop-pipeline,axis-info"
#input_dtype, elements_per_sample = torch.float16, 1 #FP16
#input_dtype, elements_per_sample = torch.int8, 1 // 1 #INT8
#input_dtype, elements_per_sample = torch.int8, 8 // 4 #INT4
#input_dtype, elements_per_sample, b_type = torch.int32, 32 // 4, #INT4
#input_dtype, elements_per_sample = torch.int32, 32 // 1 #INT1
input_dtype, elements_per_sample, b_type = torch.uint8, 8 // 4, 'e2m1' #INT4
W = torch.randn((N, K), dtype=torch.bfloat16, device='cuda')
W_q = torch.randint(0, 2**4, (N, K // elements_per_sample), dtype=input_dtype, device='cuda').t().contiguous() #Col-major
print(W_q.shape)
#W_q *= 0
x = torch.randn((M, K), dtype=torch.bfloat16, device='cuda').contiguous() #row-major
#out = forward(x, W_q, debug=True)
ref = eval_time(lambda x: torch.matmul(x, W.T), {'x':x.to(W.dtype)})
new = eval_time(forward, {'x':x, 'W_q':W_q, 'elements_per_sample':elements_per_sample, 'b_type':b_type})
print('ref', ref)
print('took', new, ref / new)
#A100 SXM4:
#Swizzle:
#q_shifts = offs_k[:, None] -> 0.17510400712490082
#q_shifts = offs_k[:, None] % 8 -> 0.30720001459121704
#Linear:
#q_shifts = offs_k[:, None] -> 0.17612800002098083
#q_shifts = offs_k[:, None] % 8 -> 0.30822399258613586
#######################################################################
#A100 SXM4 (ref fp16 x fp16: 0.3256320059299469)
#FP16 x FP16
#took 0.3256320059299469 (1.0x): OK
#FP16 x INT4
# took 0.289792001247406 (1.13x): Poor performance
#FP16 x INT1
#took 0.1802240014076233 (1.81x): Poor performance
#4090 RTX (ref fp16 x fp16: 0.5857279896736145)
#FP16 x FP16
# took 0.6256639957427979 (0.94x): OK
#FP16 x INT4
# took 0.17203199863433838 (3.40x): OK
#FP16 x INT1
# took 0.10342399775981903 (5.66x): OK
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment