Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active November 2, 2024 15:41
Show Gist options
  • Save Chillee/2e270fc5413dbbce58c779f8c4eac66c to your computer and use it in GitHub Desktop.
Save Chillee/2e270fc5413dbbce58c779f8c4eac66c to your computer and use it in GitHub Desktop.
flex_attention_tutorial.py
import torch
from torch.nn.attention._flex_attention import _create_block_mask, _create_mask
from functools import partial
from torch.nn.attention._flex_attention import _flex_attention
from triton.testing import do_bench
import torch.nn.functional as F
from functools import lru_cache
torch.set_default_device('cuda')
# Example usage
flex_attention = torch.compile(_flex_attention, dynamic=False)
# Autotunes for better perf
# flex_attention = torch.compile(_flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
torch.manual_seed(0)
data_type = torch.float16
@lru_cache
def create_block_mask_from_score_mod(score_mod, B, H, M, N, device='cuda'):
SPARSE_BLOCK = 128
block_mask = _create_block_mask(score_mod, B, H, M, N, device=device)
return block_mask
def eager_sdpa(query, key, value, score_mod, mask):
return F.scaled_dot_product_attention(query, key, value, is_causal=True)
from triton.testing import do_bench
def test_mask(score_mod, mask_fn=None, B=16, H=16, S=8192, D=64, skip_correctness=False):
if mask_fn is None:
mask_fn = score_mod
query = torch.randn(B, H, S, D, device="cuda", dtype=data_type, requires_grad=True)
key = torch.randn(B, H, S, D, device="cuda", dtype=data_type, requires_grad=True)
value = torch.randn(B, H, S, D, device="cuda", dtype=data_type, requires_grad=True)
gradOut = torch.randn(B, H, S, D, device='cuda', dtype=data_type)
# In this case assume that the mask only depends on q/kv, and so we can
# broadcast the mask across batch and heads. If that's not the case, then
# pass B and H instead of 1.
block_mask = create_block_mask_from_score_mod(mask_fn, 1, 1, S, S, device=query.device)
# Not needed for FlexAttention, only for F.scaled_dot_product_attention to check correctness.
mask = _create_mask(mask_fn, 1, 1, S, S, device=query.device)
causal_fa2 = lambda: F.scaled_dot_product_attention(query, key, value, is_causal=True)
xformers_mask = lambda: F.scaled_dot_product_attention(query, key, value, attn_mask=mask)
flex_attention_call = lambda: flex_attention(query, key, value, score_mod=score_mod, block_mask=block_mask)
print(score_mod.__name__)
print("Forward: ")
print("causal FA2: ", do_bench(causal_fa2))
print("F.sdpa + mask: ", do_bench(xformers_mask))
flex_ms = do_bench(flex_attention_call)
print("flexattention: ", flex_ms)
density = (100 - block_mask.sparsity())/100
flops = (density * B * H * D * S * S)
print("Flex FW FLOPS: ", 4 * flops * (1e3/flex_ms) / 1e12, "TF/s")
causal_fa2_out = causal_fa2()
xformers_out = xformers_mask()
flex_out = flex_attention_call()
print("Backward: ", )
print("causal FA2: ", do_bench(lambda: causal_fa2_out.backward(gradOut, retain_graph=True)))
flex_bw_ms = do_bench(lambda: flex_out.backward(gradOut, retain_graph=True))
print("flexattention: ", flex_bw_ms)
print("Flex BW FLOPS: ", 10 * flops * (1e3/flex_bw_ms) / 1e12, "TF/s")
print(block_mask)
print()
if not skip_correctness:
xformers_outs = []
flex_outs = []
query.grad = None
key.grad = None
value.grad = None
out1 = xformers_mask()
xformers_outs.append(out1)
out1.backward(gradOut)
xformers_outs += [query.grad, key.grad, value.grad]
query.grad = None
key.grad = None
value.grad = None
out2 = flex_attention_call()
flex_outs.append(out2)
out2.backward(gradOut)
flex_outs += [query.grad, key.grad, value.grad]
for flex, xformer in zip(flex_outs, xformers_outs):
torch.testing.assert_close(flex, xformer, atol=1e-1, rtol=1e-2)
##################################
# Score mod examples start here!
##################################
################
# Full attention
################
def noop(score, b, h, q_idx, kv_idx):
return score
################
# Standard causal mask
################
def causal_mask(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, -float("inf"))
SLIDING_WINDOW = 1024
################
# Sliding window attention + causal
################
def sliding_window_causal(score, b, h, q_idx, kv_idx):
return torch.where((q_idx >= kv_idx) & (q_idx - kv_idx <= SLIDING_WINDOW), score, -float("inf"))
################
# prefix LM (bidirectional attention for first PREFIX_LENGTH tokens, then causal for the rest)
################
PREFIX_LENGTH = 2048
def prefix_lm_causal(score, b, h, q_idx, kv_idx):
prefix_mask = kv_idx <= PREFIX_LENGTH
causal_mask = q_idx >= kv_idx
return torch.where(prefix_mask | causal_mask, score, -float("inf"))
################
# Document masking
################
# (Imagine that we have multiple documents of different lengths. We want to mask
# out the attention between documents, but allow attention between tokens within
# the same document. We can do this by using a document_id tensor that gives the
# document that each token belongs to. Then, we can mask out all attention
# scores where the document_id[q_idx] differs from document_id[kv_idx]
# Note: We *only* need to compile a new kernel when the `score_mod` changes
# (it'll automatically detect that using torch.compile infra). This example code
# is implemented with caching BlockMask, but in general, changing BlockMask
# *does not* require a recompile.
# That is, for document masking, we only need to compute a new BlockMask when
# the document lengths change, *not* a new kernel.
document_id = torch.zeros(32768, dtype=torch.int, device='cuda')
document_id[:4096] = 0
document_id[4096:8192] = 1
for i in range(8192, 32768, 8192):
document_id[i:i+8192] = i // 8192 + 1
def document_masking_causal(score, b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = (document_id[q_idx] == document_id[kv_idx])
return torch.where(causal_mask & document_mask, score, -float("inf"))
################
# Natten masking
################
# In this case, imagine that we have a 2D image of size (H x W) flattened into a
# sequence of tokens. We only want to attend to tokens within 8 "pixels", but
# from a 2D perspective.
#
# We can implement this score_mod by first translating the 1D position into the
# 2D coordinates. Then, we can simply check the distance of both coordinates to
# be within the window.
H = 128
W = 128
WINDOW = 8
def get_x_y(idx):
return idx // W, idx % W
def natten_mask(score, b, h, q_idx, kv_idx):
q_x, q_y = get_x_y(q_idx)
kv_x, kv_y = get_x_y(kv_idx)
return torch.where(
((q_x - kv_x).abs() <= WINDOW) & ((q_y - kv_y).abs() <= WINDOW),
score,
-float("inf"),
)
################
# Alibi Bias
################
# We are not restricted only to masking. For example, you can also implement
# alibi with this API.
alibi_bias = torch.randn(H, device='cuda')
def alibi_and_causal(score, b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
bias = alibi_bias[h] * (q_idx - kv_idx)
return torch.where(causal_mask, score + bias, -float("inf"))
################
# Tanh Soft-Capping
################
# We can also implement tanh soft-capping with this API.
# In this case, there are some nuances. In particular, the standard `tanh`
# operator in PyTorch (and CUDA/Triton) lowers to a numerically accurate but
# (relatively) quite slow implementation in SASS. See
# https://godbolt.org/z/W8afevWv1 for how the SASS looks like.
#
# So, in this case, we want to lower the `tanh` into the approximate tanh
# implementation. We can do so by register a custom operator in PyTorch and then
# an Inductor lowering.
@torch.library.custom_op("approx::tanh", mutates_args=())
def tanh_approx(inp: torch.Tensor) -> torch.Tensor:
return torch.tanh(inp)
@tanh_approx.register_fake
def _(inp: torch.Tensor) -> torch.Tensor:
return torch.tanh(inp)
# Some internal torch.compile details :P
from torch._inductor.virtualized import ops
from torch._inductor.lowering import make_pointwise, register_lowering
def tanh_approx_lowering(inp):
fn = partial(ops.inline_asm_elementwise, asm="tanh.approx.f32 $0, $1;")
return make_pointwise(fn)(inp)
register_lowering(torch.ops.approx.tanh)(tanh_approx_lowering)
class TanhApprox(torch.autograd.Function):
@staticmethod
def forward(x):
return torch.ops.approx.tanh(x)
@staticmethod
def setup_context(ctx, inputs, output):
x, = inputs
result = output
ctx.save_for_backward(result)
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * (1 - result * result)
tanh_approx = TanhApprox.apply
def tanh_soft_cap(score, b, h, q_idx, kv_idx):
score = score / 2
score = tanh_approx(score)
score = score * 2
return torch.where(q_idx >= kv_idx, score, -float("inf"))
test_mask(noop)
test_mask(causal_mask)
test_mask(sliding_window_causal)
test_mask(prefix_lm_causal)
test_mask(document_masking_causal, B=4, H=16, S=32768, D=64)
test_mask(natten_mask, B=4, H=16, S=H*W, D=64)
test_mask(alibi_and_causal, skip_correctness=True) # Biases more annoying to test correctness in our current setup
test_mask(tanh_soft_cap, mask_fn=causal_mask, skip_correctness=True)
@Chillee
Copy link
Author

Chillee commented Jul 3, 2024

noop
Forward: 
causal FA2:  7.297799587249756
F.sdpa + mask:  47.97502136230469
flexattention:  11.321952819824219
Flex FW FLOPS:  388.4529975609175 TF/s
Backward: 
causal FA2:  23.11943244934082
^[flexattention:  36.00835418701172
Flex BW FLOPS:  305.349037077789 TF/s
BlockMask(sparsity=0.00%, mask=
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
████████████████████████████████
)
causal_mask
Forward: 
causal FA2:  7.333255767822266
F.sdpa + mask:  48.00724792480469
flexattention:  7.069721221923828
Flex FW FLOPS:  315.90821247577384 TF/s
Backward: 
causal FA2:  23.15100860595703
flexattention:  23.34650421142578
Flex BW FLOPS:  239.1560395610515 TF/s
BlockMask(sparsity=49.22%, mask=
░░                              
██░░                            
████░░                          
██████░░                        
████████░░                      
██████████░░                    
████████████░░                  
██████████████░░                
████████████████░░              
██████████████████░░            
████████████████████░░          
██████████████████████░░        
████████████████████████░░      
██████████████████████████░░    
████████████████████████████░░  
██████████████████████████████░░
)
sliding_window_causal
Forward: 
causal FA2:  7.3098955154418945
F.sdpa + mask:  47.99585723876953
flexattention:  2.253040075302124
Flex FW FLOPS:  257.3503202699349 TF/s
Backward: 
causal FA2:  23.12451171875
flexattention:  7.815808296203613
Flex BW FLOPS:  185.46405022550172 TF/s
BlockMask(sparsity=86.82%, mask=
░░                              
██░░                            
░░██░░                          
  ░░██░░                        
    ░░██░░                      
      ░░██░░                    
        ░░██░░                  
          ░░██░░                
            ░░██░░              
              ░░██░░            
                ░░██░░          
                  ░░██░░        
                    ░░██░░      
                      ░░██░░    
                        ░░██░░  
                          ░░██░░
)
prefix_lm_causal
Forward: 
causal FA2:  7.220276355743408
F.sdpa + mask:  48.05806350708008
flexattention:  7.7534003257751465
Flex FW FLOPS:  306.8862411339657 TF/s
Backward: 
causal FA2:  23.12909698486328
flexattention:  26.538402557373047
Flex BW FLOPS:  224.1479942924201 TF/s
BlockMask(sparsity=45.90%, mask=
████████░░                      
████████░░                      
████████░░                      
████████░░                      
████████░░                      
██████████░░                    
████████████░░                  
██████████████░░                
████████████████░░              
██████████████████░░            
████████████████████░░          
██████████████████████░░        
████████████████████████░░      
██████████████████████████░░    
████████████████████████████░░  
██████████████████████████████░░
)
document_masking_causal
Forward: 
causal FA2:  28.504758834838867
F.sdpa + mask:  195.62339782714844
flexattention:  8.883938789367676
Flex FW FLOPS:  220.4545904030704 TF/s
Backward: 
causal FA2:  82.22627258300781
flexattention:  22.474464416503906
Flex BW FLOPS:  217.85892765677994 TF/s
BlockMask(sparsity=88.87%, mask=
░░                                      
██░░                                    
░░░░░░                                  
    ░░░░                                
    ░░░░░░                              
        ░░░░                            
        ░░██░░                          
        ░░████░░                        
        ░░██████░░                      
        ░░░░░░░░░░░░                    
                  ░░░░                  
                  ░░██░░                
                  ░░████░░              
                  ░░██████░░            
                  ░░░░░░░░░░░░          
                            ░░░░        
                            ░░██░░      
                            ░░████░░    
                            ░░██████░░  
                            ░░████████░░
)
natten_mask
Forward: 
causal FA2:  7.291690349578857
F.sdpa + mask:  49.27201843261719
flexattention:  2.7057206630706787
Flex FW FLOPS:  208.73854686205152 TF/s
Backward: 
causal FA2:  21.542888641357422
flexattention:  7.721324920654297
Flex BW FLOPS:  182.86634911361188 TF/s
BlockMask(sparsity=87.16%, mask=
██░░░░                                
░░██░░░░                              
░░░░██░░░░                            
  ░░░░██░░░░                          
    ░░░░██░░░░                        
      ░░░░██░░░░                      
        ░░░░██░░░░                    
          ░░░░██░░░░                  
            ░░░░██░░░░                
              ░░░░██░░░░              
                ░░░░██░░░░            
                  ░░░░██░░░░          
                    ░░░░██░░░░        
                      ░░░░██░░░░      
                        ░░░░██░░░░    
                          ░░░░██░░░░  
                            ░░░░██░░░░
                              ░░░░████
                                ░░████
)
alibi_and_causal
Forward: 
causal FA2:  7.342166423797607
F.sdpa + mask:  48.03105545043945
flexattention:  8.002997398376465
Flex FW FLOPS:  279.0683143759459 TF/s
Backward: 
causal FA2:  23.15782356262207
flexattention:  25.39375877380371
Flex BW FLOPS:  219.87518801509265 TF/s
BlockMask(sparsity=49.22%, mask=
░░                              
██░░                            
████░░                          
██████░░                        
████████░░                      
██████████░░                    
████████████░░                  
██████████████░░                
████████████████░░              
██████████████████░░            
████████████████████░░          
██████████████████████░░        
████████████████████████░░      
██████████████████████████░░    
████████████████████████████░░  
██████████████████████████████░░
)
tanh_soft_cap
Forward: 
causal FA2:  7.31266450881958
F.sdpa + mask:  48.008174896240234
flexattention:  8.117090225219727
Flex FW FLOPS:  275.1457643012146 TF/s
Backward: 
causal FA2:  23.170608520507812
flexattention:  27.235042572021484
Flex BW FLOPS:  205.01005166541862 TF/s
BlockMask(sparsity=49.22%, mask=
░░                              
██░░                            
████░░                          
██████░░                        
████████░░                      
██████████░░                    
████████████░░                  
██████████████░░                
████████████████░░              
██████████████████░░            
████████████████████░░          
██████████████████████░░        
████████████████████████░░      
██████████████████████████░░    
████████████████████████████░░  
██████████████████████████████░░
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment