Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created November 17, 2024 18:09
Show Gist options
  • Save Birch-san/b661d5e6812559280438a43ae4ff89ff to your computer and use it in GitHub Desktop.
Save Birch-san/b661d5e6812559280438a43ae4ff89ff to your computer and use it in GitHub Desktop.
Enabling --count-flops-early (run a model under FlopCounterMode before benchmarkign it) regresses the performance of the compiled model
import argparse
import math
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Optional
import torch
from einops import rearrange
from torch import (
BoolTensor,
FloatTensor,
LongTensor,
inference_mode,
no_grad,
)
from torch.nn import Embedding, Linear, Module
from torch.nn.functional import scaled_dot_product_attention
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def mpi_to_flops(ms_per_iter: float, flop_count: int) -> float:
iters_per_second = 1e3/ms_per_iter
return iters_per_second * flop_count
def fmt_flops(flops: int) -> str:
return f"{flops / 1e12:5.1f} TFLOP/s"
def get_flop_count(f: Callable[[], None]) -> int:
flop_counter = FlopCounterMode(display=True)
with flop_counter:
f()
return flop_counter.get_total_flops()
def get_flops(f: Callable[[], None], do_print=True, flop_count_early=False) -> float:
if flop_count_early:
flop_count: int = get_flop_count(f)
ms_per_iter: float = do_bench(f, warmup=1000, return_mode='median')
if not flop_count_early:
# we deliberately count FLOPs *after* benchmarking, because counting FLOPs of a (compiled) model *before*
# benchmarking, seems to regress the compiled model to use a slower implementation.
# this is a bug(?) that did not occur in torch 2.4.1 but started occurring in 2.5.0.
# https://x.com/Birchlabs/status/1847369302976188819
flop_count: int = get_flop_count(f)
flops: float = mpi_to_flops(ms_per_iter, flop_count)
if do_print:
print(fmt_flops(flops))
return flops
class Checkpoint(str, Enum):
Small = 'small'
Base = 'base'
Large = 'large'
XL = 'xl'
XXL = 'xxl'
ckpt_to_heads: dict[Checkpoint, int] = {
Checkpoint.Small: 6,
Checkpoint.Base: 12,
Checkpoint.Large: 16,
Checkpoint.XL: 32,
Checkpoint.XXL: 64,
}
ckpt_to_dim: dict[Checkpoint, int] = {
Checkpoint.Small: 512,
Checkpoint.Base: 768,
Checkpoint.Large: 1024,
Checkpoint.XL: 2048,
Checkpoint.XXL: 4096,
}
def _relative_position(
q_len: int,
k_len: Optional[int] = None,
cached_autoregressive=False,
device = torch.device('cpu'),
) -> LongTensor:
if k_len is None:
k_len = q_len
memory_position = torch.arange(k_len, dtype=torch.long, device=device).unsqueeze(0)
if cached_autoregressive:
# only the final query position will be kept, so that's the only one we'll compute
context_position = q_len - 1
else:
context_position = torch.arange(q_len, dtype=torch.long, device=device).unsqueeze(-1)
relative_position = memory_position - context_position # shape (q_len, k_len)
return relative_position
# based on HF implementation, Apache-licensed:
# https://github.com/huggingface/transformers/blob/9138935784583203fb5f61e8f581cdfdcd887e0f/src/transformers/models/t5/modeling_t5.py#L384
def _relative_position_bucket(
relative_position: LongTensor, bidirectional: bool, num_buckets=32, max_distance=128
) -> FloatTensor:
# in cached autoregressive inference, we have 1 query attending to n keys.
# we move the diagonal to be equivalent to having n queries attending to n keys.
*_, q_len, k_len = relative_position.shape
excess_keys: int = k_len - q_len
if bidirectional:
num_buckets //= 2
# I think the excess_keys offset here is never exercised in practice,
# because the only bidirectional case is encoder self-attn, which doesn't need KV-caching.
# still, it's probably the correct way to adjust the diagonal if you somehow had that use-case.
relative_buckets = torch.triu(torch.full_like(relative_position, num_buckets), diagonal=1 + excess_keys)
relative_position = torch.abs(relative_position)
else:
relative_buckets = torch.zeros_like(relative_position)
relative_position = -torch.tril(relative_position, diagonal=excess_keys)
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact = num_buckets // 2
is_small = relative_position < max_exact
# The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
relative_position_if_large = (
max_exact
+ (
torch.log(relative_position.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).long()
)
relative_position_if_large = relative_position_if_large.min(relative_position_if_large.new_tensor(num_buckets - 1))
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
class SDPAAttn(Module):
qkv_proj: Linear
o_proj: Linear
head_dim: int
def __init__(self, qkv_proj: Linear, o_proj: Linear, head_dim: int):
super().__init__()
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.head_dim = head_dim
def forward(
self,
x: FloatTensor,
position_bias: FloatTensor,
mask: Optional[BoolTensor] = None,
) -> FloatTensor:
qkv: FloatTensor = self.qkv_proj(x)
q, k, v = rearrange(
qkv, "batch seq (proj heads head_dim) -> proj batch heads seq head_dim", proj=3, head_dim=self.head_dim
).unbind()
if mask is not None:
assert mask.ndim == 4, "Expected [batch, heads, q, k] attention mask"
position_bias = position_bias.where(mask, -1e5)
a = scaled_dot_product_attention(
q,
k,
v,
# fused kernel requires last dimension of input to have stride 1.
attn_mask=position_bias.contiguous(),
dropout_p=0.0,
)
a = rearrange(a, "batch heads seq head_dim -> batch seq (heads head_dim)")
o = self.o_proj(a)
return o
@dataclass
class Args:
ckpt: Checkpoint
num_buckets: int
max_distance: int
batch_size: int
ctx_len: int
head_dim: int
visible_tokens: int
disable_bias_mask: bool
seed: int
parity_test: bool
count_flops_early: bool
def main(args: Args):
heads: int = ckpt_to_heads[args.ckpt]
dim: int = ckpt_to_dim[args.ckpt]
print(f'''
ckpt: {args.ckpt}
num_buckets: {args.num_buckets}
max_distance: {args.max_distance}
batch_size: {args.batch_size}
ctx_len: {args.ctx_len}
head_dim: {args.head_dim}
heads: {heads}
dim: {dim}
visible_tokens: {args.visible_tokens}
disable_bias_mask: {args.disable_bias_mask}
seed: {args.seed}
''')
dtype = torch.bfloat16
device = torch.device('cuda')
with device:
torch.manual_seed(args.seed)
pos_emb = Embedding(
num_embeddings=args.num_buckets,
embedding_dim=heads,
dtype=dtype,
).eval()
pos_emb.weight.requires_grad_(False)
torch.manual_seed(args.seed)
# strictly speaking we don't need to measure the perf of the projections,
# but maybe they'll enforce some contiguity/stride constraints somehow.
qkv_proj = Linear(
in_features=dim,
out_features=args.head_dim * heads * 3,
bias=False,
dtype=dtype,
).eval()
qkv_proj.weight.requires_grad_(False)
torch.manual_seed(args.seed)
o_proj = Linear(
in_features=args.head_dim * heads,
out_features=dim,
bias=False,
dtype=dtype,
).eval()
o_proj.weight.requires_grad_(False)
gen = torch.Generator(device='cpu')
gen.manual_seed(args.seed)
hidden_states = torch.randn(args.batch_size, args.ctx_len, dim, generator=gen, device='cpu', dtype=dtype).to(device)
mask_2d = torch.zeros(args.batch_size, args.ctx_len, device=device, dtype=torch.bool)
mask_2d[:, :args.visible_tokens] = True
mask_broadcast = rearrange(mask_2d, "b k -> b 1 1 k")
relative_position: LongTensor = _relative_position(
q_len=args.ctx_len,
k_len=args.ctx_len,
cached_autoregressive=False,
device=device,
)
relative_position_bucket: LongTensor = _relative_position_bucket(
relative_position, # shape (q_len, k_len)
bidirectional=True,
num_buckets=args.num_buckets,
)
with inference_mode():
pos_bias: FloatTensor = pos_emb(relative_position_bucket)
pos_bias = rearrange(pos_bias, "q k heads -> 1 heads q k")
# need stride of last dimension to be 1 in order to be eligible for torch sdp mem-eff kernels
# for some reason pos_bias.contiguous() doesn't achieve this, but cloning with contiguous format does
pos_bias = pos_bias.clone(memory_format=torch.contiguous_format)
sdpa_attn = SDPAAttn(qkv_proj, o_proj, args.head_dim).eval()
get_do_sdpa: Callable[[SDPAAttn], FloatTensor] = lambda sdpa: lambda: sdpa(hidden_states, pos_bias, mask=None if args.disable_bias_mask else mask_broadcast)
# FlopCounterMode only expects the matmul operations dispatched under no_grad; different ops are dispatched under inference_mode
with no_grad():
print('benchmarking SDPA...')
get_flops(get_do_sdpa(sdpa_attn), flop_count_early=args.count_flops_early)
print('benchmarking SDPA (compiled)...')
do_sdpa_c: Callable[[], FloatTensor] = get_do_sdpa(torch.compile(sdpa_attn, dynamic=False))
get_flops(do_sdpa_c, flop_count_early=args.count_flops_early)
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--parity-test', action='store_true', help='compare implementation correctness')
parser.add_argument('--ckpt', type=Checkpoint, default=Checkpoint.XXL, choices=[t.value for t in Checkpoint])
parser.add_argument('--num-buckets', type=int, default=32)
parser.add_argument('--max-distance', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=5)
parser.add_argument('--ctx-len', type=int, default=512)
parser.add_argument('--head-dim', type=int, default=64)
parser.add_argument('--visible-tokens', type=int, default=512)
parser.add_argument('--disable-bias-mask', action='store_true', help="attend even to pad tokens")
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--count-flops-early', action='store_true', help="If true, counts FLOPs before running benchmark (this is expected to regress performance in torch 2.5.0+).")
args = parser.parse_args()
main(Args(**vars(args)))
@Birch-san
Copy link
Author

Birch-san commented Nov 17, 2024

Fast mode:

python -m scripts.bench_repro --ckpt xxl

Slow mode (reproduces torch 2.5.0+ bug):

python -m scripts.bench_repro --ckpt xxl --count-flops-early

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment