Created
July 10, 2025 21:07
-
-
Save Chillee/6f1a8995dc25c08b11494485d4a53460 to your computer and use it in GitHub Desktop.
Random Kernel Microbenchmarks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import time | |
from typing import Type | |
import torch | |
import torch.nn.functional as F | |
import torch._inductor.config | |
torch._inductor.config.triton.multi_kernel = True | |
from triton.testing import do_bench | |
import cutlass | |
import cutlass.torch as cutlass_torch | |
from quack.softmax import softmax | |
def run_softmax( | |
M, | |
N, | |
dtype: Type[cutlass.Numeric], | |
warmup_iterations=10, | |
iterations=1000, | |
): | |
if not torch.cuda.is_available(): | |
raise RuntimeError(f"Ampere GPU is required to run this example!") | |
print(f"Tensor dimensions: [{M}, {N}]") | |
print(f"Input and Output Data type: {dtype}") | |
torch_dtype = cutlass_torch.dtype(dtype) | |
device = "cuda" | |
x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype) | |
print(f"Input tensor shapes:") | |
print(f"x: {x.shape}, dtype: {x.dtype}") | |
out = softmax(x) | |
torch._dynamo.config.recompile_limit = 1024 | |
compiled_func_ref = torch.compile(lambda x: F.softmax(x, dim=-1), dynamic=False, mode="max-autotune-no-cudagraphs") | |
fn = lambda: softmax(x) | |
time.sleep(0.5) | |
avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) | |
mem_bw = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) | |
print(f"Kernel execution time: {avg_time:.4f} ms") | |
print(f"Mem throughput: {mem_bw:.2f} GB/s") | |
fn = lambda: compiled_func_ref(x) | |
for _ in range(5): fn() # warm up | |
time.sleep(0.5) | |
avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) | |
mem_bw_ref = round(2 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) | |
print(f"Ref kernel execution time: {avg_time:.4f} ms") | |
print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") | |
return mem_bw, mem_bw_ref | |
def run_softmax_backward( | |
M, | |
N, | |
dtype: Type[cutlass.Numeric], | |
warmup_iterations=10, | |
iterations=1000, | |
): | |
if not torch.cuda.is_available(): | |
raise RuntimeError(f"Ampere GPU is required to run this example!") | |
print(f"Tensor dimensions: [{M}, {N}]") | |
print(f"Input and Output Data type: {dtype}") | |
torch_dtype = cutlass_torch.dtype(dtype) | |
device = "cuda" | |
x = 0.1 * torch.randn(M, N, device=device, dtype=torch_dtype, requires_grad=True) | |
x_ref = x.detach().clone().requires_grad_() | |
print(f"Input tensor shapes:") | |
print(f"x: {x.shape}, dtype: {x.dtype}") | |
y = softmax(x) | |
dy = torch.randn_like(y) | |
time.sleep(0.5) | |
fn = lambda: torch.autograd.grad(y, x, grad_outputs=dy, retain_graph=True) | |
avg_time = do_bench(fn, warmup=warmup_iterations, rep=iterations) | |
# Memory: read dy and y, write ax backward | |
mem_bw = round(3 * x.numel() * dtype.width // 8 / (avg_time / 1000) / 1e9) | |
print(f"Kernel execution time: {avg_time:.4f} ms") | |
print(f"Mem throughput: {mem_bw:.2f} GB/s") | |
# Reference implementation | |
y_ref = F.softmax(x_ref, dim=-1) | |
compiled_func_ref = torch.compile(lambda: torch.autograd.grad(y_ref, x_ref, grad_outputs=dy, retain_graph=True)) | |
for _ in range(5): compiled_func_ref() # warm up | |
time.sleep(0.5) | |
avg_time_ref = do_bench(compiled_func_ref, warmup=warmup_iterations, rep=iterations) | |
mem_bw_ref = round(3 * x.numel() * dtype.width // 8 / (avg_time_ref / 1000) / 1e9) | |
print(f"Ref kernel execution time: {avg_time_ref:.4f} ms") | |
print(f"Ref mem throughput: {mem_bw_ref:.2f} GB/s") | |
return mem_bw, mem_bw_ref | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="Benchmark softmax forward and backward passes" | |
) | |
parser.add_argument("--M", default=8192, type=int) | |
parser.add_argument("--N", default=16384, type=int) | |
parser.add_argument("--dtype", type=cutlass.dtype, choices=[cutlass.BFloat16, cutlass.Float16, cutlass.Float32], default=cutlass.BFloat16) | |
parser.add_argument("--warmup_iterations", default=10, type=int) | |
parser.add_argument("--iterations", default=100, type=int) | |
parser.add_argument("--backward", action="store_true", help="Benchmark backward pass instead of forward pass") | |
args = parser.parse_args() | |
torch.manual_seed(0) | |
# if args.backward: | |
# print("=== Softmax Backward Pass Benchmark ===") | |
# run_softmax_backward( | |
# args.M, | |
# args.N, | |
# dtype=args.dtype, | |
# warmup_iterations=args.warmup_iterations, | |
# iterations=args.iterations, | |
# ) | |
# else: | |
# print("=== Softmax Forward Pass Benchmark ===") | |
# run_softmax( | |
# args.M, | |
# args.N, | |
# dtype=args.dtype, | |
# warmup_iterations=args.warmup_iterations, | |
# iterations=args.iterations, | |
# ) | |
# exit(0) | |
MN_pairs = [(32768, 256), (32768, 512), (32768, 1024), (32768, 2048), (32768, 4096), (32768, 8192), (32768, 16384), (32768, 32768), (32768, 65536), (16384, 131072), (8192, 262144)] | |
# MN_pairs = [(32768, 32768)] | |
# # MN_pairs = [(32768, 1024)] | |
results = [] | |
for M, N in MN_pairs: | |
res = run_softmax( | |
M, | |
N, | |
dtype=args.dtype, | |
warmup_iterations=args.warmup_iterations, | |
iterations=args.iterations, | |
) | |
results.append(res) | |
print(results) | |
# print([x for x, _ in results]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment