Skip to content

Instantly share code, notes, and snippets.

@kklemon
Last active March 22, 2024 15:01
Show Gist options
  • Save kklemon/98e491ff877c497668c715541f1bf478 to your computer and use it in GitHub Desktop.
Save kklemon/98e491ff877c497668c715541f1bf478 to your computer and use it in GitHub Desktop.
PyTorch Transformer API compatible wrapper around FlashAttention-2
import torch.nn as nn
import torch.nn.functional as F
class FlashAttentionTransformerEncoder(nn.Module):
def __init__(
self,
dim_model,
num_layers,
num_heads=None,
dim_feedforward=None,
dropout=0.0,
norm_first=False,
activation=F.gelu,
rotary_emb_dim=0,
):
super().__init__()
try:
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.modules.block import Block
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import Mlp
except ImportError:
raise ImportError('Please install flash_attn from https://github.com/Dao-AILab/flash-attention')
self._pad_input = pad_input
self._unpad_input = unpad_input
if num_heads is None:
num_heads = dim_model // 64
if dim_feedforward is None:
dim_feedforward = dim_model * 4
if isinstance(activation, str):
activation = {
'relu': F.relu,
'gelu': F.gelu
}.get(activation)
if activation is None:
raise ValueError(f'Unknown activation {activation}')
mixer_cls = partial(
MHA,
num_heads=num_heads,
use_flash_attn=True,
rotary_emb_dim=rotary_emb_dim
)
mlp_cls = partial(Mlp, hidden_features=dim_feedforward)
self.layers = nn.ModuleList([
Block(
dim_model,
mixer_cls=mixer_cls,
mlp_cls=mlp_cls,
resid_dropout1=dropout,
resid_dropout2=dropout,
prenorm=norm_first,
) for _ in range(num_layers)
])
def forward(self, x, src_key_padding_mask=None):
batch, seqlen = x.shape[:2]
if src_key_padding_mask is None:
for layer in self.layers:
x = layer(x)
else:
x, indices, cu_seqlens, max_seqlen_in_batch = self._unpad_input(x, ~src_key_padding_mask)
for layer in self.layers:
x = layer(x, mixer_kwargs=dict(
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen_in_batch
))
x = self._pad_input(x, indices, batch, seqlen)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment