Skip to content

Instantly share code, notes, and snippets.

@scottt
Created May 26, 2025 01:37
Show Gist options
  • Save scottt/fb45ba422f9f133223ebb281fca8dc5d to your computer and use it in GitHub Desktop.
Save scottt/fb45ba422f9f133223ebb281fca8dc5d to your computer and use it in GitHub Desktop.
Pytorch Performance Validation
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend
###############################################################################
# Check for GPU
###############################################################################
if not torch.cuda.is_available():
raise SystemExit("CUDA GPU is not available. Please run on a CUDA-enabled device.")
device = torch.device("cuda")
torch.cuda.init() # Initialize CUDA context (optional, helps measure baseline)
###############################################################################
# Helper function for measuring one run
###############################################################################
def measure_op(op_func, warmup=3, total_runs=10):
"""
op_func: a callable that runs the operation (including memory measurement)
and returns (time_ms, peak_mem_MB, gflops_s).
warmup: number of warm-up runs to discard.
total_runs: total runs to do, including warmup.
Returns: average_time_ms, average_peak_mem_MB, average_gflops_s over the runs after warm-up.
"""
times = []
mems = []
flops = []
for run_idx in range(total_runs):
# Reset peak memory stats at the start of each run
torch.cuda.reset_peak_memory_stats(device)
with torch.no_grad():
t_ms, peak_mb, gf_s = op_func()
if run_idx >= warmup:
times.append(t_ms)
mems.append(peak_mb)
flops.append(gf_s)
avg_time_ms = sum(times) / len(times)
avg_mem_mb = sum(mems) / len(mems)
avg_gf_s = sum(flops) / len(flops)
return avg_time_ms, avg_mem_mb, avg_gf_s
###############################################################################
# 1) Define the Scaled Dot-Product Attention test
###############################################################################
def run_sdpa():
# Configuration
B, heads = 1, 8
L = 8192
E = 64
S = L
# Create random Q, K, V in half precision
q = torch.randn(B, heads, L, E, device=device, dtype=torch.float16)
k = torch.randn(B, heads, S, E, device=device, dtype=torch.float16)
v = torch.randn(B, heads, S, E, device=device, dtype=torch.float16)
# Create CUDA events for timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Record start
torch.cuda.synchronize()
start_event.record()
# Run scaled dot-product attention (Flash Attention backend)
with torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION):
out = scaled_dot_product_attention(q, k, v)
# Record end and synchronize
end_event.record()
torch.cuda.synchronize()
# Elapsed time in milliseconds
time_ms = start_event.elapsed_time(end_event)
# Peak memory usage (MB)
peak_mem_bytes = torch.cuda.max_memory_allocated(device)
peak_mem_mb = peak_mem_bytes / (1024**2)
# Compute FLOPs for scaled dot-product attention:
# Q*K^T -> 2 * B * heads * L * S * E
# Attn*V -> 2 * B * heads * L * S * E
# Total = 4 * B * heads * L * S * E
flops = 4.0 * B * heads * L * S * E
# Convert to GFLOPs/s
elapsed_s = time_ms / 1000.0
flops_s = flops / elapsed_s
gflops_s = flops_s / 1e9
return time_ms, peak_mem_mb, gflops_s
###############################################################################
# 2) Define the Conv2d test
###############################################################################
def run_conv2d():
# Configuration
N, Cin, Cout = 1, 3, 64
H, W = 2048, 2048
kernel_size = 7
stride = 1
padding = 3
# Create Conv2d layer in half precision
conv = torch.nn.Conv2d(Cin, Cout, kernel_size=kernel_size,
stride=stride, padding=padding).to(device, dtype=torch.float16)
x = torch.randn(N, Cin, H, W, device=device, dtype=torch.float16)
# Create CUDA events for timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Record start
torch.cuda.synchronize()
start_event.record()
# Forward pass
y = conv(x)
# Record end and synchronize
end_event.record()
torch.cuda.synchronize()
# Elapsed time in milliseconds
time_ms = start_event.elapsed_time(end_event)
# Peak memory usage (MB)
peak_mem_bytes = torch.cuda.max_memory_allocated(device)
peak_mem_mb = peak_mem_bytes / (1024**2)
# Compute FLOPs for Conv2d:
# 2 * N * Cout * H_out * W_out * Cin * kH * kW
N_out, Cout_out, H_out, W_out = y.shape
Cin_out = conv.in_channels
kH, kW = conv.kernel_size if isinstance(conv.kernel_size, tuple) else (conv.kernel_size, conv.kernel_size)
flops = 2.0 * N_out * Cout_out * H_out * W_out * Cin_out * kH * kW
# Convert to GFLOPs/s
elapsed_s = time_ms / 1000.0
flops_s = flops / elapsed_s
gflops_s = flops_s / 1e9
return time_ms, peak_mem_mb, gflops_s
###############################################################################
# 3) Define the GEMM test
# Inspired from https://github.com/shisa-ai/mamf-finder/blob/main/mamf-finder.py
###############################################################################
def benchmark_mm(m, n, k, dtype, device, warmup=3, total_runs=10):
"""
Basic GEMM benchmark. Allocates A, B, C, C_rand; re-randomizes C before each run
to force writes, but does not emulate L2/LLC cache flushing.
Returns average_time_ms, average_peak_mem_MB, average_gflops_s.
"""
# Prepare output & random target to force writes
C = torch.empty(m, n, dtype=dtype, device=device).contiguous()
C_rand = torch.randn(m, n, dtype=dtype, device=device).contiguous()
A = torch.randn(m, k, dtype=dtype, device=device).contiguous()
B = torch.randn(k, n, dtype=dtype, device=device).contiguous()
mm_op = lambda: torch.mm(A, B)
flops_per_run = 2.0 * m * n * k
times = []
mems = []
gflops = []
for run_idx in range(total_runs):
# Reset peak memory stats
torch.cuda.reset_peak_memory_stats(device)
# Re-randomize C to force a write
C.copy_(C_rand)
# Create CUDA events for timing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# Record start
torch.cuda.synchronize()
start_event.record()
# Run GEMM
_ = mm_op()
# Record end and synchronize
end_event.record()
torch.cuda.synchronize()
if run_idx >= warmup:
t_ms = start_event.elapsed_time(end_event)
times.append(t_ms)
peak_mb = torch.cuda.max_memory_allocated(device) / (1024**2)
mems.append(peak_mb)
elapsed_s = t_ms / 1000.0
gflops_s = flops_per_run / elapsed_s / 1e9
gflops.append(gflops_s)
avg_time = sum(times) / len(times)
avg_mem = sum(mems) / len(mems)
avg_gflops = sum(gflops) / len(gflops)
return avg_time, avg_mem, avg_gflops
###############################################################################
# Run the measurements
###############################################################################
print("Benchmarking Scaled Dot-Product Attention (Flash) in FP16 ...")
sdpa_time, sdpa_mem, sdpa_gflops = measure_op(run_sdpa, warmup=3, total_runs=10)
print(f"Average time: {sdpa_time:.2f} ms")
print(f"Average peak memory: {sdpa_mem:.2f} MB")
print(f"Average throughput: {sdpa_gflops:.2f} GFLOP/s\n")
print("Benchmarking Conv2d in FP16 ...")
conv_time, conv_mem, conv_gflops = measure_op(run_conv2d, warmup=3, total_runs=10)
print(f"Average time: {conv_time:.2f} ms")
print(f"Average peak memory: {conv_mem:.2f} MB")
print(f"Average throughput: {conv_gflops:.2f} GFLOP/s\n")
print("Benchmarking GEMM in FP16 ...")
mm_time32, mm_mem32, mm_gflops32 = benchmark_mm(4352, 13568, 3840, torch.bfloat16, device, warmup=3, total_runs=10)
print(f"Average time: {mm_time32:.2f} ms")
print(f"Average peak memory: {mm_mem32:.2f} MB")
print(f"Average throughput: {mm_gflops32:.2f} GFLOP/s")
###############################################################################
# Check PyTorch build configuration
###############################################################################
print("\nPyTorch preferred CUDA BLAS backend:")
print(torch.backends.cuda.preferred_blas_library())
# Start Generation Here
import sys
import platform
import subprocess
print("\nPython environment:")
print(f"Executable : {sys.executable}")
print(f"Implementation : {platform.python_implementation()}")
print(f"Version : {platform.python_version()}")
print(f"Build : {platform.python_build()}")
print(f"Compiler : {platform.python_compiler()}")
print(f"Platform : {platform.platform()}")
print(f"Processor : {platform.processor()}")
print(f"\nPyTorch : {torch.__version__}")
print(f"HIP available : {torch.cuda.is_available()}")
if torch.cuda.is_available():
hip_version = getattr(torch.version, 'hip', None)
print(f"HIP version : {hip_version}")
print(f"HIP devices ({torch.cuda.device_count()}):")
for idx in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(idx)
print(f"Device [{idx}] - {props.name}")
print(" Multiprocessor Count :", props.multi_processor_count)
print(" Total Memory :", f"{props.total_memory / (1024**2):.2f} MB")
print(" Integrated :", bool(props.is_integrated))
print(" Multi-GPU Board :", bool(props.is_multi_gpu_board))
print(" GCN Arch Name :", props.gcnArchName)
print(" Warp Size :", props.warp_size)
print(" L2 Cache Size :", f"{props.L2_cache_size / 1024:.2f} KB")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment