Last active
October 23, 2024 20:05
-
-
Save Birch-san/d2afeaf59a557ca0a5c807accfe7bfd2 to your computer and use it in GitHub Desktop.
Benchmark various ways of doing T5 Encoder flex_attention against SDPA
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from enum import Enum | |
from typing import Callable, Optional, Any | |
from einops import rearrange | |
from dataclasses import dataclass | |
import math | |
import torch | |
from torch import FloatTensor, LongTensor, IntTensor, BoolTensor, ByteTensor, no_grad, inference_mode | |
from torch.nn import Embedding, Linear, Module | |
from torch.nn.attention.flex_attention import BlockMask, flex_attention, create_block_mask, _score_mod_signature, _mask_mod_signature | |
from torch.nn.functional import scaled_dot_product_attention | |
from torch.utils.flop_counter import FlopCounterMode | |
from triton.testing import do_bench | |
import argparse | |
def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float: | |
iters_per_second = 1e3/ms_per_iter | |
return iters_per_second * flop_count | |
def fmt_flops(flops: int) -> str: | |
return f"{flops / 1e12:5.1f} TFLOP/s" | |
def get_flop_count(f: Callable[[], None]) -> int: | |
flop_counter = FlopCounterMode(display=True) | |
with flop_counter: | |
f() | |
return flop_counter.get_total_flops() | |
def get_flops(f: Callable[[], None], do_print=True, flop_count: Optional[int] = None) -> float: | |
ms_per_iter: float = do_bench(f, warmup=1000, return_mode='median') | |
if flop_count is None: | |
# we deliberately count FLOPs *after* benchmarking, because counting FLOPs of a (compiled) model *before* | |
# benchmarking, seems to regress the compiled model to use a slower implementation. | |
# this is a bug(?) that did not occur in torch 2.4.1 but started occurring in 2.5.0. | |
# https://x.com/Birchlabs/status/1847369302976188819 | |
flop_count: int = get_flop_count(f) | |
flops: float = mpi_to_flops(ms_per_iter, flop_count) | |
if do_print: | |
print(fmt_flops(flops)) | |
return flops | |
class Checkpoint(str, Enum): | |
Small = 'small' | |
Base = 'base' | |
Large = 'large' | |
XL = 'xl' | |
XXL = 'xxl' | |
ckpt_to_heads: dict[Checkpoint, int] = { | |
Checkpoint.Small: 6, | |
Checkpoint.Base: 12, | |
Checkpoint.Large: 16, | |
Checkpoint.XL: 32, | |
Checkpoint.XXL: 64, | |
} | |
ckpt_to_dim: dict[Checkpoint, int] = { | |
Checkpoint.Small: 512, | |
Checkpoint.Base: 768, | |
Checkpoint.Large: 1024, | |
Checkpoint.XL: 2048, | |
Checkpoint.XXL: 4096, | |
} | |
class ScoreModAlgo(str, Enum): | |
JumpTableAndEmb = 'jump_table_and_emb' | |
"probably the most reasonable algorithm" | |
RepeatEmbs = 'repeat_embs' | |
"more copying of embeddings, less arithmetic on positions" | |
ComputeBuckets = 'compute_buckets' | |
"no jump table; compute the bucket directly from the position. entails float32 ratio-of-logarithms arithmetic in the score_mod." | |
ReadBias = 'read_bias' | |
"[fastest] nothing clever; maximizes IO" | |
MinJumpTable = 'min_jump_table' | |
"smallest jump table, using complicated arithmetic" | |
def _relative_position( | |
q_len: int, | |
k_len: Optional[int] = None, | |
cached_autoregressive=False, | |
device = torch.device('cpu'), | |
) -> LongTensor: | |
if k_len is None: | |
k_len = q_len | |
memory_position = torch.arange(k_len, dtype=torch.long, device=device).unsqueeze(0) | |
if cached_autoregressive: | |
# only the final query position will be kept, so that's the only one we'll compute | |
context_position = q_len - 1 | |
else: | |
context_position = torch.arange(q_len, dtype=torch.long, device=device).unsqueeze(-1) | |
relative_position = memory_position - context_position # shape (q_len, k_len) | |
return relative_position | |
# based on HF implementation, Apache-licensed: | |
# https://github.com/huggingface/transformers/blob/9138935784583203fb5f61e8f581cdfdcd887e0f/src/transformers/models/t5/modeling_t5.py#L384 | |
def _relative_position_bucket( | |
relative_position: LongTensor, bidirectional: bool, num_buckets=32, max_distance=128 | |
) -> FloatTensor: | |
# in cached autoregressive inference, we have 1 query attending to n keys. | |
# we move the diagonal to be equivalent to having n queries attending to n keys. | |
*_, q_len, k_len = relative_position.shape | |
excess_keys: int = k_len - q_len | |
if bidirectional: | |
num_buckets //= 2 | |
# I think the excess_keys offset here is never exercised in practice, | |
# because the only bidirectional case is encoder self-attn, which doesn't need KV-caching. | |
# still, it's probably the correct way to adjust the diagonal if you somehow had that use-case. | |
relative_buckets = torch.triu(torch.full_like(relative_position, num_buckets), diagonal=1 + excess_keys) | |
relative_position = torch.abs(relative_position) | |
else: | |
relative_buckets = torch.zeros_like(relative_position) | |
relative_position = -torch.tril(relative_position, diagonal=excess_keys) | |
# now relative_position is in the range [0, inf) | |
# half of the buckets are for exact increments in positions | |
max_exact = num_buckets // 2 | |
is_small = relative_position < max_exact | |
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance | |
relative_position_if_large = ( | |
max_exact | |
+ ( | |
torch.log(relative_position.float() / max_exact) | |
/ math.log(max_distance / max_exact) | |
* (num_buckets - max_exact) | |
).long() | |
) | |
relative_position_if_large = relative_position_if_large.min(relative_position_if_large.new_tensor(num_buckets - 1)) | |
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) | |
return relative_buckets | |
class SDPAAttn(Module): | |
qkv_proj: Linear | |
o_proj: Linear | |
head_dim: int | |
def __init__(self, qkv_proj: Linear, o_proj: Linear, head_dim: int): | |
super().__init__() | |
self.qkv_proj = qkv_proj | |
self.o_proj = o_proj | |
self.head_dim = head_dim | |
def forward( | |
self, | |
x: FloatTensor, | |
position_bias: FloatTensor, | |
mask: Optional[BoolTensor] = None, | |
) -> FloatTensor: | |
qkv: FloatTensor = self.qkv_proj(x) | |
q, k, v = rearrange( | |
qkv, "batch seq (proj heads head_dim) -> proj batch heads seq head_dim", proj=3, head_dim=self.head_dim | |
).unbind() | |
if mask is not None: | |
assert mask.ndim == 4, "Expected [batch, heads, q, k] attention mask" | |
position_bias = position_bias.where(mask, -1e5) | |
a = scaled_dot_product_attention( | |
q, | |
k, | |
v, | |
# fused kernel requires last dimension of input to have stride 1. | |
attn_mask=position_bias.contiguous(), | |
dropout_p=0.0, | |
) | |
a = rearrange(a, "batch heads seq head_dim -> batch seq (heads head_dim)") | |
o = self.o_proj(a) | |
return o | |
class FlexAttn(Module): | |
qkv_proj: Linear | |
o_proj: Linear | |
head_dim: int | |
score_mod: _score_mod_signature | |
flex_kernel_options: dict[str, Any] | |
def __init__( | |
self, | |
qkv_proj: Linear, | |
o_proj: Linear, | |
head_dim: int, | |
score_mod: _score_mod_signature, | |
flex_kernel_options: dict[str, Any] = {}, | |
): | |
super().__init__() | |
self.qkv_proj = qkv_proj | |
self.o_proj = o_proj | |
self.head_dim = head_dim | |
self.score_mod = score_mod | |
self.flex_kernel_options = flex_kernel_options | |
def forward( | |
self, | |
x: FloatTensor, | |
block_mask: Optional[BlockMask] = None, | |
) -> FloatTensor: | |
qkv: FloatTensor = self.qkv_proj(x) | |
q, k, v = rearrange( | |
qkv, "batch seq (proj heads head_dim) -> proj batch heads seq head_dim", proj=3, head_dim=self.head_dim | |
).unbind() | |
a = flex_attention( | |
q, | |
k, | |
v, | |
score_mod=self.score_mod, | |
block_mask=block_mask, | |
kernel_options=self.flex_kernel_options, | |
) | |
a = rearrange(a, "batch heads seq head_dim -> batch seq (heads head_dim)") | |
o = self.o_proj(a) | |
return o | |
def get_score_mod_jump_table_and_emb( | |
emb_weight: FloatTensor, | |
num_buckets=32, | |
max_distance=128, | |
ctx_len=512, | |
) -> _score_mod_signature: | |
""" | |
Minimize the amount that we read into the score_mod kernel to just the emb_weight and a jump table: | |
bfloat16 [num_buckets=32, heads=6~64] | |
int8 [ctx_len=512] # this would fit in uint4 if torch supported it | |
it *is* possible to minimize the jump table's size to ~91 instead of 512, by introducing more offsets and fallbacks to the jump arithmetic. | |
I have an implementation of that, and it was slightly slower. probably not worth the code complexity, but it had a cool property whereby you don't need to know ctx_len. | |
""" | |
half_buckets = num_buckets // 2 | |
max_exact = half_buckets // 2 | |
relpos_to_bucket: LongTensor = torch.arange(ctx_len, device=emb_weight.device, dtype=torch.float32).div_(max_exact).log_().mul_((half_buckets - max_exact) / math.log(max_distance / max_exact)).long().clamp_max(max_exact-1).add_(max_exact).byte() | |
relpos_to_bucket[:max_exact].copy_(torch.arange(max_exact, device=emb_weight.device)) | |
def score_mod( | |
score: FloatTensor, | |
b: IntTensor, | |
h: IntTensor, | |
q_idx: IntTensor, | |
kv_idx: IntTensor, | |
) -> FloatTensor: | |
relpos = (kv_idx - q_idx).abs_() | |
relpos_buckets = relpos_to_bucket[relpos] | |
relpos_buckets.add_(kv_idx > q_idx, alpha=half_buckets) | |
return score + emb_weight[relpos_buckets.int(), h] | |
return score_mod | |
def get_score_mod_repeat_embs( | |
emb_weight: FloatTensor, | |
num_buckets=32, | |
max_distance=128, | |
ctx_len=512, | |
) -> _score_mod_signature: | |
""" | |
Minimize the computation in the score_mod kernel by repeating the embedding over every position we'll use: | |
bfloat16 [positions=num_buckets, heads=6~64] | |
the position arithmetic is quite fiddly, I am surprised it was allclose on the first try; don't be surprised if there's still an off-by-one error somewhere. | |
""" | |
half_buckets = num_buckets // 2 | |
max_exact = half_buckets // 2 | |
max_relpos_ix: int = math.ceil(math.exp((max_exact - 1) * math.log(max_distance/max_exact) / (half_buckets - max_exact)) * max_exact) | |
relpos_to_bucket: LongTensor = torch.arange(ctx_len, device=emb_weight.device, dtype=torch.float32).div_(max_exact).log_().mul_((half_buckets - max_exact) / math.log(max_distance / max_exact)).long().clamp_max(max_exact-1).add_(max_exact).byte() | |
relpos_to_bucket[:max_exact].copy_(torch.arange(max_exact, device=emb_weight.device)) | |
relpos_to_bucket_bidi = torch.cat([relpos_to_bucket[:max_relpos_ix+1].flip(-1), relpos_to_bucket[1:max_relpos_ix+1]+half_buckets]) | |
relpos_to_emb = emb_weight[relpos_to_bucket_bidi.int()] | |
def score_mod( | |
score: FloatTensor, | |
b: IntTensor, | |
h: IntTensor, | |
q_idx: IntTensor, | |
kv_idx: IntTensor, | |
) -> FloatTensor: | |
return score + relpos_to_emb[(kv_idx - q_idx).clamp_(-max_relpos_ix, max_relpos_ix).add_(max_relpos_ix), h] | |
return score_mod | |
def get_score_min_jump_table( | |
emb_weight: FloatTensor, | |
num_buckets=32, | |
max_distance=128, | |
) -> _score_mod_signature: | |
""" | |
Minimize the amount that we read into the score_mod kernel to just the emb_weight and a (small) jump table: | |
bfloat16 [num_buckets=32, heads=6~64] | |
int8 [positions=~91] # this would fit in uint4 if torch supported it | |
""" | |
# encoder self-attn is bidirectional; each direction uses half the buckets | |
half_buckets = num_buckets // 2 | |
max_half_bucket_ix = half_buckets - 1 | |
# (in each direction) | |
# half of the buckets are allocated for exact increments in position | |
# the other half are logarithmically growing bins in positions up to max_distance | |
max_exact = half_buckets // 2 | |
max_relpos_ix: int = math.ceil(math.exp((max_exact - 1) * math.log(max_distance/max_exact) / (half_buckets - max_exact)) * max_exact) | |
distant_relpos_to_bucket: ByteTensor = torch.arange(max_exact, max_relpos_ix, device=emb_weight.device, dtype=torch.float32).div_(max_exact).log_().mul_((half_buckets - max_exact) / math.log(max_distance / max_exact)).long().clamp_max(max_exact-1).add_(max_exact).byte() | |
def score_mod( | |
score: FloatTensor, | |
b: IntTensor, | |
h: IntTensor, | |
q_idx: IntTensor, | |
kv_idx: IntTensor, | |
) -> FloatTensor: | |
relative_position = (kv_idx - q_idx).abs_() | |
is_small = relative_position < max_exact | |
below_max_distance = relative_position < max_relpos_ix | |
distant_position = torch.where(below_max_distance, distant_relpos_to_bucket[(relative_position-max_exact).clamp_(0, max_relpos_ix-max_exact-1)], max_half_bucket_ix) | |
relpos_buckets: ByteTensor = torch.where(is_small, relative_position, distant_position) | |
relpos_buckets.add_(kv_idx > q_idx, alpha=half_buckets) | |
return score + emb_weight[relpos_buckets.int(), h] | |
return score_mod | |
def get_score_mod_compute_buckets( | |
emb_weight: FloatTensor, | |
num_buckets=32, | |
max_distance=128, | |
) -> _score_mod_signature: | |
""" | |
Maximize the computation in the score_mod kernel, but eliminate the need for a jump table. Only requires emb_weights, and doesn't need to know ctx_len: | |
bfloat16 [positions=num_buckets, heads=6~64] | |
""" | |
# note: only implemented for bidirectional (i.e. encoder self-attn) | |
# encoder self-attn is bidirectional; each direction uses half the buckets | |
half_buckets = num_buckets // 2 | |
max_half_bucket_ix = half_buckets - 1 | |
# (in each direction) | |
# half of the buckets are allocated for exact increments in position | |
# the other half are logarithmically growing bins in positions up to max_distance | |
max_exact = half_buckets // 2 | |
relpos_coeff: float = (half_buckets - max_exact) / math.log(max_distance / max_exact) | |
def score_mod( | |
score: FloatTensor, | |
b: IntTensor, | |
h: IntTensor, | |
q_idx: IntTensor, | |
kv_idx: IntTensor, | |
) -> FloatTensor: | |
# NOTE: in decoder self-attn, if kv-cached decoding is used: we would see | |
# only the final query element, and its position would need adjusting. | |
# q_idx would not run from 0 to q_len-1, | |
# the only query would be q_idx=0, and you'd want to add q_len-1 to it, e.g. | |
# ctx_pos = q_idx + q_len-1 | |
# or just: | |
# ctx_pos = q_len-1 | |
ctx_pos = q_idx | |
relative_position = (kv_idx - ctx_pos).abs_() | |
is_small = relative_position < max_exact | |
# NOTE: in decoder self-attn, if kv-cached decoding is used: we would need | |
# to move the diagonal of this upper-triangular mask, so that no key gains | |
# a positive position (they'd all be behind the query). | |
# maybe like this: | |
# excess_keys: int = k_len - q_len | |
# relative_buckets = torch.where(kv_idx > ctx_pos + excess_keys, half_buckets, 0) | |
# errr or maybe just: | |
# relative_buckets = 0 | |
relative_buckets = torch.where(kv_idx > ctx_pos, half_buckets, 0) | |
relpos_distant = (max_exact + relative_position.to(torch.float32, copy=True).div_(max_exact).log_().mul_(relpos_coeff).long()).clamp_max_(max_half_bucket_ix) | |
relative_buckets.add_(torch.where(is_small, relative_position, relpos_distant)) | |
return score + emb_weight[relative_buckets, h] | |
return score_mod | |
def get_score_mod_read_bias(bias: FloatTensor) -> _score_mod_signature: | |
""" | |
Maximize the amount that we read into the score_mod kernel: | |
bfloat16 [heads=6~64, ctx_len=512, ctx_len=512] | |
you'd still hope this would be competitive with cutlassF/cuDNN | |
""" | |
def score_mod( | |
score: FloatTensor, | |
b: IntTensor, | |
h: IntTensor, | |
q_idx: IntTensor, | |
kv_idx: IntTensor, | |
) -> FloatTensor: | |
return score + bias[h, q_idx, kv_idx] | |
return score_mod | |
def get_score_mod_read_masked_bias(bias: FloatTensor) -> _score_mod_signature: | |
""" | |
Maximize the amount that we read into the score_mod kernel: | |
bfloat16 [heads=6~64, ctx_len=512, ctx_len=512] | |
you'd still hope this would be competitive with cutlassF/cuDNN | |
""" | |
def score_mod( | |
score: FloatTensor, | |
b: IntTensor, | |
h: IntTensor, | |
q_idx: IntTensor, | |
kv_idx: IntTensor, | |
) -> FloatTensor: | |
return score + bias[b, h, q_idx, kv_idx] | |
return score_mod | |
def make_mask_mod(mask: BoolTensor, mask_pad_queries=False) -> _mask_mod_signature: | |
if mask_pad_queries: | |
# faster (more sparsity) but outputs in pad positions will be 0-valued so you will fail parity tests against SDPA, where our masked bias doesn't do this | |
def mask_mod( | |
batch: IntTensor, | |
head: IntTensor, | |
q_idx: IntTensor, | |
kv_idx: IntTensor, | |
) -> BoolTensor: | |
return mask[batch, kv_idx] & mask[batch, q_idx] | |
else: | |
def mask_mod( | |
batch: IntTensor, | |
head: IntTensor, | |
q_idx: IntTensor, | |
kv_idx: IntTensor, | |
) -> BoolTensor: | |
return mask[batch, kv_idx] | |
return mask_mod | |
def make_block_mask(mask: BoolTensor, mask_pad_queries=False) -> BlockMask: | |
seq_len: int = mask.size(-1) | |
mask_mod: _mask_mod_signature = make_mask_mod(mask, mask_pad_queries) | |
block_mask: BlockMask = create_block_mask( | |
mask_mod=mask_mod, | |
B=mask.size(0), | |
H=1, # broadcast over all heads | |
Q_LEN=seq_len, | |
KV_LEN=seq_len, | |
) | |
return block_mask | |
@dataclass | |
class Args: | |
ckpt: Checkpoint | |
num_buckets: int | |
max_distance: int | |
batch_size: int | |
ctx_len: int | |
head_dim: int | |
visible_tokens: int | |
block_m: Optional[int] | |
block_n: Optional[int] | |
disable_block_mask: bool | |
disable_bias_mask: bool | |
score_mod_algo: ScoreModAlgo | |
seed: int | |
parity_test: bool | |
mask_pad_queries: bool | |
def main(args: Args): | |
heads: int = ckpt_to_heads[args.ckpt] | |
dim: int = ckpt_to_dim[args.ckpt] | |
flex_kernel_options: dict[str, Any] = {} | |
if args.block_m is not None: | |
flex_kernel_options['BLOCK_M'] = args.block_m | |
if args.block_n is not None: | |
flex_kernel_options['BLOCK_N'] = args.block_n | |
print(f''' | |
ckpt: {args.ckpt} | |
num_buckets: {args.num_buckets} | |
max_distance: {args.max_distance} | |
batch_size: {args.batch_size} | |
ctx_len: {args.ctx_len} | |
head_dim: {args.head_dim} | |
heads: {heads} | |
dim: {dim} | |
visible_tokens: {args.visible_tokens} | |
disable_block_mask: {args.disable_block_mask} | |
disable_bias_mask: {args.disable_bias_mask} | |
mask_pad_queries: {args.mask_pad_queries} | |
score_mod_algo: {args.score_mod_algo} | |
seed: {args.seed} | |
flex_kernel_options: {flex_kernel_options} | |
''') | |
dtype = torch.bfloat16 | |
device = torch.device('cuda') | |
with device: | |
torch.manual_seed(args.seed) | |
pos_emb = Embedding( | |
num_embeddings=args.num_buckets, | |
embedding_dim=heads, | |
dtype=dtype, | |
).eval() | |
pos_emb.weight.requires_grad_(False) | |
torch.manual_seed(args.seed) | |
# strictly speaking we don't need to measure the perf of the projections, | |
# but maybe they'll enforce some contiguity/stride constraints somehow. | |
qkv_proj = Linear( | |
in_features=dim, | |
out_features=args.head_dim * heads * 3, | |
bias=False, | |
dtype=dtype, | |
).eval() | |
qkv_proj.weight.requires_grad_(False) | |
torch.manual_seed(args.seed) | |
o_proj = Linear( | |
in_features=args.head_dim * heads, | |
out_features=dim, | |
bias=False, | |
dtype=dtype, | |
).eval() | |
o_proj.weight.requires_grad_(False) | |
gen = torch.Generator(device='cpu') | |
gen.manual_seed(args.seed) | |
hidden_states = torch.randn(args.batch_size, args.ctx_len, dim, generator=gen, device='cpu', dtype=dtype).to(device) | |
mask_2d = torch.zeros(args.batch_size, args.ctx_len, device=device, dtype=torch.bool) | |
mask_2d[:, :args.visible_tokens] = True | |
mask_broadcast = rearrange(mask_2d, "b k -> b 1 1 k") | |
relative_position: LongTensor = _relative_position( | |
q_len=args.ctx_len, | |
k_len=args.ctx_len, | |
cached_autoregressive=False, | |
device=device, | |
) | |
relative_position_bucket: LongTensor = _relative_position_bucket( | |
relative_position, # shape (q_len, k_len) | |
bidirectional=True, | |
num_buckets=args.num_buckets, | |
) | |
with inference_mode(): | |
pos_bias: FloatTensor = pos_emb(relative_position_bucket) | |
pos_bias = rearrange(pos_bias, "q k heads -> 1 heads q k") | |
# need stride of last dimension to be 1 in order to be eligible for torch sdp mem-eff kernels | |
# for some reason pos_bias.contiguous() doesn't achieve this, but cloning with contiguous format does | |
pos_bias = pos_bias.clone(memory_format=torch.contiguous_format) | |
match args.score_mod_algo: | |
case ScoreModAlgo.JumpTableAndEmb: | |
score_mod: _score_mod_signature = get_score_mod_jump_table_and_emb( | |
pos_emb.weight, | |
num_buckets=args.num_buckets, | |
max_distance=args.max_distance, | |
) | |
case ScoreModAlgo.RepeatEmbs: | |
score_mod: _score_mod_signature = get_score_mod_repeat_embs( | |
pos_emb.weight, | |
num_buckets=args.num_buckets, | |
max_distance=args.max_distance, | |
ctx_len=args.ctx_len, | |
) | |
case ScoreModAlgo.MinJumpTable: | |
score_mod: _score_mod_signature = get_score_min_jump_table( | |
pos_emb.weight, | |
num_buckets=args.num_buckets, | |
max_distance=args.max_distance, | |
) | |
case ScoreModAlgo.ComputeBuckets: | |
score_mod: _score_mod_signature = get_score_mod_compute_buckets( | |
pos_emb.weight, | |
num_buckets=args.num_buckets, | |
max_distance=args.max_distance, | |
) | |
case ScoreModAlgo.ReadBias: | |
if args.disable_block_mask and not args.disable_bias_mask: | |
masked_bias: FloatTensor = pos_bias.where(mask_broadcast, -1e5) | |
score_mod: _score_mod_signature = get_score_mod_read_masked_bias(masked_bias) | |
else: | |
score_mod: _score_mod_signature = get_score_mod_read_bias(pos_bias.squeeze(0)) | |
sdpa_attn = SDPAAttn(qkv_proj, o_proj, args.head_dim).eval() | |
flex_attn = FlexAttn(qkv_proj, o_proj, args.head_dim, score_mod, flex_kernel_options).eval() | |
get_do_sdpa: Callable[[SDPAAttn], FloatTensor] = lambda sdpa: lambda: sdpa(hidden_states, pos_bias, mask=None if args.disable_bias_mask else mask_broadcast) | |
if args.disable_block_mask: | |
if args.score_mod_algo != ScoreModAlgo.ReadBias and args.visible_tokens < args.ctx_len: | |
print("WARN: you have masked tokens, but you've disabled the block mask. flex's result will be incorrect in padding positions.") | |
else: | |
print('INFO: you have disabled the block mask, but we still expect the result to match SDPA (i.e. visible_tokens==ctx_len or score_mod_algo == "read_bias").') | |
block_mask: Optional[BlockMask] = None | |
else: | |
block_mask: BlockMask = make_block_mask(mask_2d, args.mask_pad_queries) | |
print('sparsity:', block_mask.sparsity()) | |
get_do_flex: Callable[[FlexAttn], FloatTensor] = lambda flex: lambda: flex(hidden_states, block_mask=block_mask) | |
qs = torch.tensor([.5, .75, .9, .95, .99, .999, .9999], device=device) | |
torch.set_printoptions(linewidth=200) | |
# SDPA doesn't have safe_softmax, so if we mask out its pad queries we'll get inf in those positions (whereas flex would return 0). | |
# but the downstream consumer of the embedded sequence shouldn't want to look at tokens in pad positions anyway. | |
# thus, we exclude pad tokens from the parity test; we know they're different but we also know they shouldn't be read anyway. | |
eligible_toks: int = args.visible_tokens if args.mask_pad_queries else args.ctx_len | |
# FlopCounterMode only expects the matmul operations dispatched under no_grad; different ops are dispatched under inference_mode | |
with no_grad(): | |
print('tracing SDPA...') | |
sdpa_flop: int = get_flop_count(get_do_sdpa(sdpa_attn)) | |
flex_flop: int = sdpa_flop | |
if args.parity_test: | |
print("testing SDPA for parity..") | |
sdpa_out: FloatTensor = get_do_sdpa(sdpa_attn)() | |
print("testing Flex for parity..") | |
flex_out: FloatTensor = get_do_flex(flex_attn)() | |
print('absmax() diff quantiles:') | |
print(str(qs.cpu()).removeprefix("tensor(").removesuffix(")")) | |
print(str(sdpa_out[:,:,:eligible_toks].to(torch.float32, copy=True).sub_(flex_out[:,:,:eligible_toks].float()).abs_().quantile(qs).cpu()).removeprefix("tensor(").removesuffix(")")) | |
with inference_mode(): | |
print('benchmarking SDPA...') | |
get_flops(get_do_sdpa(sdpa_attn), flop_count=sdpa_flop) | |
print('benchmarking Flex...') | |
get_flops(get_do_flex(flex_attn), flop_count=flex_flop) | |
print('benchmarking SDPA (compiled)...') | |
do_sdpa_c: Callable[[], FloatTensor] = get_do_sdpa(torch.compile(sdpa_attn, dynamic=False)) | |
get_flops(do_sdpa_c, flop_count=sdpa_flop) | |
print('benchmarking Flex (compiled)...') | |
do_flex_c: Callable[[], FloatTensor] = get_do_flex(torch.compile(flex_attn, dynamic=False)) | |
get_flops(do_flex_c, flop_count=flex_flop) | |
if args.parity_test: | |
print("testing SDPA (compiled) for parity..") | |
sdpa_c_out: FloatTensor = do_sdpa_c() | |
print("testing Flex (compiled) for parity..") | |
flex_c_out: FloatTensor = do_flex_c() | |
print('absmax() diff quantiles:') | |
print(str(qs.cpu()).removeprefix("tensor(").removesuffix(")")) | |
print(str(sdpa_c_out[:,:,:eligible_toks].to(torch.float32, copy=True).sub_(flex_c_out[:,:,:eligible_toks].float()).abs_().quantile(qs).cpu()).removeprefix("tensor(").removesuffix(")")) | |
pass | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--parity-test', action='store_true', help='compare implementation correctness') | |
parser.add_argument('--ckpt', type=Checkpoint, default=Checkpoint.XXL, choices=[t.value for t in Checkpoint]) | |
parser.add_argument('--num-buckets', type=int, default=32) | |
parser.add_argument('--max-distance', type=int, default=128) | |
parser.add_argument('--batch-size', type=int, default=5) | |
parser.add_argument('--ctx-len', type=int, default=512) | |
parser.add_argument('--head-dim', type=int, default=64) | |
parser.add_argument('--visible-tokens', type=int, default=512) | |
parser.add_argument('--block-m', type=int, default=None, help='kernel option BLOCK_M for flex attention') | |
parser.add_argument('--block-n', type=int, default=None, help='kernel option BLOCK_N for flex attention') | |
parser.add_argument('--disable-block-mask', action='store_true', help="when visible_tokens==ctx_len, you can disable the block mask to compare SDPA vs Flex without the advantage of sparsity. this is also useful for testing whether block mask brings its own cost in the no-op case.") | |
parser.add_argument('--disable-bias-mask', action='store_true', help="attend even to pad tokens") | |
parser.add_argument('--mask-pad-queries', action='store_true', help="in Flex, pad queries will attend to nothing, and rely on safe_softmax to prevent inf probabilities. this improves sparsity but may make parity tests fail (outputs in pad positions will be 0-valued).") | |
parser.add_argument('--score-mod-algo', type=ScoreModAlgo, default=ScoreModAlgo.JumpTableAndEmb, choices=[t.value for t in ScoreModAlgo]) | |
parser.add_argument('--seed', type=int, default=42) | |
args = parser.parse_args() | |
main(Args(**vars(args))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Invoke via: