Skip to content

Instantly share code, notes, and snippets.

@eqy
Created October 30, 2024 23:00
Show Gist options
  • Save eqy/24246e2c70072aa5f3e3a803ef98f58f to your computer and use it in GitHub Desktop.
Save eqy/24246e2c70072aa5f3e3a803ef98f58f to your computer and use it in GitHub Desktop.
cuDNN GQA
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
batch = 4
seq_len_q = 512
seq_len_kv = 1024
D = 128
# Sample call to SDPA - GQ
query = torch.rand(batch, 32, seq_len_q, D, device='cuda', dtype=torch.bfloat16)
key = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16)
value = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16)
with sdpa_kernel([SDPBackend.MATH]):
output_math = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]):
output_cudnn = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)
torch.testing.assert_close(output_math, output_cudnn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment