Skip to content

Instantly share code, notes, and snippets.

@davidberard98
Created November 25, 2024 22:47
Show Gist options
  • Save davidberard98/b6bbaaa541edd345689064c7b315d512 to your computer and use it in GitHub Desktop.
Save davidberard98/b6bbaaa541edd345689064c7b315d512 to your computer and use it in GitHub Desktop.
--- mobicham.py 2024-11-25 14:02:15.355460967 -0800
+++ mobicham_fp4.py 2024-11-25 14:44:09.015276420 -0800
@@ -42,6 +42,7 @@
a_ptr, b_ptr, c_ptr,
M, N, K,
elements_per_sample: tl.constexpr,
+ b_type: tl.constexpr,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
@@ -69,7 +70,10 @@
# offs_bn = tl.max_contiguous(tl.multiple_of(offs_n, BLOCK_SIZE_N), BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
- b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn)
+ # b_ptrs = b_ptr + ((offs_k[:, None] // elements_per_sample) * stride_bk + offs_bn[None, :] * stride_bn)
+
+ offs_k_quantized = pid_k * BLOCK_SIZE_K // elements_per_sample + tl.arange(0, BLOCK_SIZE_K // elements_per_sample)
+ b_ptrs = b_ptr + (offs_k_quantized[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
@@ -89,7 +93,9 @@
#b = (b >> q_shifts) & 0x0F
#######################################
- acc = tl.dot(a, b.to(a.dtype), acc=acc, out_dtype=tl.float32)
+ # acc = tl.dot(a, b.to(a.dtype), acc=acc, out_dtype=tl.float32)
+ b_scale = tl.full([BLOCK_SIZE_N, BLOCK_SIZE_K // 32], 127, dtype=tl.uint8)
+ acc = tl.dot_scaled(a, None, "bf16", b, b_scale, "e2m1", acc=acc)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += (BLOCK_SIZE_K // elements_per_sample) * stride_bk
@@ -101,7 +107,7 @@
tl.store(c_ptrs, acc)
-def forward(x, W_q, elements_per_sample, debug=False):
+def forward(x, W_q, elements_per_sample, b_type, debug=False):
M, K, N = x.shape[0], x.shape[1], W_q.shape[1]
@@ -121,6 +127,7 @@
x, W_q, output,
M, N, K,
elements_per_sample,
+ b_type,
x.stride(0), x.stride(1),
W_q.stride(0), W_q.stride(1),
output.stride(0), output.stride(1),
@@ -148,20 +155,21 @@
#input_dtype, elements_per_sample = torch.float16, 1 #FP16
#input_dtype, elements_per_sample = torch.int8, 1 // 1 #INT8
#input_dtype, elements_per_sample = torch.int8, 8 // 4 #INT4
-input_dtype, elements_per_sample = torch.int32, 32 // 4 #INT4
+#input_dtype, elements_per_sample, b_type = torch.int32, 32 // 4, #INT4
#input_dtype, elements_per_sample = torch.int32, 32 // 1 #INT1
+input_dtype, elements_per_sample, b_type = torch.uint8, 8 // 4, 'e2m1' #INT4
-W = torch.randn((N, K), dtype=torch.float16, device='cuda')
+W = torch.randn((N, K), dtype=torch.bfloat16, device='cuda')
W_q = torch.randint(0, 2**4, (N, K // elements_per_sample), dtype=input_dtype, device='cuda').t().contiguous() #Col-major
print(W_q.shape)
#W_q *= 0
-x = torch.randn((M, K), dtype=torch.float16, device='cuda').contiguous() #row-major
+x = torch.randn((M, K), dtype=torch.bfloat16, device='cuda').contiguous() #row-major
#out = forward(x, W_q, debug=True)
ref = eval_time(lambda x: torch.matmul(x, W.T), {'x':x.to(W.dtype)})
-new = eval_time(forward, {'x':x, 'W_q':W_q, 'elements_per_sample':elements_per_sample})
+new = eval_time(forward, {'x':x, 'W_q':W_q, 'elements_per_sample':elements_per_sample, 'b_type':b_type})
print('ref', ref)
print('took', new, ref / new)
#A100 SXM4:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment