Created
June 17, 2025 02:11
-
-
Save a-r-r-o-w/58425fd303633e3c3702283b4687599d to your computer and use it in GitHub Desktop.
SDPA benchmark for torch, FA2, FA3, transformer engine, xformers, Sage Attention and HF kernels-lib
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# Benchmarking common shapes for Flux 1024x1024px image + varying text sequence lengths | |
import functools | |
import os | |
import pathlib | |
import matplotlib.pyplot as plt | |
import torch | |
import torch._dynamo.config | |
import triton | |
import triton.language as tl | |
try: | |
from flash_attn import flash_attn_func | |
except: | |
flash_attn_func = None | |
print("Flash Attention 2 not found.") | |
try: | |
from flash_attn_interface import flash_attn_func as flash_attn_3_func | |
except: | |
flash_attn_3_func = None | |
print("Flash Attention 3 not found.") | |
try: | |
from kernels import get_kernel | |
hf_kernels_flash_attn = get_kernel("kernels-community/flash-attn") | |
except: | |
hf_kernels_flash_attn = None | |
print("HF Kernels not found.") | |
try: | |
from sageattention import sageattn_qk_int8_pv_fp16_cuda, sageattn_qk_int8_pv_fp16_triton, sageattn_qk_int8_pv_fp8_cuda_sm90 | |
except: | |
sageattn_qk_int8_pv_fp16_cuda = None | |
sageattn_qk_int8_pv_fp16_triton = None | |
sageattn_qk_int8_pv_fp8_cuda_sm90 = None | |
print("SageAttention not found.") | |
try: | |
from transformer_engine.pytorch.attention import DotProductAttention | |
except: | |
DotProductAttention = None | |
print("Transformer Engine not found.") | |
try: | |
import xformers.ops as xops | |
except: | |
xops = None | |
print("xFormers not found.") | |
plt.rcParams.update({ | |
"figure.figsize": (12, 10), | |
"figure.dpi": 120, | |
"font.size": 10, | |
"axes.titlesize": 12, | |
"axes.labelsize": 14, | |
"xtick.labelsize": 10, | |
"ytick.labelsize": 10, | |
"legend.fontsize": 8, | |
"axes.grid": True, | |
"grid.alpha": 0.3, | |
"grid.linestyle": "--", | |
"lines.linewidth": 2.0, | |
"lines.markersize": 6, | |
"legend.frameon": True, | |
"legend.framealpha": 0.9, | |
"legend.loc": "best", | |
"axes.spines.top": False, | |
"axes.spines.right": False, | |
}) | |
# We want to compare the best compiled version for each specific shape (dynamic=False) | |
torch._dynamo.config.cache_size_limit = 10000 | |
# We need to suppress_errors for FA3 to work. It makes it run in eager mode. | |
# I can't seem to get it to work any other way under torch.compile, so any suggestions are welcome! | |
torch._dynamo.config.suppress_errors = True | |
output_dir = pathlib.Path("dump_attention_benchmark") | |
output_dir.mkdir(parents=True, exist_ok=True) | |
batch_size = 1 | |
num_attention_heads = 24 | |
attention_head_dim = 128 | |
image_sequence_length = 4096 # 1024x1024px | |
text_sequence_lengths = [128, 256, 320, 384, 448, 512] | |
sequence_lengths = [image_sequence_length + i for i in text_sequence_lengths] | |
def _attention_torch(query, key, value, *, backend): | |
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value)) | |
with torch.nn.attention.sdpa_kernel(backend): | |
out = torch.nn.functional.scaled_dot_product_attention(query, key, value) | |
out = out.transpose(1, 2).contiguous() | |
return out | |
_compiled_attention_torch_default = torch.compile(_attention_torch, mode="default", fullgraph=True, dynamic=False) | |
def _attention_torch_compile_default(query, key, value, *, backend): | |
return _compiled_attention_torch_default(query, key, value, backend=backend) | |
_compiled_attention_torch_max_autotune = torch.compile(_attention_torch, mode="max-autotune", fullgraph=True, dynamic=False) | |
def _attention_torch_compile_max_autotune(query, key, value, *, backend): | |
return _compiled_attention_torch_max_autotune(query, key, value, backend=backend) | |
def _attention_flash_attn_2(query, key, value): | |
return flash_attn_func(query, key, value) | |
_compiled_flash_attn_2_default = torch.compile(_attention_flash_attn_2, mode="default", fullgraph=True, dynamic=False) | |
def _attention_flash_attn_2_compile_default(query, key, value): | |
return _compiled_flash_attn_2_default(query, key, value) | |
_compiled_flash_attn_2_max_autotune = torch.compile(_attention_flash_attn_2, mode="max-autotune", fullgraph=True, dynamic=False) | |
def _attention_flash_attn_2_compile_max_autotune(query, key, value): | |
return _compiled_flash_attn_2_max_autotune(query, key, value) | |
# For fullgraph=True tracing to be compatible | |
@torch.library.custom_op("flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda") | |
def _wrapped_flash_attn_3(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
out, lse = flash_attn_3_func(query, key, value) | |
return out | |
@torch.library.register_fake("flash_attn_3::_flash_attn_forward") | |
def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor: | |
return torch.empty_like(query) | |
def _attention_flash_attn_3(query, key, value): | |
out = _wrapped_flash_attn_3(query, key, value) | |
return out | |
_compiled_flash_attn_3_default = torch.compile(_attention_flash_attn_3, mode="default", fullgraph=True, dynamic=False) | |
def _attention_flash_attn_3_compile_default(query, key, value): | |
return _compiled_flash_attn_3_default(query, key, value) | |
_compiled_flash_attn_3_max_autotune = torch.compile(_attention_flash_attn_3, mode="max-autotune", fullgraph=True, dynamic=False) | |
def _attention_flash_attn_3_compile_max_autotune(query, key, value): | |
return _compiled_flash_attn_3_max_autotune(query, key, value) | |
def _attention_hf_kernels_flash_attn(query, key, value): | |
return hf_kernels_flash_attn.mha_fwd(query, key, value, is_causal=False)[0] | |
def _attention_sageattn_qk_int8_pv_fp16_cuda(query, key, value): | |
return sageattn_qk_int8_pv_fp16_cuda(query, key, value, tensor_layout="NHD") | |
def _attention_sageattn_qk_int8_pv_fp16_triton(query, key, value): | |
return sageattn_qk_int8_pv_fp16_triton(query, key, value, tensor_layout="NHD") | |
def _attention_sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value): | |
return sageattn_qk_int8_pv_fp8_cuda_sm90(query, key, value, tensor_layout="NHD") | |
if DotProductAttention is not None: | |
def set_te_backend(backend): | |
# must be applied before first use of | |
# transformer_engine.pytorch.attention | |
os.environ["NVTE_FLASH_ATTN"] = '0' | |
os.environ["NVTE_FUSED_ATTN"] = '0' | |
os.environ["NVTE_UNFUSED_ATTN"] = '0' | |
if backend == 'flash': | |
os.environ["NVTE_FLASH_ATTN"] = '1' | |
if backend == 'fused': | |
os.environ["NVTE_FUSED_ATTN"] = '1' | |
if backend == 'unfused': | |
os.environ["NVTE_UNFUSED_ATTN"] = '1' | |
set_te_backend("fused") | |
te_attn_fn = DotProductAttention( | |
num_attention_heads=num_attention_heads, | |
kv_channels=attention_head_dim, | |
qkv_format="bshd", | |
attn_mask_type="no_mask", | |
) | |
else: | |
def te_attn_fn(query, key, value): | |
raise RuntimeError("Transformer Engine is not available. Please install it for TE-based attention.") | |
def _attention_te(query, key, value): | |
out = te_attn_fn(query, key, value) | |
out = out.unflatten(2, (num_attention_heads, attention_head_dim)) | |
return out | |
# Cannot fullgraph compile TE | |
_compiled_te_attn_fn_default = torch.compile(_attention_te, mode="default", fullgraph=False, dynamic=False) | |
def _attention_te_compile_default(query, key, value): | |
return _compiled_te_attn_fn_default(query, key, value) | |
# Cannot fullgraph compile TE | |
_compiled_te_attn_fn_max_autotune = torch.compile(_attention_te, mode="max-autotune", fullgraph=False, dynamic=False) | |
def _attention_te_compile_max_autotune(query, key, value): | |
return _compiled_te_attn_fn_max_autotune(query, key, value) | |
def _attention_xformers(query, key, value): | |
return xops.memory_efficient_attention(query, key, value) | |
_compiled_xformers_default = torch.compile(_attention_xformers, mode="default", fullgraph=True, dynamic=False) | |
def _attention_xformers_compile_default(query, key, value): | |
return _compiled_xformers_default(query, key, value) | |
_compiled_xformers_max_autotune = torch.compile(_attention_xformers, mode="max-autotune", fullgraph=True, dynamic=False) | |
def _attention_xformers_compile_max_autotune(query, key, value): | |
return _compiled_xformers_max_autotune(query, key, value) | |
attention_ops = {} | |
attention_ops["torch_cudnn"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) | |
attention_ops["torch_cudnn_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) | |
attention_ops["torch_cudnn_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.CUDNN_ATTENTION) | |
attention_ops["torch_flash"] = functools.partial(_attention_torch, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) | |
attention_ops["torch_flash_compile_d"] = functools.partial(_attention_torch_compile_default, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) | |
attention_ops["torch_flash_compile_ma"] = functools.partial(_attention_torch_compile_max_autotune, backend=torch.nn.attention.SDPBackend.FLASH_ATTENTION) | |
if hf_kernels_flash_attn is not None: | |
attention_ops["hf_flash_attn"] = _attention_hf_kernels_flash_attn | |
if flash_attn_func is not None: | |
attention_ops["flash_attn_2"] = _attention_flash_attn_2 | |
attention_ops["flash_attn_2_compile_d"] = _attention_flash_attn_2_compile_default | |
attention_ops["flash_attn_2_compile_ma"] = _attention_flash_attn_2_compile_max_autotune | |
if flash_attn_3_func is not None: | |
attention_ops["flash_attn_3"] = _attention_flash_attn_3 | |
attention_ops["flash_attn_3_compile_d"] = _attention_flash_attn_3_compile_default | |
attention_ops["flash_attn_3_compile_ma"] = _attention_flash_attn_3_compile_max_autotune | |
if sageattn_qk_int8_pv_fp16_cuda is not None: | |
attention_ops["sageattn_qk_int8_pv_fp16_cuda"] = _attention_sageattn_qk_int8_pv_fp16_cuda | |
attention_ops["sageattn_qk_int8_pv_fp16_triton"] = _attention_sageattn_qk_int8_pv_fp16_triton | |
if torch.cuda.get_device_capability()[0] >= 9: | |
attention_ops["sageattn_qk_int8_pv_fp8_cuda_sm90"] = _attention_sageattn_qk_int8_pv_fp8_cuda_sm90 | |
if DotProductAttention is not None: | |
attention_ops["te_fused"] = _attention_te | |
attention_ops["te_fused_compile_d"] = _attention_te_compile_default | |
attention_ops["te_fused_compile_ma"] = _attention_te_compile_max_autotune | |
if xops is not None: | |
attention_ops["xformers"] = _attention_xformers | |
attention_ops["xformers_compile_d"] = _attention_xformers_compile_default | |
attention_ops["xformers_compile_ma"] = _attention_xformers_compile_max_autotune | |
def get_color_and_linestyle(n: int) -> tuple[str, str]: | |
colors = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3", "#ff7f00", "#a65628", "#f781bf", "#999999"] | |
line_styles = ["-", ":", "-.", "--"] | |
if n > len(colors) * len(line_styles): | |
raise ValueError(f"Required {n=} styles but maximum is {len(colors) * len(line_styles)}") | |
styles = [] | |
for i in range(n): | |
color = colors[i % len(colors)] | |
linestyle = line_styles[i // len(colors)] | |
styles.append((color, linestyle)) | |
return styles | |
def correctness(): | |
for seq_len in sequence_lengths: | |
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) | |
print(f"\n\n===== Testing shape: {shape} =====") | |
query = torch.randn(shape, device="cuda", dtype=torch.float32) | |
key = torch.randn(shape, device="cuda", dtype=torch.float32) | |
value = torch.randn(shape, device="cuda", dtype=torch.float32) | |
golden_truth = _attention_torch(query, key, value, backend=torch.nn.attention.SDPBackend.MATH) | |
query, key, value = (x.bfloat16() for x in (query, key, value)) | |
for name, fn in attention_ops.items(): | |
out = fn(query, key, value) | |
absdiff = (out - golden_truth).abs() | |
absmax = torch.max(absdiff) | |
mae = torch.mean(absdiff) | |
mse = torch.mean((golden_truth - out) ** 2) | |
print(f"{name:<30}: absmax={absmax:.6f}, mae={mae:.6f}, mse={mse:.6f}") | |
@triton.testing.perf_report( | |
triton.testing.Benchmark( | |
x_names=["seq_len"], | |
x_vals=sequence_lengths, | |
x_log=False, | |
line_arg="provider", | |
line_vals=list(attention_ops.keys()), | |
line_names=[x.removeprefix("solution_") for x in attention_ops.keys()], | |
ylabel="Time (ms)", | |
styles=get_color_and_linestyle(len(attention_ops)), | |
plot_name="Attention Benchmark", | |
args={}, | |
) | |
) | |
def benchmark_fn(seq_len: int, provider: str): | |
torch.manual_seed(0) | |
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim) | |
query = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) | |
key = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) | |
value = torch.randn(shape, device="cuda", dtype=torch.bfloat16) * torch.randint(1, 5, shape, device="cuda", dtype=torch.bfloat16) | |
fn = attention_ops[provider] | |
ms, min_ms, max_ms = triton.testing.do_bench( | |
lambda: fn(query, key, value), | |
warmup=3, | |
rep=10, | |
quantiles=[0.5, 0.2, 0.8], | |
) | |
return ms, max_ms, min_ms | |
with torch.inference_mode(): | |
correctness() | |
benchmark_fn.run(print_data=True, save_path=output_dir.as_posix()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
A100