Skip to content

Instantly share code, notes, and snippets.

@chuanqi129
Last active October 10, 2025 16:34
Show Gist options
  • Save chuanqi129/ca45593c3d2c4b279b426e4883df9b87 to your computer and use it in GitHub Desktop.
Save chuanqi129/ca45593c3d2c4b279b426e4883df9b87 to your computer and use it in GitHub Desktop.
sdpa op benchmark
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
import time
import argparse
import sys
# --- XPU Device Setup (Strictly XPU) ---
DEVICE = "xpu"
# Check if torch.xpu is available, which is necessary for the XPU-specific code.
if not hasattr(torch, 'xpu') or not torch.xpu.is_available():
print("FATAL ERROR: torch.xpu is not available. This script is strictly configured for XPU devices and cannot proceed.")
sys.exit(1)
# --- 0. Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser(description="Benchmark Scaled Dot-Product Attention (SDPA) with variable head counts on XPU.")
# Required parameters for attention head configuration
parser.add_argument(
'--num_heads',
type=int,
required=True,
help="Number of Query heads (H_q)."
)
parser.add_argument(
'--num_kv_heads',
type=int,
required=True,
help="Number of Key/Value heads (H_kv)."
)
# Optional parameters
parser.add_argument('--batch_size', type=int, default=1, help="Batch size (N).")
parser.add_argument('--seq_len', type=int, default=8096, help="Sequence length (S/L).")
parser.add_argument('--head_dim', type=int, default=128, help="Head dimension (d_k).")
parser.add_argument('--warmup_iter', type=int, default=100, help="Number of warmup iterations.")
parser.add_argument('--benchmark_iter', type=int, default=200, help="Number of benchmark iterations.")
return parser.parse_args()
args = parse_args()
# --- 1. Parameter Setup and Data Preparation (using arguments) ---
DTYPE = torch.float16
BATCH_SIZE = args.batch_size
NUM_HEADS = args.num_heads # Number of Query heads (H_q)
NUM_KV_HEADS = args.num_kv_heads # Number of Key/Value heads (H_kv)
SEQ_LEN = args.seq_len
HEAD_DIM = args.head_dim
WARMUP_ITER = args.warmup_iter
BENCHMARK_ITER = args.benchmark_iter
# Determine GQA status based on head counts
# GQA is enabled if Q_heads > KV_heads and KV_heads > 1
# MQA is enabled if KV_heads == 1
# MHA is enabled if Q_heads == KV_heads
IS_GQA_ENABLED = (NUM_HEADS != NUM_KV_HEADS) and (NUM_HEADS % NUM_KV_HEADS == 0)
IS_MHA_ENABLED = (NUM_HEADS == NUM_KV_HEADS)
ATTENTION_TYPE = "MHA" if IS_MHA_ENABLED else ("GQA" if IS_GQA_ENABLED else "MQA")
print(f"Device: {DEVICE}, DType: {DTYPE}")
print(f"Q Heads: {NUM_HEADS}, KV Heads: {NUM_KV_HEADS}. Attention Type: {ATTENTION_TYPE}")
# PyTorch SDPA uses the tensor shapes to determine GQA/MHA/MQA internally.
query = torch.rand(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE)
key = torch.rand(BATCH_SIZE, NUM_KV_HEADS, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE)
value = torch.rand(BATCH_SIZE, NUM_KV_HEADS, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE)
# --- 2. Define Benchmark Function (with GQA reporting) ---
def benchmark_sdpa(q, k, v, backend_name, attention_type, backend_list=None):
# Use xpu.Event for precise timing
starter = torch.xpu.Event(enable_timing=True)
ender = torch.xpu.Event(enable_timing=True)
timings = []
# Use the sdpa_kernel context manager to set the attention backend
with sdpa_kernel(backends=backend_list):
# Warm-up
print(f"\nWarm-up for {backend_name} ({attention_type})...")
for _ in range(WARMUP_ITER):
_ = scaled_dot_product_attention(q, k, v, enable_gqa = attention_type=="GQA")
# Ensure warm-up operations are complete
torch.xpu.synchronize()
# Benchmark
print(f"Benchmarking {backend_name} for {BENCHMARK_ITER} iterations...")
for _ in range(BENCHMARK_ITER):
starter.record()
_ = scaled_dot_product_attention(q, k, v, enable_gqa = attention_type=="GQA")
ender.record()
# Wait for the operation to complete
ender.synchronize()
# Record time (in milliseconds)
curr_time = starter.elapsed_time(ender)
timings.append(curr_time)
# Calculate results
mean_time_ms = sum(timings) / BENCHMARK_ITER
print(f"--- {backend_name} ({attention_type}) Results ---")
print(f"Mean execution time: {mean_time_ms:.3f} ms")
return mean_time_ms
# --- 3. Execute Benchmarks ---
# 1. Test Flash Attention backend (if available)
flash_time = benchmark_sdpa(
query, key, value,
backend_name="Flash Attention",
attention_type=ATTENTION_TYPE, # Pass the determined attention type
backend_list=[SDPBackend.FLASH_ATTENTION]
)
# 2. Test the pure Math (Eager) PyTorch implementation
math_time = benchmark_sdpa(
query, key, value,
backend_name="Math (Eager)",
attention_type=ATTENTION_TYPE, # Pass the determined attention type
backend_list=[SDPBackend.MATH]
)
# --- 3. Summary ---
print("\n" + "="*40)
print(" Benchmark Summary")
print("="*40)
print(f"Attention Type: {ATTENTION_TYPE}")
print(f"Flash Attention Time: {flash_time:.3f} ms")
print(f"Math (Eager) Time: {math_time:.3f} ms")
print("="*40)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment