Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active October 23, 2024 20:05
Show Gist options
  • Save Birch-san/d2afeaf59a557ca0a5c807accfe7bfd2 to your computer and use it in GitHub Desktop.
Save Birch-san/d2afeaf59a557ca0a5c807accfe7bfd2 to your computer and use it in GitHub Desktop.
Benchmark various ways of doing T5 Encoder flex_attention against SDPA
from enum import Enum
from typing import Callable, Optional, Any
from einops import rearrange
from dataclasses import dataclass
import math
import torch
from torch import FloatTensor, LongTensor, IntTensor, BoolTensor, ByteTensor, no_grad, inference_mode
from torch.nn import Embedding, Linear, Module
from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, _score_mod_signature, _mask_mod_signature
from torch.nn.functional import scaled_dot_product_attention
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
import argparse
def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float:
iters_per_second = 1e3/ms_per_iter
return iters_per_second * flop_count
def fmt_flops(flops: int) -> str:
return f"{flops / 1e12:5.1f} TFLOP/s"
def get_flop_count(f: Callable[[], None]) -> int:
flop_counter = FlopCounterMode(display=True)
with flop_counter:
f()
return flop_counter.get_total_flops()
def get_flops(f: Callable[[], None], do_print=True, flop_count: Optional[int] = None) -> float:
ms_per_iter: float = do_bench(f, warmup=1000, return_mode='median')
if flop_count is None:
# we deliberately count FLOPs *after* benchmarking, because counting FLOPs of a (compiled) model *before*
# benchmarking, seems to regress the compiled model to use a slower implementation.
# this is a bug(?) that did not occur in torch 2.4.1 but started occurring in 2.5.0.
# https://x.com/Birchlabs/status/1847369302976188819
flop_count: int = get_flop_count(f)
flops: float = mpi_to_flops(ms_per_iter, flop_count)
if do_print:
print(fmt_flops(flops))
return flops
class Checkpoint(str, Enum):
Small = 'small'
Base = 'base'
Large = 'large'
XL = 'xl'
XXL = 'xxl'
ckpt_to_heads: dict[Checkpoint, int] = {
Checkpoint.Small: 6,
Checkpoint.Base: 12,
Checkpoint.Large: 16,
Checkpoint.XL: 32,
Checkpoint.XXL: 64,
}
ckpt_to_dim: dict[Checkpoint, int] = {
Checkpoint.Small: 512,
Checkpoint.Base: 768,
Checkpoint.Large: 1024,
Checkpoint.XL: 2048,
Checkpoint.XXL: 4096,
}
class ScoreModAlgo(str, Enum):
JumpTableAndEmb = 'jump_table_and_emb'
"probably the most reasonable algorithm"
RepeatEmbs = 'repeat_embs'
"more copying of embeddings, less arithmetic on positions"
ComputeBuckets = 'compute_buckets'
"no jump table; compute the bucket directly from the position. entails float32 ratio-of-logarithms arithmetic in the score_mod."
ReadBias = 'read_bias'
"[fastest] nothing clever; maximizes IO"
MinJumpTable = 'min_jump_table'
"smallest jump table, using complicated arithmetic"
def _relative_position(
q_len: int,
k_len: Optional[int] = None,
cached_autoregressive=False,
device = torch.device('cpu'),
) -> LongTensor:
if k_len is None:
k_len = q_len
memory_position = torch.arange(k_len, dtype=torch.long, device=device).unsqueeze(0)
if cached_autoregressive:
# only the final query position will be kept, so that's the only one we'll compute
context_position = q_len - 1
else:
context_position = torch.arange(q_len, dtype=torch.long, device=device).unsqueeze(-1)
relative_position = memory_position - context_position # shape (q_len, k_len)
return relative_position
# based on HF implementation, Apache-licensed:
# https://github.com/huggingface/transformers/blob/9138935784583203fb5f61e8f581cdfdcd887e0f/src/transformers/models/t5/modeling_t5.py#L384
def _relative_position_bucket(
relative_position: LongTensor, bidirectional: bool, num_buckets=32, max_distance=128
) -> FloatTensor:
# in cached autoregressive inference, we have 1 query attending to n keys.
# we move the diagonal to be equivalent to having n queries attending to n keys.
*_, q_len, k_len = relative_position.shape
excess_keys: int = k_len - q_len
if bidirectional:
num_buckets //= 2
# I think the excess_keys offset here is never exercised in practice,
# because the only bidirectional case is encoder self-attn, which doesn't need KV-caching.
# still, it's probably the correct way to adjust the diagonal if you somehow had that use-case.
relative_buckets = torch.triu(torch.full_like(relative_position, num_buckets), diagonal=1 + excess_keys)
relative_position = torch.abs(relative_position)
else:
relative_buckets = torch.zeros_like(relative_position)
relative_position = -torch.tril(relative_position, diagonal=excess_keys)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = (
max_exact
+ (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).long()
)
relative_position_if_large = relative_position_if_large.min(relative_position_if_large.new_tensor(num_buckets - 1))
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
class SDPAAttn(Module):
qkv_proj: Linear
o_proj: Linear
head_dim: int
def __init__(self, qkv_proj: Linear, o_proj: Linear, head_dim: int):
super().__init__()
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.head_dim = head_dim
def forward(
self,
x: FloatTensor,
position_bias: FloatTensor,
mask: Optional[BoolTensor] = None,
) -> FloatTensor:
qkv: FloatTensor = self.qkv_proj(x)
q, k, v = rearrange(
qkv, "batch seq (proj heads head_dim) -> proj batch heads seq head_dim", proj=3, head_dim=self.head_dim
).unbind()
if mask is not None:
assert mask.ndim == 4, "Expected [batch, heads, q, k] attention mask"
position_bias = position_bias.where(mask, -1e5)
a = scaled_dot_product_attention(
q,
k,
v,
# fused kernel requires last dimension of input to have stride 1.
attn_mask=position_bias.contiguous(),
dropout_p=0.0,
)
a = rearrange(a, "batch heads seq head_dim -> batch seq (heads head_dim)")
o = self.o_proj(a)
return o
class FlexAttn(Module):
qkv_proj: Linear
o_proj: Linear
head_dim: int
score_mod: _score_mod_signature
flex_kernel_options: dict[str, Any]
def __init__(
self,
qkv_proj: Linear,
o_proj: Linear,
head_dim: int,
score_mod: _score_mod_signature,
flex_kernel_options: dict[str, Any] = {},
):
super().__init__()
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.head_dim = head_dim
self.score_mod = score_mod
self.flex_kernel_options = flex_kernel_options
def forward(
self,
x: FloatTensor,
block_mask: Optional[BlockMask] = None,
) -> FloatTensor:
qkv: FloatTensor = self.qkv_proj(x)
q, k, v = rearrange(
qkv, "batch seq (proj heads head_dim) -> proj batch heads seq head_dim", proj=3, head_dim=self.head_dim
).unbind()
a = flex_attention(
q,
k,
v,
score_mod=self.score_mod,
block_mask=block_mask,
kernel_options=self.flex_kernel_options,
)
a = rearrange(a, "batch heads seq head_dim -> batch seq (heads head_dim)")
o = self.o_proj(a)
return o
def get_score_mod_jump_table_and_emb(
emb_weight: FloatTensor,
num_buckets=32,
max_distance=128,
ctx_len=512,
) -> _score_mod_signature:
"""
Minimize the amount that we read into the score_mod kernel to just the emb_weight and a jump table:
bfloat16 [num_buckets=32, heads=6~64]
int8 [ctx_len=512] # this would fit in uint4 if torch supported it
it *is* possible to minimize the jump table's size to ~91 instead of 512, by introducing more offsets and fallbacks to the jump arithmetic.
I have an implementation of that, and it was slightly slower. probably not worth the code complexity, but it had a cool property whereby you don't need to know ctx_len.
"""
half_buckets = num_buckets // 2
max_exact = half_buckets // 2
relpos_to_bucket: LongTensor = torch.arange(ctx_len, device=emb_weight.device, dtype=torch.float32).div_(max_exact).log_().mul_((half_buckets - max_exact) / math.log(max_distance / max_exact)).long().clamp_max(max_exact-1).add_(max_exact).byte()
relpos_to_bucket[:max_exact].copy_(torch.arange(max_exact, device=emb_weight.device))
def score_mod(
score: FloatTensor,
b: IntTensor,
h: IntTensor,
q_idx: IntTensor,
kv_idx: IntTensor,
) -> FloatTensor:
relpos = (kv_idx - q_idx).abs_()
relpos_buckets = relpos_to_bucket[relpos]
relpos_buckets.add_(kv_idx > q_idx, alpha=half_buckets)
return score + emb_weight[relpos_buckets.int(), h]
return score_mod
def get_score_mod_repeat_embs(
emb_weight: FloatTensor,
num_buckets=32,
max_distance=128,
ctx_len=512,
) -> _score_mod_signature:
"""
Minimize the computation in the score_mod kernel by repeating the embedding over every position we'll use:
bfloat16 [positions=num_buckets, heads=6~64]
the position arithmetic is quite fiddly, I am surprised it was allclose on the first try; don't be surprised if there's still an off-by-one error somewhere.
"""
half_buckets = num_buckets // 2
max_exact = half_buckets // 2
max_relpos_ix: int = math.ceil(math.exp((max_exact - 1) * math.log(max_distance/max_exact) / (half_buckets - max_exact)) * max_exact)
relpos_to_bucket: LongTensor = torch.arange(ctx_len, device=emb_weight.device, dtype=torch.float32).div_(max_exact).log_().mul_((half_buckets - max_exact) / math.log(max_distance / max_exact)).long().clamp_max(max_exact-1).add_(max_exact).byte()
relpos_to_bucket[:max_exact].copy_(torch.arange(max_exact, device=emb_weight.device))
relpos_to_bucket_bidi = torch.cat([relpos_to_bucket[:max_relpos_ix+1].flip(-1), relpos_to_bucket[1:max_relpos_ix+1]+half_buckets])
relpos_to_emb = emb_weight[relpos_to_bucket_bidi.int()]
def score_mod(
score: FloatTensor,
b: IntTensor,
h: IntTensor,
q_idx: IntTensor,
kv_idx: IntTensor,
) -> FloatTensor:
return score + relpos_to_emb[(kv_idx - q_idx).clamp_(-max_relpos_ix, max_relpos_ix).add_(max_relpos_ix), h]
return score_mod
def get_score_min_jump_table(
emb_weight: FloatTensor,
num_buckets=32,
max_distance=128,
) -> _score_mod_signature:
"""
Minimize the amount that we read into the score_mod kernel to just the emb_weight and a (small) jump table:
bfloat16 [num_buckets=32, heads=6~64]
int8 [positions=~91] # this would fit in uint4 if torch supported it
"""
# encoder self-attn is bidirectional; each direction uses half the buckets
half_buckets = num_buckets // 2
max_half_bucket_ix = half_buckets - 1
# (in each direction)
# half of the buckets are allocated for exact increments in position
# the other half are logarithmically growing bins in positions up to max_distance
max_exact = half_buckets // 2
max_relpos_ix: int = math.ceil(math.exp((max_exact - 1) * math.log(max_distance/max_exact) / (half_buckets - max_exact)) * max_exact)
distant_relpos_to_bucket: ByteTensor = torch.arange(max_exact, max_relpos_ix, device=emb_weight.device, dtype=torch.float32).div_(max_exact).log_().mul_((half_buckets - max_exact) / math.log(max_distance / max_exact)).long().clamp_max(max_exact-1).add_(max_exact).byte()
def score_mod(
score: FloatTensor,
b: IntTensor,
h: IntTensor,
q_idx: IntTensor,
kv_idx: IntTensor,
) -> FloatTensor:
relative_position = (kv_idx - q_idx).abs_()
is_small = relative_position < max_exact
below_max_distance = relative_position < max_relpos_ix
distant_position = torch.where(below_max_distance, distant_relpos_to_bucket[(relative_position-max_exact).clamp_(0, max_relpos_ix-max_exact-1)], max_half_bucket_ix)
relpos_buckets: ByteTensor = torch.where(is_small, relative_position, distant_position)
relpos_buckets.add_(kv_idx > q_idx, alpha=half_buckets)
return score + emb_weight[relpos_buckets.int(), h]
return score_mod
def get_score_mod_compute_buckets(
emb_weight: FloatTensor,
num_buckets=32,
max_distance=128,
) -> _score_mod_signature:
"""
Maximize the computation in the score_mod kernel, but eliminate the need for a jump table. Only requires emb_weights, and doesn't need to know ctx_len:
bfloat16 [positions=num_buckets, heads=6~64]
"""
# note: only implemented for bidirectional (i.e. encoder self-attn)
# encoder self-attn is bidirectional; each direction uses half the buckets
half_buckets = num_buckets // 2
max_half_bucket_ix = half_buckets - 1
# (in each direction)
# half of the buckets are allocated for exact increments in position
# the other half are logarithmically growing bins in positions up to max_distance
max_exact = half_buckets // 2
relpos_coeff: float = (half_buckets - max_exact) / math.log(max_distance / max_exact)
def score_mod(
score: FloatTensor,
b: IntTensor,
h: IntTensor,
q_idx: IntTensor,
kv_idx: IntTensor,
) -> FloatTensor:
# NOTE: in decoder self-attn, if kv-cached decoding is used: we would see
# only the final query element, and its position would need adjusting.
# q_idx would not run from 0 to q_len-1,
# the only query would be q_idx=0, and you'd want to add q_len-1 to it, e.g.
# ctx_pos = q_idx + q_len-1
# or just:
# ctx_pos = q_len-1
ctx_pos = q_idx
relative_position = (kv_idx - ctx_pos).abs_()
is_small = relative_position < max_exact
# NOTE: in decoder self-attn, if kv-cached decoding is used: we would need
# to move the diagonal of this upper-triangular mask, so that no key gains
# a positive position (they'd all be behind the query).
# maybe like this:
# excess_keys: int = k_len - q_len
# relative_buckets = torch.where(kv_idx > ctx_pos + excess_keys, half_buckets, 0)
# errr or maybe just:
# relative_buckets = 0
relative_buckets = torch.where(kv_idx > ctx_pos, half_buckets, 0)
relpos_distant = (max_exact + relative_position.to(torch.float32, copy=True).div_(max_exact).log_().mul_(relpos_coeff).long()).clamp_max_(max_half_bucket_ix)
relative_buckets.add_(torch.where(is_small, relative_position, relpos_distant))
return score + emb_weight[relative_buckets, h]
return score_mod
def get_score_mod_read_bias(bias: FloatTensor) -> _score_mod_signature:
"""
Maximize the amount that we read into the score_mod kernel:
bfloat16 [heads=6~64, ctx_len=512, ctx_len=512]
you'd still hope this would be competitive with cutlassF/cuDNN
"""
def score_mod(
score: FloatTensor,
b: IntTensor,
h: IntTensor,
q_idx: IntTensor,
kv_idx: IntTensor,
) -> FloatTensor:
return score + bias[h, q_idx, kv_idx]
return score_mod
def get_score_mod_read_masked_bias(bias: FloatTensor) -> _score_mod_signature:
"""
Maximize the amount that we read into the score_mod kernel:
bfloat16 [heads=6~64, ctx_len=512, ctx_len=512]
you'd still hope this would be competitive with cutlassF/cuDNN
"""
def score_mod(
score: FloatTensor,
b: IntTensor,
h: IntTensor,
q_idx: IntTensor,
kv_idx: IntTensor,
) -> FloatTensor:
return score + bias[b, h, q_idx, kv_idx]
return score_mod
def make_mask_mod(mask: BoolTensor, mask_pad_queries=False) -> _mask_mod_signature:
if mask_pad_queries:
# faster (more sparsity) but outputs in pad positions will be 0-valued so you will fail parity tests against SDPA, where our masked bias doesn't do this
def mask_mod(
batch: IntTensor,
head: IntTensor,
q_idx: IntTensor,
kv_idx: IntTensor,
) -> BoolTensor:
return mask[batch, kv_idx] & mask[batch, q_idx]
else:
def mask_mod(
batch: IntTensor,
head: IntTensor,
q_idx: IntTensor,
kv_idx: IntTensor,
) -> BoolTensor:
return mask[batch, kv_idx]
return mask_mod
def make_block_mask(mask: BoolTensor, mask_pad_queries=False) -> BlockMask:
seq_len: int = mask.size(-1)
mask_mod: _mask_mod_signature = make_mask_mod(mask, mask_pad_queries)
block_mask: BlockMask = create_block_mask(
mask_mod=mask_mod,
B=mask.size(0),
H=1, # broadcast over all heads
Q_LEN=seq_len,
KV_LEN=seq_len,
)
return block_mask
@dataclass
class Args:
ckpt: Checkpoint
num_buckets: int
max_distance: int
batch_size: int
ctx_len: int
head_dim: int
visible_tokens: int
block_m: Optional[int]
block_n: Optional[int]
disable_block_mask: bool
disable_bias_mask: bool
score_mod_algo: ScoreModAlgo
seed: int
parity_test: bool
mask_pad_queries: bool
def main(args: Args):
heads: int = ckpt_to_heads[args.ckpt]
dim: int = ckpt_to_dim[args.ckpt]
flex_kernel_options: dict[str, Any] = {}
if args.block_m is not None:
flex_kernel_options['BLOCK_M'] = args.block_m
if args.block_n is not None:
flex_kernel_options['BLOCK_N'] = args.block_n
print(f'''
ckpt: {args.ckpt}
num_buckets: {args.num_buckets}
max_distance: {args.max_distance}
batch_size: {args.batch_size}
ctx_len: {args.ctx_len}
head_dim: {args.head_dim}
heads: {heads}
dim: {dim}
visible_tokens: {args.visible_tokens}
disable_block_mask: {args.disable_block_mask}
disable_bias_mask: {args.disable_bias_mask}
mask_pad_queries: {args.mask_pad_queries}
score_mod_algo: {args.score_mod_algo}
seed: {args.seed}
flex_kernel_options: {flex_kernel_options}
''')
dtype = torch.bfloat16
device = torch.device('cuda')
with device:
torch.manual_seed(args.seed)
pos_emb = Embedding(
num_embeddings=args.num_buckets,
embedding_dim=heads,
dtype=dtype,
).eval()
pos_emb.weight.requires_grad_(False)
torch.manual_seed(args.seed)
# strictly speaking we don't need to measure the perf of the projections,
# but maybe they'll enforce some contiguity/stride constraints somehow.
qkv_proj = Linear(
in_features=dim,
out_features=args.head_dim * heads * 3,
bias=False,
dtype=dtype,
).eval()
qkv_proj.weight.requires_grad_(False)
torch.manual_seed(args.seed)
o_proj = Linear(
in_features=args.head_dim * heads,
out_features=dim,
bias=False,
dtype=dtype,
).eval()
o_proj.weight.requires_grad_(False)
gen = torch.Generator(device='cpu')
gen.manual_seed(args.seed)
hidden_states = torch.randn(args.batch_size, args.ctx_len, dim, generator=gen, device='cpu', dtype=dtype).to(device)
mask_2d = torch.zeros(args.batch_size, args.ctx_len, device=device, dtype=torch.bool)
mask_2d[:, :args.visible_tokens] = True
mask_broadcast = rearrange(mask_2d, "b k -> b 1 1 k")
relative_position: LongTensor = _relative_position(
q_len=args.ctx_len,
k_len=args.ctx_len,
cached_autoregressive=False,
device=device,
)
relative_position_bucket: LongTensor = _relative_position_bucket(
relative_position, # shape (q_len, k_len)
bidirectional=True,
num_buckets=args.num_buckets,
)
with inference_mode():
pos_bias: FloatTensor = pos_emb(relative_position_bucket)
pos_bias = rearrange(pos_bias, "q k heads -> 1 heads q k")
# need stride of last dimension to be 1 in order to be eligible for torch sdp mem-eff kernels
# for some reason pos_bias.contiguous() doesn't achieve this, but cloning with contiguous format does
pos_bias = pos_bias.clone(memory_format=torch.contiguous_format)
match args.score_mod_algo:
case ScoreModAlgo.JumpTableAndEmb:
score_mod: _score_mod_signature = get_score_mod_jump_table_and_emb(
pos_emb.weight,
num_buckets=args.num_buckets,
max_distance=args.max_distance,
)
case ScoreModAlgo.RepeatEmbs:
score_mod: _score_mod_signature = get_score_mod_repeat_embs(
pos_emb.weight,
num_buckets=args.num_buckets,
max_distance=args.max_distance,
ctx_len=args.ctx_len,
)
case ScoreModAlgo.MinJumpTable:
score_mod: _score_mod_signature = get_score_min_jump_table(
pos_emb.weight,
num_buckets=args.num_buckets,
max_distance=args.max_distance,
)
case ScoreModAlgo.ComputeBuckets:
score_mod: _score_mod_signature = get_score_mod_compute_buckets(
pos_emb.weight,
num_buckets=args.num_buckets,
max_distance=args.max_distance,
)
case ScoreModAlgo.ReadBias:
if args.disable_block_mask and not args.disable_bias_mask:
masked_bias: FloatTensor = pos_bias.where(mask_broadcast, -1e5)
score_mod: _score_mod_signature = get_score_mod_read_masked_bias(masked_bias)
else:
score_mod: _score_mod_signature = get_score_mod_read_bias(pos_bias.squeeze(0))
sdpa_attn = SDPAAttn(qkv_proj, o_proj, args.head_dim).eval()
flex_attn = FlexAttn(qkv_proj, o_proj, args.head_dim, score_mod, flex_kernel_options).eval()
get_do_sdpa: Callable[[SDPAAttn], FloatTensor] = lambda sdpa: lambda: sdpa(hidden_states, pos_bias, mask=None if args.disable_bias_mask else mask_broadcast)
if args.disable_block_mask:
if args.score_mod_algo != ScoreModAlgo.ReadBias and args.visible_tokens < args.ctx_len:
print("WARN: you have masked tokens, but you've disabled the block mask. flex's result will be incorrect in padding positions.")
else:
print('INFO: you have disabled the block mask, but we still expect the result to match SDPA (i.e. visible_tokens==ctx_len or score_mod_algo == "read_bias").')
block_mask: Optional[BlockMask] = None
else:
block_mask: BlockMask = make_block_mask(mask_2d, args.mask_pad_queries)
print('sparsity:', block_mask.sparsity())
get_do_flex: Callable[[FlexAttn], FloatTensor] = lambda flex: lambda: flex(hidden_states, block_mask=block_mask)
qs = torch.tensor([.5, .75, .9, .95, .99, .999, .9999], device=device)
torch.set_printoptions(linewidth=200)
# SDPA doesn't have safe_softmax, so if we mask out its pad queries we'll get inf in those positions (whereas flex would return 0).
# but the downstream consumer of the embedded sequence shouldn't want to look at tokens in pad positions anyway.
# thus, we exclude pad tokens from the parity test; we know they're different but we also know they shouldn't be read anyway.
eligible_toks: int = args.visible_tokens if args.mask_pad_queries else args.ctx_len
# FlopCounterMode only expects the matmul operations dispatched under no_grad; different ops are dispatched under inference_mode
with no_grad():
print('tracing SDPA...')
sdpa_flop: int = get_flop_count(get_do_sdpa(sdpa_attn))
flex_flop: int = sdpa_flop
if args.parity_test:
print("testing SDPA for parity..")
sdpa_out: FloatTensor = get_do_sdpa(sdpa_attn)()
print("testing Flex for parity..")
flex_out: FloatTensor = get_do_flex(flex_attn)()
print('absmax() diff quantiles:')
print(str(qs.cpu()).removeprefix("tensor(").removesuffix(")"))
print(str(sdpa_out[:,:,:eligible_toks].to(torch.float32, copy=True).sub_(flex_out[:,:,:eligible_toks].float()).abs_().quantile(qs).cpu()).removeprefix("tensor(").removesuffix(")"))
with inference_mode():
print('benchmarking SDPA...')
get_flops(get_do_sdpa(sdpa_attn), flop_count=sdpa_flop)
print('benchmarking Flex...')
get_flops(get_do_flex(flex_attn), flop_count=flex_flop)
print('benchmarking SDPA (compiled)...')
do_sdpa_c: Callable[[], FloatTensor] = get_do_sdpa(torch.compile(sdpa_attn, dynamic=False))
get_flops(do_sdpa_c, flop_count=sdpa_flop)
print('benchmarking Flex (compiled)...')
do_flex_c: Callable[[], FloatTensor] = get_do_flex(torch.compile(flex_attn, dynamic=False))
get_flops(do_flex_c, flop_count=flex_flop)
if args.parity_test:
print("testing SDPA (compiled) for parity..")
sdpa_c_out: FloatTensor = do_sdpa_c()
print("testing Flex (compiled) for parity..")
flex_c_out: FloatTensor = do_flex_c()
print('absmax() diff quantiles:')
print(str(qs.cpu()).removeprefix("tensor(").removesuffix(")"))
print(str(sdpa_c_out[:,:,:eligible_toks].to(torch.float32, copy=True).sub_(flex_c_out[:,:,:eligible_toks].float()).abs_().quantile(qs).cpu()).removeprefix("tensor(").removesuffix(")"))
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--parity-test', action='store_true', help='compare implementation correctness')
parser.add_argument('--ckpt', type=Checkpoint, default=Checkpoint.XXL, choices=[t.value for t in Checkpoint])
parser.add_argument('--num-buckets', type=int, default=32)
parser.add_argument('--max-distance', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=5)
parser.add_argument('--ctx-len', type=int, default=512)
parser.add_argument('--head-dim', type=int, default=64)
parser.add_argument('--visible-tokens', type=int, default=512)
parser.add_argument('--block-m', type=int, default=None, help='kernel option BLOCK_M for flex attention')
parser.add_argument('--block-n', type=int, default=None, help='kernel option BLOCK_N for flex attention')
parser.add_argument('--disable-block-mask', action='store_true', help="when visible_tokens==ctx_len, you can disable the block mask to compare SDPA vs Flex without the advantage of sparsity. this is also useful for testing whether block mask brings its own cost in the no-op case.")
parser.add_argument('--disable-bias-mask', action='store_true', help="attend even to pad tokens")
parser.add_argument('--mask-pad-queries', action='store_true', help="in Flex, pad queries will attend to nothing, and rely on safe_softmax to prevent inf probabilities. this improves sparsity but may make parity tests fail (outputs in pad positions will be 0-valued).")
parser.add_argument('--score-mod-algo', type=ScoreModAlgo, default=ScoreModAlgo.JumpTableAndEmb, choices=[t.value for t in ScoreModAlgo])
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
main(Args(**vars(args)))
@Birch-san
Copy link
Author

Invoke via:

python -m t5_enc_attn_bench --score-mod-algo jump_table_and_emb --ckpt xxl

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment