Created
July 11, 2025 19:30
-
-
Save drisspg/8f46dad45beef1cb6f16efe20b051356 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchao.float8.inference import ( | |
addmm_float8_unwrapped_inference, | |
preprocess_data, | |
Float8MMConfig, | |
) | |
def ceil_div(a, b): | |
return (a + b - 1) // b | |
def get_e8_scales(A: torch.Tensor, B: torch.Tensor, use_zeros: bool = False): | |
M, K = A.shape | |
_, N = B.shape | |
n_a_rows = ceil_div(M, 128) * 128 | |
n_a_cols = ceil_div(K, 32) | |
n_b_rows = ceil_div(N, 128) * 128 | |
n_b_cols = ceil_div(K, 32) | |
# Use zeros or random values based on the flag | |
if use_zeros: | |
a_scales = torch.zeros(n_a_rows, n_a_cols, dtype=torch.float32, device="cuda").to( | |
torch.float8_e8m0fnu | |
) | |
b_scales = torch.zeros(n_b_rows, n_b_cols, dtype=torch.float32, device="cuda").to( | |
torch.float8_e8m0fnu | |
) | |
else: | |
a_scales = torch.randn(n_a_rows, n_a_cols, dtype=torch.float32, device="cuda").to( | |
torch.float8_e8m0fnu | |
) | |
b_scales = torch.randn(n_b_rows, n_b_cols, dtype=torch.float32, device="cuda").to( | |
torch.float8_e8m0fnu | |
) | |
return a_scales, b_scales | |
def main(): | |
# Create input matrices | |
M, K, N = 2048, 2048, 2048 | |
A = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) | |
B = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) | |
A_fp8 = A.to(torch.float8_e4m3fn) | |
B_fp8 = B.to(torch.float8_e4m3fn) | |
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=False)) | |
# Get E8M0 scales for MX-FP8 | |
a_scales, b_scales = get_e8_scales(A_fp8, B_fp8, False) | |
print(f"Input A shape: {A_fp8.shape}, dtype: {A_fp8.dtype}") | |
print(f"Input B shape: {B_fp8.shape}, dtype: {B_fp8.dtype}") | |
print(f"A scales shape: {a_scales.shape}, dtype: {a_scales.dtype}") | |
print(f"B scales shape: {b_scales.shape}, dtype: {b_scales.dtype}") | |
# Example 1: Output as BF16 (typical case) | |
result_bf16 = torch._scaled_mm( | |
A_fp8, | |
B_fp8, | |
scale_a=a_scales, | |
scale_b=b_scales, | |
out_dtype=torch.bfloat16 | |
) | |
print(f"BF16 output shape: {result_bf16.shape}, dtype: {result_bf16.dtype}") | |
# DOESN"T WORK TODAY | |
# Example 2: Output as FP8 E4M3 with scale_result | |
# Need dummy output scale for FP8 output | |
# output_scale = torch.tensor(1.0, device="cuda", dtype=torch.float32) | |
# result_fp8 = torch._scaled_mm( | |
# A_fp8, | |
# B_fp8, | |
# scale_a=a_scales, | |
# scale_b=b_scales, | |
# # scale_result=output_scale, | |
# out_dtype=torch.float8_e4m3fn | |
# ) | |
# print(f"FP8 output shape: {result_fp8.shape}, dtype: {result_fp8.dtype}") | |
# Example 3: Using per-tensor scales instead of E8M0 | |
a_scale_tensor = torch.tensor(1.0, device="cuda", dtype=torch.float32) | |
b_scale_tensor = torch.tensor(1.0, device="cuda", dtype=torch.float32) | |
result_simple = torch._scaled_mm( | |
A_fp8, | |
B_fp8, | |
scale_a=a_scale_tensor, | |
scale_b=b_scale_tensor, | |
# scale_result=output_scale, | |
out_dtype=torch.float8_e4m3fn | |
) | |
print(f"Simple scaled output shape: {result_simple.shape}, dtype: {result_simple.dtype}") | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment