Created
November 17, 2024 18:09
-
-
Save Birch-san/b661d5e6812559280438a43ae4ff89ff to your computer and use it in GitHub Desktop.
Enabling --count-flops-early (run a model under FlopCounterMode before benchmarkign it) regresses the performance of the compiled model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import math | |
from dataclasses import dataclass | |
from enum import Enum | |
from typing import Callable, Optional | |
import torch | |
from einops import rearrange | |
from torch import ( | |
BoolTensor, | |
FloatTensor, | |
LongTensor, | |
inference_mode, | |
no_grad, | |
) | |
from torch.nn import Embedding, Linear, Module | |
from torch.nn.functional import scaled_dot_product_attention | |
from torch.utils.flop_counter import FlopCounterMode | |
from triton.testing import do_bench | |
def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float: | |
iters_per_second = 1e3/ms_per_iter | |
return iters_per_second * flop_count | |
def fmt_flops(flops: int) -> str: | |
return f"{flops / 1e12:5.1f} TFLOP/s" | |
def get_flop_count(f: Callable[[], None]) -> int: | |
flop_counter = FlopCounterMode(display=True) | |
with flop_counter: | |
f() | |
return flop_counter.get_total_flops() | |
def get_flops(f: Callable[[], None], do_print=True, flop_count_early=False) -> float: | |
if flop_count_early: | |
flop_count: int = get_flop_count(f) | |
ms_per_iter: float = do_bench(f, warmup=1000, return_mode='median') | |
if not flop_count_early: | |
# we deliberately count FLOPs *after* benchmarking, because counting FLOPs of a (compiled) model *before* | |
# benchmarking, seems to regress the compiled model to use a slower implementation. | |
# this is a bug(?) that did not occur in torch 2.4.1 but started occurring in 2.5.0. | |
# https://x.com/Birchlabs/status/1847369302976188819 | |
flop_count: int = get_flop_count(f) | |
flops: float = mpi_to_flops(ms_per_iter, flop_count) | |
if do_print: | |
print(fmt_flops(flops)) | |
return flops | |
class Checkpoint(str, Enum): | |
Small = 'small' | |
Base = 'base' | |
Large = 'large' | |
XL = 'xl' | |
XXL = 'xxl' | |
ckpt_to_heads: dict[Checkpoint, int] = { | |
Checkpoint.Small: 6, | |
Checkpoint.Base: 12, | |
Checkpoint.Large: 16, | |
Checkpoint.XL: 32, | |
Checkpoint.XXL: 64, | |
} | |
ckpt_to_dim: dict[Checkpoint, int] = { | |
Checkpoint.Small: 512, | |
Checkpoint.Base: 768, | |
Checkpoint.Large: 1024, | |
Checkpoint.XL: 2048, | |
Checkpoint.XXL: 4096, | |
} | |
def _relative_position( | |
q_len: int, | |
k_len: Optional[int] = None, | |
cached_autoregressive=False, | |
device = torch.device('cpu'), | |
) -> LongTensor: | |
if k_len is None: | |
k_len = q_len | |
memory_position = torch.arange(k_len, dtype=torch.long, device=device).unsqueeze(0) | |
if cached_autoregressive: | |
# only the final query position will be kept, so that's the only one we'll compute | |
context_position = q_len - 1 | |
else: | |
context_position = torch.arange(q_len, dtype=torch.long, device=device).unsqueeze(-1) | |
relative_position = memory_position - context_position # shape (q_len, k_len) | |
return relative_position | |
# based on HF implementation, Apache-licensed: | |
# https://github.com/huggingface/transformers/blob/9138935784583203fb5f61e8f581cdfdcd887e0f/src/transformers/models/t5/modeling_t5.py#L384 | |
def _relative_position_bucket( | |
relative_position: LongTensor, bidirectional: bool, num_buckets=32, max_distance=128 | |
) -> FloatTensor: | |
# in cached autoregressive inference, we have 1 query attending to n keys. | |
# we move the diagonal to be equivalent to having n queries attending to n keys. | |
*_, q_len, k_len = relative_position.shape | |
excess_keys: int = k_len - q_len | |
if bidirectional: | |
num_buckets //= 2 | |
# I think the excess_keys offset here is never exercised in practice, | |
# because the only bidirectional case is encoder self-attn, which doesn't need KV-caching. | |
# still, it's probably the correct way to adjust the diagonal if you somehow had that use-case. | |
relative_buckets = torch.triu(torch.full_like(relative_position, num_buckets), diagonal=1 + excess_keys) | |
relative_position = torch.abs(relative_position) | |
else: | |
relative_buckets = torch.zeros_like(relative_position) | |
relative_position = -torch.tril(relative_position, diagonal=excess_keys) | |
# now relative_position is in the range [0, inf) | |
# half of the buckets are for exact increments in positions | |
max_exact = num_buckets // 2 | |
is_small = relative_position < max_exact | |
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance | |
relative_position_if_large = ( | |
max_exact | |
+ ( | |
torch.log(relative_position.float() / max_exact) | |
/ math.log(max_distance / max_exact) | |
* (num_buckets - max_exact) | |
).long() | |
) | |
relative_position_if_large = relative_position_if_large.min(relative_position_if_large.new_tensor(num_buckets - 1)) | |
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) | |
return relative_buckets | |
class SDPAAttn(Module): | |
qkv_proj: Linear | |
o_proj: Linear | |
head_dim: int | |
def __init__(self, qkv_proj: Linear, o_proj: Linear, head_dim: int): | |
super().__init__() | |
self.qkv_proj = qkv_proj | |
self.o_proj = o_proj | |
self.head_dim = head_dim | |
def forward( | |
self, | |
x: FloatTensor, | |
position_bias: FloatTensor, | |
mask: Optional[BoolTensor] = None, | |
) -> FloatTensor: | |
qkv: FloatTensor = self.qkv_proj(x) | |
q, k, v = rearrange( | |
qkv, "batch seq (proj heads head_dim) -> proj batch heads seq head_dim", proj=3, head_dim=self.head_dim | |
).unbind() | |
if mask is not None: | |
assert mask.ndim == 4, "Expected [batch, heads, q, k] attention mask" | |
position_bias = position_bias.where(mask, -1e5) | |
a = scaled_dot_product_attention( | |
q, | |
k, | |
v, | |
# fused kernel requires last dimension of input to have stride 1. | |
attn_mask=position_bias.contiguous(), | |
dropout_p=0.0, | |
) | |
a = rearrange(a, "batch heads seq head_dim -> batch seq (heads head_dim)") | |
o = self.o_proj(a) | |
return o | |
@dataclass | |
class Args: | |
ckpt: Checkpoint | |
num_buckets: int | |
max_distance: int | |
batch_size: int | |
ctx_len: int | |
head_dim: int | |
visible_tokens: int | |
disable_bias_mask: bool | |
seed: int | |
parity_test: bool | |
count_flops_early: bool | |
def main(args: Args): | |
heads: int = ckpt_to_heads[args.ckpt] | |
dim: int = ckpt_to_dim[args.ckpt] | |
print(f''' | |
ckpt: {args.ckpt} | |
num_buckets: {args.num_buckets} | |
max_distance: {args.max_distance} | |
batch_size: {args.batch_size} | |
ctx_len: {args.ctx_len} | |
head_dim: {args.head_dim} | |
heads: {heads} | |
dim: {dim} | |
visible_tokens: {args.visible_tokens} | |
disable_bias_mask: {args.disable_bias_mask} | |
seed: {args.seed} | |
''') | |
dtype = torch.bfloat16 | |
device = torch.device('cuda') | |
with device: | |
torch.manual_seed(args.seed) | |
pos_emb = Embedding( | |
num_embeddings=args.num_buckets, | |
embedding_dim=heads, | |
dtype=dtype, | |
).eval() | |
pos_emb.weight.requires_grad_(False) | |
torch.manual_seed(args.seed) | |
# strictly speaking we don't need to measure the perf of the projections, | |
# but maybe they'll enforce some contiguity/stride constraints somehow. | |
qkv_proj = Linear( | |
in_features=dim, | |
out_features=args.head_dim * heads * 3, | |
bias=False, | |
dtype=dtype, | |
).eval() | |
qkv_proj.weight.requires_grad_(False) | |
torch.manual_seed(args.seed) | |
o_proj = Linear( | |
in_features=args.head_dim * heads, | |
out_features=dim, | |
bias=False, | |
dtype=dtype, | |
).eval() | |
o_proj.weight.requires_grad_(False) | |
gen = torch.Generator(device='cpu') | |
gen.manual_seed(args.seed) | |
hidden_states = torch.randn(args.batch_size, args.ctx_len, dim, generator=gen, device='cpu', dtype=dtype).to(device) | |
mask_2d = torch.zeros(args.batch_size, args.ctx_len, device=device, dtype=torch.bool) | |
mask_2d[:, :args.visible_tokens] = True | |
mask_broadcast = rearrange(mask_2d, "b k -> b 1 1 k") | |
relative_position: LongTensor = _relative_position( | |
q_len=args.ctx_len, | |
k_len=args.ctx_len, | |
cached_autoregressive=False, | |
device=device, | |
) | |
relative_position_bucket: LongTensor = _relative_position_bucket( | |
relative_position, # shape (q_len, k_len) | |
bidirectional=True, | |
num_buckets=args.num_buckets, | |
) | |
with inference_mode(): | |
pos_bias: FloatTensor = pos_emb(relative_position_bucket) | |
pos_bias = rearrange(pos_bias, "q k heads -> 1 heads q k") | |
# need stride of last dimension to be 1 in order to be eligible for torch sdp mem-eff kernels | |
# for some reason pos_bias.contiguous() doesn't achieve this, but cloning with contiguous format does | |
pos_bias = pos_bias.clone(memory_format=torch.contiguous_format) | |
sdpa_attn = SDPAAttn(qkv_proj, o_proj, args.head_dim).eval() | |
get_do_sdpa: Callable[[SDPAAttn], FloatTensor] = lambda sdpa: lambda: sdpa(hidden_states, pos_bias, mask=None if args.disable_bias_mask else mask_broadcast) | |
# FlopCounterMode only expects the matmul operations dispatched under no_grad; different ops are dispatched under inference_mode | |
with no_grad(): | |
print('benchmarking SDPA...') | |
get_flops(get_do_sdpa(sdpa_attn), flop_count_early=args.count_flops_early) | |
print('benchmarking SDPA (compiled)...') | |
do_sdpa_c: Callable[[], FloatTensor] = get_do_sdpa(torch.compile(sdpa_attn, dynamic=False)) | |
get_flops(do_sdpa_c, flop_count_early=args.count_flops_early) | |
pass | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--parity-test', action='store_true', help='compare implementation correctness') | |
parser.add_argument('--ckpt', type=Checkpoint, default=Checkpoint.XXL, choices=[t.value for t in Checkpoint]) | |
parser.add_argument('--num-buckets', type=int, default=32) | |
parser.add_argument('--max-distance', type=int, default=128) | |
parser.add_argument('--batch-size', type=int, default=5) | |
parser.add_argument('--ctx-len', type=int, default=512) | |
parser.add_argument('--head-dim', type=int, default=64) | |
parser.add_argument('--visible-tokens', type=int, default=512) | |
parser.add_argument('--disable-bias-mask', action='store_true', help="attend even to pad tokens") | |
parser.add_argument('--seed', type=int, default=42) | |
parser.add_argument('--count-flops-early', action='store_true', help="If true, counts FLOPs before running benchmark (this is expected to regress performance in torch 2.5.0+).") | |
args = parser.parse_args() | |
main(Args(**vars(args))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Fast mode:
Slow mode (reproduces torch 2.5.0+ bug):