Created
August 25, 2025 20:15
-
-
Save ita9naiwa/ac4cc770379504b573c460d3861675ab to your computer and use it in GitHub Desktop.
sm_120.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import triton | |
import triton.language as tl | |
@triton.jit | |
def mxfp_matmul( | |
a_ptr, b_ptr, output_ptr, | |
a_scale, b_scale, | |
M, N, K, | |
stride_scale: tl.constexpr, | |
stride_am, stride_ak, | |
stride_bk, stride_bn, | |
stride_cm, stride_cn, | |
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, | |
NUM_STAGES: tl.constexpr): | |
pid = tl.program_id(axis=0) | |
num_pid_m = tl.cdiv(M, BLOCK_M) | |
pid_m = pid % num_pid_m | |
pid_n = pid // num_pid_m | |
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M | |
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N | |
offs_k = tl.arange(0, BLOCK_K) | |
offs_scale_k = tl.arange(0, BLOCK_K // 32) | |
a_scale_ptr = a_scale + offs_am[:, None] * stride_scale + offs_scale_k[None, :] | |
b_scale_ptr = b_scale + offs_bn[:, None] * stride_scale + offs_scale_k[None, :] | |
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) | |
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) | |
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty) | |
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES): | |
a = tl.load(a_ptrs) | |
b = tl.load(b_ptrs) | |
scale_a = tl.load(a_scale_ptr) | |
scale_b = tl.load(b_scale_ptr) | |
accumulator = tl.dot_scaled(a, scale_a, 'e5m2', b, scale_b, 'e5m2', accumulator) | |
a_ptrs += BLOCK_K * stride_ak | |
b_ptrs += BLOCK_K * stride_bk | |
a_scale_ptr += BLOCK_K // 32 | |
b_scale_ptr += BLOCK_K // 32 | |
c_ptrs = output_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] | |
tl.store(c_ptrs, accumulator) | |
# Test with NUM_CTAS=2 | |
device = 'cuda' | |
M, N, K = 128, 128, 128 | |
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 128 | |
NUM_STAGES = 1 | |
torch.manual_seed(42) | |
a = torch.randint(20, 40, (M, K), dtype=torch.uint8, device=device).view(torch.float8_e5m2) | |
b = torch.randint(20, 40, (K, N), dtype=torch.uint8, device=device).view(torch.float8_e5m2) | |
a_scale = torch.randint(64, 130, (M, K // 32), dtype=torch.uint8, device=device) | |
b_scale = torch.randint(64, 130, (N, K // 32), dtype=torch.uint8, device=device) | |
output = torch.empty((M, N), dtype=torch.float32, device=device) | |
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) | |
print("Testing with NUM_CTAS=2...") | |
try: | |
out = mxfp_matmul[grid]( | |
a, b, output, a_scale, b_scale, M, N, K, | |
a_scale.stride(0), a.stride(0), a.stride(1), | |
b.stride(0), b.stride(1), output.stride(0), output.stride(1), | |
BLOCK_M, BLOCK_N, BLOCK_K, | |
NUM_STAGES=NUM_STAGES, num_warps=4, num_ctas=2 | |
) | |
print('Success with NUM_CTAS=2') | |
except RuntimeError as e: | |
print(f'Failed with NUM_CTAS=2:\n{str(e)[:2000]}') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment