Last active
October 7, 2023 04:44
-
-
Save KeremTurgutlu/847dd84519e28df85e68f8d88dc29905 to your computer and use it in GitHub Desktop.
Multipack Sampler x Flash Attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Testing flash attn with multipacking which essentially packs sequences using https://github.com/imoneoi/multipack_sampler, | |
and passes a single sequence of `1 x (bs x seqlen)` to the model to avoid padding. | |
An alternative is to use block diagonal attention as attention bias, but the following uses flash attention 2 which | |
is much faster. | |
Multipacking can be used to speed up both pretraining and finetuning. | |
""" | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
try: | |
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports | |
flash_attn_kvpacked_func, | |
flash_attn_varlen_kvpacked_func, | |
flash_attn_varlen_qkvpacked_func, | |
) | |
except ImportError: | |
from flash_attn.flash_attn_interface import ( | |
flash_attn_unpadded_kvpacked_func as flash_attn_varlen_kvpacked_func, | |
) | |
from flash_attn.flash_attn_interface import ( | |
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func, | |
) | |
# packed sequence to [bs x seqlen = 16] | |
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]]) | |
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]]) | |
# begin idx of each sequence in the pack (size: bsz/num samples + 1) | |
cu_seqlens = torch.tensor([0, 4, 7, 12, 14, 16]).to(torch.int32) | |
max_seqlen = 5 | |
for i in cu_seqlens[:-1]: print(attn_mask[0][i]) | |
#tensor(1) | |
#tensor(2) | |
#tensor(3) | |
#tensor(4) | |
#tensor(0) | |
# total size: bs x seqlen (after packing sequences) | |
bs_x_seqlen = 16 | |
# create random qkv | |
qkv = torch.randn( | |
bs_x_seqlen, 3, 2, 128, device="cuda:0", dtype=torch.float16, requires_grad=True | |
) | |
# should use block diagonal attn mask | |
attn_output_flash = flash_attn_varlen_qkvpacked_func( | |
qkv.cuda(), cu_seqlens.cuda(), max_seqlen, 0.0, softmax_scale=None, causal=True | |
) | |
# create block diagonal attn mask manually for torch testing | |
attn_mask = make_decoder_mask_pt(position_ids, torch.int32, decoder_segment_ids=attn_mask) | |
# convert to attn bias (inverse of attn mask) | |
attn_bias = attn_mask2attn_bias(attn_mask) | |
attn_mask, attn_bias | |
# (tensor([[[[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], | |
# [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1]]]], | |
# dtype=torch.int32), | |
# tensor([[[[ 0, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, | |
# -128, -128, -128, -128, -128], | |
# [ 0, 0, -128, -128, -128, -128, -128, -128, -128, -128, -128, | |
# -128, -128, -128, -128, -128], | |
# [ 0, 0, 0, -128, -128, -128, -128, -128, -128, -128, -128, | |
# -128, -128, -128, -128, -128], | |
# [ 0, 0, 0, 0, -128, -128, -128, -128, -128, -128, -128, | |
# -128, -128, -128, -128, -128], | |
# [-128, -128, -128, -128, 0, -128, -128, -128, -128, -128, -128, | |
# -128, -128, -128, -128, -128], | |
# [-128, -128, -128, -128, 0, 0, -128, -128, -128, -128, -128, | |
# -128, -128, -128, -128, -128], | |
# [-128, -128, -128, -128, 0, 0, 0, -128, -128, -128, -128, | |
# -128, -128, -128, -128, -128], | |
# [-128, -128, -128, -128, -128, -128, -128, 0, -128, -128, -128, | |
# -128, -128, -128, -128, -128], | |
# [-128, -128, -128, -128, -128, -128, -128, 0, 0, -128, -128, | |
# -128, -128, -128, -128, -128], | |
# [-128, -128, -128, -128, -128, -128, -128, 0, 0, 0, -128, | |
# -128, -128, -128, -128, -128], | |
# [-128, -128, -128, -128, -128, -128, -128, 0, 0, 0, 0, | |
# -128, -128, -128, -128, -128], | |
# [-128, -128, -128, -128, -128, -128, -128, 0, 0, 0, 0, | |
# 0, -128, -128, -128, -128], | |
# [-128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, | |
# -128, 0, -128, -128, -128], | |
# [-128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, | |
# -128, 0, 0, -128, -128], | |
# [-128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, | |
# -128, -128, -128, 0, -128], | |
# [-128, -128, -128, -128, -128, -128, -128, -128, -128, -128, -128, | |
# -128, -128, -128, 0, 0]]]], dtype=torch.int32)) | |
# reshape q,k,v for torch scaled_dot_product_attention | |
q,k,v = qkv.unbind(1) | |
q,k,v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0) | |
q = rearrange(q, "b s h d -> b h s d") | |
k = rearrange(k, "b s h d -> b h s d") | |
v = rearrange(v, "b s h d -> b h s d") | |
attn_output_torch = F.scaled_dot_product_attention( | |
q.cuda(), k.cuda(), v.cuda(), attn_bias.to(q.dtype).cuda(), 0.0, is_causal=False | |
) | |
attn_output_torch = rearrange(attn_output_torch[0], "h s d -> s h d") | |
torch.isclose(attn_output_flash, attn_output_torch).float().mean() | |
# tensor(1., device='cuda:0') | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment