Created
October 30, 2024 23:00
-
-
Save eqy/24246e2c70072aa5f3e3a803ef98f58f to your computer and use it in GitHub Desktop.
cuDNN GQA
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.nn.functional import scaled_dot_product_attention | |
from torch.nn.attention import SDPBackend, sdpa_kernel | |
batch = 4 | |
seq_len_q = 512 | |
seq_len_kv = 1024 | |
D = 128 | |
# Sample call to SDPA - GQ | |
query = torch.rand(batch, 32, seq_len_q, D, device='cuda', dtype=torch.bfloat16) | |
key = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16) | |
value = torch.rand(batch, 8, seq_len_kv, D, device='cuda', dtype=torch.bfloat16) | |
with sdpa_kernel([SDPBackend.MATH]): | |
output_math = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) | |
with sdpa_kernel([SDPBackend.CUDNN_ATTENTION]): | |
output_cudnn = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True) | |
torch.testing.assert_close(output_math, output_cudnn) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment