Last active
October 10, 2025 16:34
-
-
Save chuanqi129/ca45593c3d2c4b279b426e4883df9b87 to your computer and use it in GitHub Desktop.
sdpa op benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn.functional import scaled_dot_product_attention | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
import time | |
import argparse | |
import sys | |
# --- XPU Device Setup (Strictly XPU) --- | |
DEVICE = "xpu" | |
# Check if torch.xpu is available, which is necessary for the XPU-specific code. | |
if not hasattr(torch, 'xpu') or not torch.xpu.is_available(): | |
print("FATAL ERROR: torch.xpu is not available. This script is strictly configured for XPU devices and cannot proceed.") | |
sys.exit(1) | |
# --- 0. Argument Parsing --- | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Benchmark Scaled Dot-Product Attention (SDPA) with variable head counts on XPU.") | |
# Required parameters for attention head configuration | |
parser.add_argument( | |
'--num_heads', | |
type=int, | |
required=True, | |
help="Number of Query heads (H_q)." | |
) | |
parser.add_argument( | |
'--num_kv_heads', | |
type=int, | |
required=True, | |
help="Number of Key/Value heads (H_kv)." | |
) | |
# Optional parameters | |
parser.add_argument('--batch_size', type=int, default=1, help="Batch size (N).") | |
parser.add_argument('--seq_len', type=int, default=8096, help="Sequence length (S/L).") | |
parser.add_argument('--head_dim', type=int, default=128, help="Head dimension (d_k).") | |
parser.add_argument('--warmup_iter', type=int, default=100, help="Number of warmup iterations.") | |
parser.add_argument('--benchmark_iter', type=int, default=200, help="Number of benchmark iterations.") | |
return parser.parse_args() | |
args = parse_args() | |
# --- 1. Parameter Setup and Data Preparation (using arguments) --- | |
DTYPE = torch.float16 | |
BATCH_SIZE = args.batch_size | |
NUM_HEADS = args.num_heads # Number of Query heads (H_q) | |
NUM_KV_HEADS = args.num_kv_heads # Number of Key/Value heads (H_kv) | |
SEQ_LEN = args.seq_len | |
HEAD_DIM = args.head_dim | |
WARMUP_ITER = args.warmup_iter | |
BENCHMARK_ITER = args.benchmark_iter | |
# Determine GQA status based on head counts | |
# GQA is enabled if Q_heads > KV_heads and KV_heads > 1 | |
# MQA is enabled if KV_heads == 1 | |
# MHA is enabled if Q_heads == KV_heads | |
IS_GQA_ENABLED = (NUM_HEADS != NUM_KV_HEADS) and (NUM_HEADS % NUM_KV_HEADS == 0) | |
IS_MHA_ENABLED = (NUM_HEADS == NUM_KV_HEADS) | |
ATTENTION_TYPE = "MHA" if IS_MHA_ENABLED else ("GQA" if IS_GQA_ENABLED else "MQA") | |
print(f"Device: {DEVICE}, DType: {DTYPE}") | |
print(f"Q Heads: {NUM_HEADS}, KV Heads: {NUM_KV_HEADS}. Attention Type: {ATTENTION_TYPE}") | |
# PyTorch SDPA uses the tensor shapes to determine GQA/MHA/MQA internally. | |
query = torch.rand(BATCH_SIZE, NUM_HEADS, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE) | |
key = torch.rand(BATCH_SIZE, NUM_KV_HEADS, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE) | |
value = torch.rand(BATCH_SIZE, NUM_KV_HEADS, SEQ_LEN, HEAD_DIM, dtype=DTYPE, device=DEVICE) | |
# --- 2. Define Benchmark Function (with GQA reporting) --- | |
def benchmark_sdpa(q, k, v, backend_name, attention_type, backend_list=None): | |
# Use xpu.Event for precise timing | |
starter = torch.xpu.Event(enable_timing=True) | |
ender = torch.xpu.Event(enable_timing=True) | |
timings = [] | |
# Use the sdpa_kernel context manager to set the attention backend | |
with sdpa_kernel(backends=backend_list): | |
# Warm-up | |
print(f"\nWarm-up for {backend_name} ({attention_type})...") | |
for _ in range(WARMUP_ITER): | |
_ = scaled_dot_product_attention(q, k, v, enable_gqa = attention_type=="GQA") | |
# Ensure warm-up operations are complete | |
torch.xpu.synchronize() | |
# Benchmark | |
print(f"Benchmarking {backend_name} for {BENCHMARK_ITER} iterations...") | |
for _ in range(BENCHMARK_ITER): | |
starter.record() | |
_ = scaled_dot_product_attention(q, k, v, enable_gqa = attention_type=="GQA") | |
ender.record() | |
# Wait for the operation to complete | |
ender.synchronize() | |
# Record time (in milliseconds) | |
curr_time = starter.elapsed_time(ender) | |
timings.append(curr_time) | |
# Calculate results | |
mean_time_ms = sum(timings) / BENCHMARK_ITER | |
print(f"--- {backend_name} ({attention_type}) Results ---") | |
print(f"Mean execution time: {mean_time_ms:.3f} ms") | |
return mean_time_ms | |
# --- 3. Execute Benchmarks --- | |
# 1. Test Flash Attention backend (if available) | |
flash_time = benchmark_sdpa( | |
query, key, value, | |
backend_name="Flash Attention", | |
attention_type=ATTENTION_TYPE, # Pass the determined attention type | |
backend_list=[SDPBackend.FLASH_ATTENTION] | |
) | |
# 2. Test the pure Math (Eager) PyTorch implementation | |
math_time = benchmark_sdpa( | |
query, key, value, | |
backend_name="Math (Eager)", | |
attention_type=ATTENTION_TYPE, # Pass the determined attention type | |
backend_list=[SDPBackend.MATH] | |
) | |
# --- 3. Summary --- | |
print("\n" + "="*40) | |
print(" Benchmark Summary") | |
print("="*40) | |
print(f"Attention Type: {ATTENTION_TYPE}") | |
print(f"Flash Attention Time: {flash_time:.3f} ms") | |
print(f"Math (Eager) Time: {math_time:.3f} ms") | |
print("="*40) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment