Skip to content

Instantly share code, notes, and snippets.

@drisspg
Created July 11, 2025 19:30
Show Gist options
  • Save drisspg/8f46dad45beef1cb6f16efe20b051356 to your computer and use it in GitHub Desktop.
Save drisspg/8f46dad45beef1cb6f16efe20b051356 to your computer and use it in GitHub Desktop.
import torch
from torchao.float8.inference import (
addmm_float8_unwrapped_inference,
preprocess_data,
Float8MMConfig,
)
def ceil_div(a, b):
return (a + b - 1) // b
def get_e8_scales(A: torch.Tensor, B: torch.Tensor, use_zeros: bool = False):
M, K = A.shape
_, N = B.shape
n_a_rows = ceil_div(M, 128) * 128
n_a_cols = ceil_div(K, 32)
n_b_rows = ceil_div(N, 128) * 128
n_b_cols = ceil_div(K, 32)
# Use zeros or random values based on the flag
if use_zeros:
a_scales = torch.zeros(n_a_rows, n_a_cols, dtype=torch.float32, device="cuda").to(
torch.float8_e8m0fnu
)
b_scales = torch.zeros(n_b_rows, n_b_cols, dtype=torch.float32, device="cuda").to(
torch.float8_e8m0fnu
)
else:
a_scales = torch.randn(n_a_rows, n_a_cols, dtype=torch.float32, device="cuda").to(
torch.float8_e8m0fnu
)
b_scales = torch.randn(n_b_rows, n_b_cols, dtype=torch.float32, device="cuda").to(
torch.float8_e8m0fnu
)
return a_scales, b_scales
def main():
# Create input matrices
M, K, N = 2048, 2048, 2048
A = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
B = torch.randn(K, N, device="cuda", dtype=torch.bfloat16)
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn)
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=False))
# Get E8M0 scales for MX-FP8
a_scales, b_scales = get_e8_scales(A_fp8, B_fp8, False)
print(f"Input A shape: {A_fp8.shape}, dtype: {A_fp8.dtype}")
print(f"Input B shape: {B_fp8.shape}, dtype: {B_fp8.dtype}")
print(f"A scales shape: {a_scales.shape}, dtype: {a_scales.dtype}")
print(f"B scales shape: {b_scales.shape}, dtype: {b_scales.dtype}")
# Example 1: Output as BF16 (typical case)
result_bf16 = torch._scaled_mm(
A_fp8,
B_fp8,
scale_a=a_scales,
scale_b=b_scales,
out_dtype=torch.bfloat16
)
print(f"BF16 output shape: {result_bf16.shape}, dtype: {result_bf16.dtype}")
# DOESN"T WORK TODAY
# Example 2: Output as FP8 E4M3 with scale_result
# Need dummy output scale for FP8 output
# output_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32)
# result_fp8 = torch._scaled_mm(
# A_fp8,
# B_fp8,
# scale_a=a_scales,
# scale_b=b_scales,
# # scale_result=output_scale,
# out_dtype=torch.float8_e4m3fn
# )
# print(f"FP8 output shape: {result_fp8.shape}, dtype: {result_fp8.dtype}")
# Example 3: Using per-tensor scales instead of E8M0
a_scale_tensor = torch.tensor(1.0, device="cuda", dtype=torch.float32)
b_scale_tensor = torch.tensor(1.0, device="cuda", dtype=torch.float32)
result_simple = torch._scaled_mm(
A_fp8,
B_fp8,
scale_a=a_scale_tensor,
scale_b=b_scale_tensor,
# scale_result=output_scale,
out_dtype=torch.float8_e4m3fn
)
print(f"Simple scaled output shape: {result_simple.shape}, dtype: {result_simple.dtype}")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment