Last active
April 26, 2023 14:52
-
-
Save wkcn/65bbf94037222a38af78169f7f2c206b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import numpy as np | |
from flash_attn.flash_attention import FlashAttention | |
class Attention(nn.Module): | |
use_flash_attn: bool = False | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
qkv_bias=False, | |
qk_scale=None, | |
attn_drop=0., | |
proj_drop=0., | |
norm_layer=nn.LayerNorm, | |
): | |
super().__init__() | |
assert dim % num_heads == 0, 'dim should be divisible by num_heads' | |
self.num_heads = num_heads | |
self.head_dim = dim // num_heads | |
self.scale = qk_scale or self.head_dim ** -0.5 | |
self.flash_attn = FlashAttention(attention_dropout=attn_drop) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
if self.use_flash_attn: | |
return self.flash_attn_forward(x) | |
return self.naive_forward(x) | |
def flash_attn_forward(self, x): | |
# The input of FlashAttention is (B, N, 3, H, D) | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) | |
# (B, N, H, D) | |
x = self.flash_attn(qkv)[0] | |
x = x.flatten(2) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
def naive_forward(self, x): | |
B, N, C = x.shape | |
# (B, N, C) -> (B, N, 3, H, C/H) -> (3, B, H, N, C/H) -> (Q, K, V) | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) | |
# (B, H, N, C/H) | |
q, k, v = qkv.unbind(0) | |
q = q * self.scale | |
attn = q @ k.transpose(-2, -1) | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = attn @ v | |
x = x.transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
dim_per_head = 64 | |
num_heads = 64 | |
dim = dim_per_head * num_heads | |
attn = Attention(dim, num_heads, dim_per_head) | |
attn.cuda() | |
attn.half() | |
B, N, C = 128, 14*14, dim | |
x = torch.randn(B, N, C, device='cuda', dtype=torch.float16) | |
y = attn(x) | |
attn.use_flash_attn = True | |
y2 = attn(x) | |
print('Flash Attention forward works!') | |
y2.sum().backward() | |
print('Flash Attention backward works!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment