Created
November 25, 2024 22:45
-
-
Save davidberard98/76d5e1f838c4e504c08859deb1321237 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly; | |
# OMP_NUM_THREADS=16 CUDA_VISIBLE_DEVICES=0 ipython3 A100_vs_4090_test.py | |
########################################################################## | |
import torch | |
import triton | |
import triton.language as tl | |
from triton.testing import do_bench | |
import itertools | |
def eval_time(fct, params): | |
return do_bench(lambda: fct(**params), warmup=200, rep=1000) # , fast_flush=True, return_mode='min') | |
@triton.jit() | |
def swizzle_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr): | |
grid_m = tl.cdiv(M, BLOCK_SIZE_M) | |
grid_n = tl.cdiv(N, BLOCK_SIZE_N) | |
width = GROUP_SIZE_M * grid_n | |
group_id = pid // width | |
group_size = tl.minimum(grid_m - group_id * GROUP_SIZE_M, GROUP_SIZE_M) | |
pid_m = group_id * GROUP_SIZE_M + (pid % group_size) | |
pid_n = (pid % width) // group_size | |
return pid_m, pid_n | |
@triton.jit() | |
def linear_tile(pid, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, GROUP_SIZE_M: tl.constexpr): | |
pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N) | |
pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N) | |
return pid_m, pid_n | |
''' | |
@triton.autotune( | |
configs=[ | |
triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_N": n, "BLOCK_SIZE_K": k, "GROUP_SIZE_M": 8}, num_stages=ns, num_warps=nw) | |
for m, n, k, ns, nw in itertools.product([16, 32], [16, 32, 64], [16, 32, 64, 128], [2, 3], [2, 4, 8]) | |
], | |
key=[], | |
) | |
''' | |
@triton.jit | |
def dummy_kernel( | |
a_ptr, b_ptr, c_ptr, | |
M, N, K, | |
elements_per_sample: tl.constexpr, | |
b_type: tl.constexpr, | |
stride_am, stride_ak, | |
stride_bk, stride_bn, | |
stride_cm, stride_cn, | |
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, | |
GROUP_SIZE_M: tl.constexpr | |
): | |
pid = tl.program_id(axis=0) | |
pid_k = 0 | |
pid_m, pid_n = swizzle_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M) | |
#pid_m, pid_n = linear_tile(pid, M, N, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M) | |
#------------------------------------------------ | |
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) | |
#Offsets | |
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) | |
#Vectorized coalesced load | |
offs_am = offs_m | |
offs_bn = offs_n | |
# offs_am = tl.max_contiguous(tl.multiple_of(offs_m, BLOCK_SIZE_M), BLOCK_SIZE_M) | |
# offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N) | |
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | |
# b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn) | |
offs_k_quantized = pid_k * BLOCK_SIZE_K // elements_per_sample + tl.arange(0, BLOCK_SIZE_K // elements_per_sample) | |
b_ptrs = b_ptr + (offs_k_quantized[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | |
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) | |
for k in tl.range(0, num_pid_k, 1, num_stages=1): | |
#Best for 4090 RTX | |
b = tl.load(b_ptrs, eviction_policy='evict_first') | |
a = tl.load(a_ptrs, eviction_policy='evict_last') | |
#Best for A100 | |
#b = tl.load(b_ptrs, eviction_policy='evict_first') | |
#a = tl.load(a_ptrs, eviction_policy='') | |
####################################### | |
#Shifts | |
#q_shifts = offs_k[:, None] #This is fast | |
#q_shifts = offs_k[:, None] % 8 #This is much slower | |
#b = (b >> q_shifts) & 0x0F | |
####################################### | |
# acc = tl.dot(a, b.to(a.dtype), acc=acc, out_dtype=tl.float32) | |
b_scale = tl.full([BLOCK_SIZE_N, BLOCK_SIZE_K // 32], 127, dtype=tl.uint8) | |
acc = tl.dot_scaled(a, None, "bf16", b, b_scale, "e2m1", acc=acc) | |
a_ptrs += BLOCK_SIZE_K * stride_ak | |
b_ptrs += (BLOCK_SIZE_K // elements_per_sample) * stride_bk | |
acc = acc.to(tl.float16) | |
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | |
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) | |
c_ptrs = c_ptr + (offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn) | |
tl.store(c_ptrs, acc) | |
def forward(x, W_q, elements_per_sample, b_type, debug=False): | |
M, K, N = x.shape[0], x.shape[1], W_q.shape[1] | |
output = torch.zeros((M, N), device=W_q.device, dtype=torch.float16) | |
BLOCK_SIZE_M = 16 | |
BLOCK_SIZE_N = 64 | |
BLOCK_SIZE_K = 128 | |
num_stages = 2 | |
num_warps = 4 | |
def grid(meta): | |
return (triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_M"]),) | |
grid = (triton.cdiv(M, BLOCK_SIZE_M) * triton.cdiv(N, BLOCK_SIZE_N),) | |
k = dummy_kernel[grid]( | |
x, W_q, output, | |
M, N, K, | |
elements_per_sample, | |
b_type, | |
x.stride(0), x.stride(1), | |
W_q.stride(0), W_q.stride(1), | |
output.stride(0), output.stride(1), | |
BLOCK_SIZE_M = BLOCK_SIZE_M, BLOCK_SIZE_N = BLOCK_SIZE_N, BLOCK_SIZE_K = BLOCK_SIZE_K, | |
GROUP_SIZE_M = 8, | |
num_stages = num_stages, num_warps = num_warps, | |
) | |
if(debug): | |
with open('dequant_simple.txt', 'w') as f: | |
print(f"{k.n_regs} registers used, {k.n_spills} spills\n", file=f) | |
print("IR", k.asm['ttir'], file=f) | |
print("TTGIR", k.asm['ttgir'], file=f) | |
print("PTX", k.asm['ptx'], file=f) | |
return output | |
################################################################################## | |
torch.manual_seed(1) | |
M, N, K = 1, 4096*4, 4096*4 | |
#TRITON_LLVM_DEBUG_ONLY="triton-matmul-loop-pipeline,axis-info" | |
#input_dtype, elements_per_sample = torch.float16, 1 #FP16 | |
#input_dtype, elements_per_sample = torch.int8, 1 // 1 #INT8 | |
#input_dtype, elements_per_sample = torch.int8, 8 // 4 #INT4 | |
#input_dtype, elements_per_sample, b_type = torch.int32, 32 // 4, #INT4 | |
#input_dtype, elements_per_sample = torch.int32, 32 // 1 #INT1 | |
input_dtype, elements_per_sample, b_type = torch.uint8, 8 // 4, 'e2m1' #INT4 | |
W = torch.randn((N, K), dtype=torch.bfloat16, device='cuda') | |
W_q = torch.randint(0, 2**4, (N, K // elements_per_sample), dtype=input_dtype, device='cuda').t().contiguous() #Col-major | |
print(W_q.shape) | |
#W_q *= 0 | |
x = torch.randn((M, K), dtype=torch.bfloat16, device='cuda').contiguous() #row-major | |
#out = forward(x, W_q, debug=True) | |
ref = eval_time(lambda x: torch.matmul(x, W.T), {'x':x.to(W.dtype)}) | |
new = eval_time(forward, {'x':x, 'W_q':W_q, 'elements_per_sample':elements_per_sample, 'b_type':b_type}) | |
print('ref', ref) | |
print('took', new, ref / new) | |
#A100 SXM4: | |
#Swizzle: | |
#q_shifts = offs_k[:, None] -> 0.17510400712490082 | |
#q_shifts = offs_k[:, None] % 8 -> 0.30720001459121704 | |
#Linear: | |
#q_shifts = offs_k[:, None] -> 0.17612800002098083 | |
#q_shifts = offs_k[:, None] % 8 -> 0.30822399258613586 | |
####################################################################### | |
#A100 SXM4 (ref fp16 x fp16: 0.3256320059299469) | |
#FP16 x FP16 | |
#took 0.3256320059299469 (1.0x): OK | |
#FP16 x INT4 | |
# took 0.289792001247406 (1.13x): Poor performance | |
#FP16 x INT1 | |
#took 0.1802240014076233 (1.81x): Poor performance | |
#4090 RTX (ref fp16 x fp16: 0.5857279896736145) | |
#FP16 x FP16 | |
# took 0.6256639957427979 (0.94x): OK | |
#FP16 x INT4 | |
# took 0.17203199863433838 (3.40x): OK | |
#FP16 x INT1 | |
# took 0.10342399775981903 (5.66x): OK |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment