Skip to content

Instantly share code, notes, and snippets.

@maaquib
Created September 5, 2025 15:40
Show Gist options
  • Select an option

  • Save maaquib/9465d79bd0e524f69f621a6dc1c1310d to your computer and use it in GitHub Desktop.

Select an option

Save maaquib/9465d79bd0e524f69f621a6dc1c1310d to your computer and use it in GitHub Desktop.
benchmark_flash_comparison
#!/usr/bin/env python3
"""
Benchmark comparing FlashInfer vs Flash Attention on B200 GPU.
Tests various sequence lengths and head dimensions.
"""
import argparse
import sys
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import datetime
from typing import Any
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch.nn.attention import SDPBackend, sdpa_kernel
try:
import flashinfer
FLASHINFER_AVAILABLE = True
except ImportError:
FLASHINFER_AVAILABLE = False
print(
"Warning: FlashInfer not available. Install with: pip install flashinfer-python"
)
try:
from flash_attn import flash_attn_func
FLASH_ATTN_AVAILABLE = True
except ImportError:
FLASH_ATTN_AVAILABLE = False
print(
"Warning: Flash Attention not available. Install with: pip install flash-attn --no-build-isolation"
)
try:
import max
MAX_AVAILABLE = True
except ImportError:
MAX_AVAILABLE = False
print("Warning: Max not available. Install with: pip install max")
@dataclass
class BenchmarkConfig:
"""Configuration for benchmark parameters."""
batch_sizes: Sequence[int] | None = None
seq_lengths: Sequence[int] | None = None
num_heads_list: Sequence[int] | None = None
head_dims: Sequence[int] | None = None
causal: bool = True
dtype: torch.dtype = torch.float16
warmup_iters: int = 10
benchmark_iters: int = 100
def benchmark_flashinfer(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
warmup_iters: int = 10,
benchmark_iters: int = 100,
):
"""Benchmark FlashInfer attention."""
if not FLASHINFER_AVAILABLE:
return None
batch_size, num_heads, seq_len, head_dim = q.shape
# Reshape for FlashInfer (batch_size * seq_len, num_heads, head_dim)
q_fi = q.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
k_fi = k.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
v_fi = v.transpose(1, 2).contiguous().view(-1, num_heads, head_dim)
scale = 1.0 / (head_dim ** 0.5)
# Warmup
for _ in range(warmup_iters):
with torch.no_grad():
_ = flashinfer.single_prefill_with_kv_cache(
q_fi,
k_fi,
v_fi,
causal=causal,
sm_scale=scale,
)
torch.cuda.synchronize()
# Benchmark
start_events = []
end_events = []
for _ in range(benchmark_iters):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
output = flashinfer.single_prefill_with_kv_cache(
q_fi,
k_fi,
v_fi,
causal=causal,
sm_scale=scale,
)
end.record()
start_events.append(start)
end_events.append(end)
torch.cuda.synchronize()
latencies = []
for start, end in zip(start_events, end_events):
latencies.append(start.elapsed_time(end))
# Convert back to original shape for validation
output = output.view(batch_size, seq_len, num_heads, head_dim)
output = output.transpose(1, 2).contiguous()
return {
"mean_ms": np.mean(latencies),
"p50_ms": np.percentile(latencies, 50),
"p90_ms": np.percentile(latencies, 90),
"p95_ms": np.percentile(latencies, 95),
"p99_ms": np.percentile(latencies, 99),
"std_ms": np.std(latencies),
"median_ms": np.median(latencies),
"min_ms": np.min(latencies),
"max_ms": np.max(latencies),
"output": output,
}
def benchmark_flash_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
warmup_iters: int = 10,
benchmark_iters: int = 100,
):
"""Benchmark Flash Attention."""
if not FLASH_ATTN_AVAILABLE:
return None
batch_size, num_heads, seq_len, head_dim = q.shape
# Reshape for Flash Attention (batch_size, seq_len, num_heads, head_dim)
q_fa = q.transpose(1, 2).contiguous()
k_fa = k.transpose(1, 2).contiguous()
v_fa = v.transpose(1, 2).contiguous()
# Warmup
for _ in range(warmup_iters):
with torch.no_grad():
_ = flash_attn_func(q_fa, k_fa, v_fa, causal=causal)
torch.cuda.synchronize()
# Benchmark
start_events = []
end_events = []
for _ in range(benchmark_iters):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
output = flash_attn_func(q_fa, k_fa, v_fa, causal=causal)
end.record()
start_events.append(start)
end_events.append(end)
torch.cuda.synchronize()
latencies = []
for start, end in zip(start_events, end_events):
latencies.append(start.elapsed_time(end))
# Convert back to original shape for validation
output = output.transpose(1, 2).contiguous()
return {
"mean_ms": np.mean(latencies),
"p50_ms": np.percentile(latencies, 50),
"p90_ms": np.percentile(latencies, 90),
"p95_ms": np.percentile(latencies, 95),
"p99_ms": np.percentile(latencies, 99),
"std_ms": np.std(latencies),
"median_ms": np.median(latencies),
"min_ms": np.min(latencies),
"max_ms": np.max(latencies),
"output": output,
}
def benchmark_torch_sdpa(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
warmup_iters: int = 10,
benchmark_iters: int = 100,
):
"""Benchmark PyTorch's scaled dot-product attention (baseline)."""
# Warmup
for _ in range(warmup_iters):
with torch.no_grad():
_ = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
torch.cuda.synchronize()
# Benchmark
start_events = []
end_events = []
for _ in range(benchmark_iters):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
with torch.no_grad():
output = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
end.record()
start_events.append(start)
end_events.append(end)
torch.cuda.synchronize()
latencies = []
for start, end in zip(start_events, end_events):
latencies.append(start.elapsed_time(end))
return {
"mean_ms": np.mean(latencies),
"p50_ms": np.percentile(latencies, 50),
"p90_ms": np.percentile(latencies, 90),
"p95_ms": np.percentile(latencies, 95),
"p99_ms": np.percentile(latencies, 99),
"std_ms": np.std(latencies),
"median_ms": np.median(latencies),
"min_ms": np.min(latencies),
"max_ms": np.max(latencies),
"output": output,
}
def benchmark_max(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
causal: bool = True,
warmup_iters: int = 10,
benchmark_iters: int = 100,
):
"""Benchmark Max's attention using Graph API with flash_attention_gpu kernel."""
if not MAX_AVAILABLE:
return None
try:
from max.driver import Accelerator, Tensor
from max.dtype import DType
from max.engine import InferenceSession
from max.graph import DeviceRef, Graph, TensorType, TensorValue, ops
from max.nn.attention.mask_config import MHAMaskVariant
from max.nn.kernels import flash_attention_gpu
from torch.utils.dlpack import from_dlpack
except ImportError as e:
print(f"Warning: Could not import MAX modules: {e}")
return None
_, num_heads, _, head_dim = q.shape
# Convert torch tensors to the right format (batch, seq_len, num_heads, head_dim)
q_max = q.transpose(1, 2).contiguous()
k_max = k.transpose(1, 2).contiguous()
v_max = v.transpose(1, 2).contiguous()
dtype_map = {
torch.float16: DType.float16,
torch.bfloat16: DType.bfloat16,
torch.float32: DType.float32,
}
max_dtype = dtype_map.get(q.dtype, DType.float32)
device_ref, device = DeviceRef.GPU(), Accelerator()
with Graph(
"max_attention_benchmark",
input_types=(
TensorType(
shape=("batch_size", "seq_len", num_heads, head_dim),
dtype=max_dtype,
device=device_ref,
),
TensorType(
shape=("batch_size", "seq_len", num_heads, head_dim),
dtype=max_dtype,
device=device_ref,
),
TensorType(
shape=("batch_size", "seq_len", num_heads, head_dim),
dtype=max_dtype,
device=device_ref,
),
),
) as graph:
q_input, k_input, v_input = graph.inputs
assert isinstance(q_input, TensorValue)
assert isinstance(k_input, TensorValue)
assert isinstance(v_input, TensorValue)
mask_variant = MHAMaskVariant.CAUSAL_MASK if causal else MHAMaskVariant.NULL_MASK
scale = 1.0 / np.sqrt(head_dim)
output = flash_attention_gpu(
q=q_input,
k=k_input,
v=v_input,
mask_variant=mask_variant,
scale=scale,
)
# Transpose back to (batch, num_heads, seq_len, head_dim)
output = ops.transpose(output, 1, 2)
graph.output(output)
session = InferenceSession(devices=[device])
compiled = session.load(graph)
q_max = Tensor.from_dlpack(q_max).to(device)
k_max = Tensor.from_dlpack(k_max).to(device)
v_max = Tensor.from_dlpack(v_max).to(device)
# Warmup
for _ in range(warmup_iters):
_ = compiled.execute(
q_max,
k_max,
v_max,
)
if torch.cuda.is_available():
torch.cuda.synchronize()
# Benchmark
start_events = []
end_events = []
for _ in range(benchmark_iters):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = compiled.execute(
q_max,
k_max,
v_max,
)[0]
end.record()
start_events.append(start)
end_events.append(end)
torch.cuda.synchronize()
latencies = []
for start, end in zip(start_events, end_events):
latencies.append(start.elapsed_time(end))
# Convert output back to torch tensor with original shape
output_torch = from_dlpack(result)
return {
"mean_ms": np.mean(latencies),
"p50_ms": np.percentile(latencies, 50),
"p90_ms": np.percentile(latencies, 90),
"p95_ms": np.percentile(latencies, 95),
"p99_ms": np.percentile(latencies, 99),
"std_ms": np.std(latencies),
"median_ms": np.median(latencies),
"min_ms": np.min(latencies),
"max_ms": np.max(latencies),
"output": output_torch,
}
def validate_outputs(
outputs_dict: dict[str, dict[str, Any]], tolerance: float = 1e-2
):
"""Validate that all implementations produce similar outputs."""
outputs = {k: v["output"] for k, v in outputs_dict.items() if v is not None}
if len(outputs) < 2:
print("Not enough implementations to validate")
return True
reference_key = list(outputs.keys())[0]
reference = outputs[reference_key]
print(f"Reference: {reference_key}")
valid = True
for name, output in outputs.items():
if name == reference_key:
continue
max_diff = torch.max(torch.abs(reference - output)).item()
if max_diff > tolerance:
valid = False
print(f"Validation failed: {reference_key} vs {name}, max_diff={max_diff}")
if valid:
print("All outputs match within tolerance")
return valid
def run_benchmark(config: BenchmarkConfig):
"""Run the complete benchmark suite."""
results = []
print(f"\n{'=' * 80}")
print(f"Running benchmarks on {torch.cuda.get_device_name()}")
print(f"{'=' * 80}\n")
total_configs = (
len(config.seq_lengths or [])
* len(config.num_heads_list or [])
* len(config.head_dims or [])
)
current = 0
for batch_size in config.batch_sizes or []:
for seq_len in config.seq_lengths or []:
for num_heads in config.num_heads_list or []:
for head_dim in config.head_dims or []:
current += 1
print(
f"[{current}/{total_configs}] Testing: seq_len={seq_len}, "
f"num_heads={num_heads}, head_dim={head_dim}"
)
q = torch.randn(
batch_size,
num_heads,
seq_len,
head_dim,
dtype=config.dtype,
device="cuda",
)
k = torch.randn(
batch_size,
num_heads,
seq_len,
head_dim,
dtype=config.dtype,
device="cuda",
)
v = torch.randn(
batch_size,
num_heads,
seq_len,
head_dim,
dtype=config.dtype,
device="cuda",
)
bench_results = {}
# PyTorch SDPA (baseline)
print(" - Running PyTorch SDPA w/ CUDNN attention...")
with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
sdpa_result = benchmark_torch_sdpa(
q,
k,
v,
config.causal,
config.warmup_iters,
config.benchmark_iters,
)
if sdpa_result:
bench_results["pytorch_sdpa"] = sdpa_result
# FlashInfer
if FLASHINFER_AVAILABLE:
print(" - Running FlashInfer...")
fi_result = benchmark_flashinfer(
q,
k,
v,
config.causal,
config.warmup_iters,
config.benchmark_iters,
)
if fi_result:
bench_results["flashinfer"] = fi_result
# Flash Attention
if FLASH_ATTN_AVAILABLE:
print(" - Running Flash Attention...")
fa_result = benchmark_flash_attention(
q,
k,
v,
config.causal,
config.warmup_iters,
config.benchmark_iters,
)
if fa_result:
bench_results["flash_attn"] = fa_result
# Max
if MAX_AVAILABLE:
print(" - Running Max...")
max_result = benchmark_max(
q,
k,
v,
config.causal,
config.warmup_iters,
config.benchmark_iters,
)
if max_result:
bench_results["max"] = max_result
valid = validate_outputs(bench_results)
for impl_name, impl_results in bench_results.items():
results.append(
{
"implementation": impl_name,
"batch_size": batch_size,
"seq_length": seq_len,
"num_heads": num_heads,
"head_dim": head_dim,
"mean_latency_ms": impl_results["mean_ms"],
"p50_latency_ms": impl_results["p50_ms"],
"p90_latency_ms": impl_results["p90_ms"],
"p95_latency_ms": impl_results["p95_ms"],
"p99_latency_ms": impl_results["p99_ms"],
"std_latency_ms": impl_results["std_ms"],
"median_latency_ms": impl_results["median_ms"],
"min_latency_ms": impl_results["min_ms"],
"max_latency_ms": impl_results["max_ms"],
"causal": config.causal,
"dtype": str(config.dtype),
"validation_passed": valid,
}
)
print(" Results:")
for impl_name, impl_results in bench_results.items():
print(
f" {impl_name:15s}: {impl_results['mean_ms']:.3f} ± "
f"{impl_results['std_ms']:.3f} ms"
)
print()
return results
def save_results(results: list[dict[str, Any]], output_dir: str = "."):
"""Save benchmark results to CSV and summary."""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
df = pd.DataFrame(results)
csv_path = f"{output_dir}/flash_benchmark_results_{timestamp}.csv"
df.to_csv(csv_path, index=False)
print(f"Detailed results saved to: {csv_path}")
if len(df["implementation"].unique()) > 1:
pivot = df.pivot_table(
values="mean_latency_ms",
index=["seq_length", "num_heads", "head_dim"],
columns="implementation",
aggfunc="mean",
)
# Calculate speedups relative to pytorch_sdpa baseline
if "pytorch_sdpa" in pivot.columns:
for col in pivot.columns:
if col != "pytorch_sdpa":
pivot[f"{col}_speedup"] = pivot["pytorch_sdpa"] / pivot[col]
pivot_path = f"{output_dir}/flash_benchmark_comparison_{timestamp}.csv"
pivot.to_csv(pivot_path)
print(f"Comparison table saved to: {pivot_path}")
# Print summary
print("\n" + "=" * 80)
print("BENCHMARK SUMMARY")
print("=" * 80)
print("\nAverage Latencies by Implementation:")
avg_latencies = df.groupby("implementation")["mean_latency_ms"].mean().round(3)
print(avg_latencies)
# Print comprehensive comparison table
# print("\n" + "=" * 80)
# print("ALL FRAMEWORKS COMPARISON (mean latency in ms)")
# print("=" * 80)
# print("\n" + pivot.round(3).to_string())
if "pytorch_sdpa" in df["implementation"].unique():
print("\n" + "=" * 80)
print("SPEEDUP SUMMARY (relative to PyTorch SDPA)")
print("=" * 80)
baseline_df = df[df["implementation"] == "pytorch_sdpa"]
for impl in df["implementation"].unique():
if impl != "pytorch_sdpa":
impl_df = df[df["implementation"] == impl]
# Merge on common configurations
merged = pd.merge(
baseline_df[["seq_length", "num_heads", "head_dim", "mean_latency_ms"]],
impl_df[["seq_length", "num_heads", "head_dim", "mean_latency_ms"]],
on=["seq_length", "num_heads", "head_dim"],
suffixes=("_baseline", f"_{impl}"),
)
if not merged.empty:
merged["speedup"] = merged["mean_latency_ms_baseline"] / merged[f"mean_latency_ms_{impl}"]
print(f"\n{impl}:")
print(f" Average speedup: {merged['speedup'].mean():.3f}x")
print(f" Median speedup: {merged['speedup'].median():.3f}x")
print(f" Max speedup: {merged['speedup'].max():.3f}x")
print(f" Min speedup: {merged['speedup'].min():.3f}x")
return df
def main():
parser = argparse.ArgumentParser(
description="Benchmark FlashInfer vs Flash Attention on B200"
)
parser.add_argument(
"--batch-sizes",
type=int,
nargs="+",
default=[1],
help="Batch sizes",
)
parser.add_argument(
"--seq-lengths",
type=int,
nargs="+",
default=[32, 64, 128, 256, 512, 1024, 1536, 2048, 4096, 8192, 16384],
help="Sequence lengths to test",
)
parser.add_argument(
"--num-heads",
type=int,
nargs="+",
default=[16, 32, 64, 128],
help="Number of attention heads to test",
)
parser.add_argument(
"--head-dims",
type=int,
nargs="+",
default=[64, 128],
help="Head dimensions to test",
)
parser.add_argument(
"--warmup-iters",
type=int,
default=10,
help="Number of warmup iterations",
)
parser.add_argument(
"--benchmark-iters",
type=int,
default=100,
help="Number of benchmark iterations",
)
parser.add_argument(
"--no-causal", action="store_true", help="Disable causal masking"
)
parser.add_argument(
"--fp16", action="store_true", help="Use FP16 instead of BF16"
)
parser.add_argument(
"--output-dir", type=str, default=".", help="Directory to save results"
)
args = parser.parse_args()
# Check CUDA availability
if not torch.cuda.is_available():
print("Error: CUDA is not available!")
sys.exit(1)
# Check for B200
device_name = torch.cuda.get_device_name()
if "B200" not in device_name and "Blackwell" not in device_name:
print(f"Warning: Current GPU is {device_name}, not B200/Blackwell")
response = input("Continue anyway? (y/n): ")
if response.lower() != "y":
sys.exit(0)
config = BenchmarkConfig(
batch_sizes=args.batch_sizes,
seq_lengths=args.seq_lengths,
num_heads_list=args.num_heads,
head_dims=args.head_dims,
causal=not args.no_causal,
dtype=torch.float16 if args.fp16 else torch.bfloat16,
warmup_iters=args.warmup_iters,
benchmark_iters=args.benchmark_iters,
)
print("\nBenchmark Configuration:")
print(f" Batch sizes: {config.batch_sizes}")
print(f" Sequence lengths: {config.seq_lengths}")
print(f" Number of heads: {config.num_heads_list}")
print(f" Head dimensions: {config.head_dims}")
print(f" Causal: {config.causal}")
print(f" Data type: {config.dtype}")
print(f" Warmup iterations: {config.warmup_iters}")
print(f" Benchmark iterations: {config.benchmark_iters}")
# Run benchmark
results = run_benchmark(config)
# Save and display results
if results:
save_results(results, args.output_dir)
else:
print("No results collected!")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment