Skip to content

Instantly share code, notes, and snippets.

@antferdom
Created September 16, 2024 11:33
Show Gist options
  • Save antferdom/f7874ab68f4c1183d2b8196d2ace3ffc to your computer and use it in GitHub Desktop.
Save antferdom/f7874ab68f4c1183d2b8196d2ace3ffc to your computer and use it in GitHub Desktop.
FlashAttention v3 within torch.compile compatible
from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple
import torch
try:
from flash_attn_interface import flashattn_hopper_cuda as _C_flashattention3
except ImportError:
# We end up here is arch is not 90a
_C_flashattention3 = None
if _C_flashattention3 is not None:
# returns: out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p
@torch.library.custom_op(
"hopper_flash3::flash_fwd", mutates_args=(), device_types=["cuda"]
)
def mha_fwd(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
softmax_scale: Optional[float],
is_causal: bool,
) -> Tuple[torch.Tensor, torch.Tensor,]:
if softmax_scale is None:
softmax_scale = query.shape[-1] ** (-0.5)
(
out,
q_padded,
k_padded,
v_padded,
out_padded,
softmax_lse,
p,
) = _C_flashattention3.fwd(
query, key, value, None, softmax_scale, None, None, None, is_causal
)
return out, softmax_lse
class HopperMHA(torch.autograd.Function):
@staticmethod
def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
softmax_scale: float,
is_causal: bool,):
return torch.ops.hopper_flash3.flash_fwd(
query,
key,
value,
softmax_scale,
is_causal,
)
@staticmethod
def setup_context(ctx, inputs, output):
pass
@staticmethod
def backward(ctx, grad_output):
pass
hopper_mha = HopperMHA.apply
torch.manual_seed(0)
repeats = 10
dropout_p = 0.0
causal = False
dtype = torch.float16
device = "cuda"
verbose = False
batch_size = 1
seqlen = 512
dim = 2048
head_dim = 256
n_heads = dim // head_dim
n_heads_kv = n_heads
qkv = torch.randn(batch_size, seqlen, 3, n_heads, head_dim, device=device, dtype=dtype,
requires_grad=True)
q = torch.randn(batch_size, seqlen, n_heads, head_dim, device=device, dtype=dtype, requires_grad=True)
k = torch.randn(batch_size, seqlen, n_heads_kv, head_dim, device=device, dtype=dtype, requires_grad=True)
v = torch.randn(batch_size, seqlen, n_heads_kv, head_dim, device=device, dtype=dtype, requires_grad=True)
q_t = q.transpose(1, 2).contiguous().detach().requires_grad_()
k_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
v_t = k.transpose(1, 2).contiguous().detach().requires_grad_()
ref_o = hopper_mha(q, k, v, None, causal)
print(ref_o)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment