Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active December 31, 2025 23:25
Show Gist options
  • Select an option

  • Save Birch-san/ded314ee2c67fe47a5537a86ef61f4e2 to your computer and use it in GitHub Desktop.

Select an option

Save Birch-san/ded314ee2c67fe47a5537a86ef61f4e2 to your computer and use it in GitHub Desktop.
how to benchmark and profile with cudagraphs
# the attention layer from this benchmark is from modded-nanogpt, MIT-licensed
# https://github.com/KellerJordan/modded-nanogpt
from typing import Callable, Optional
import math
from pathlib import Path
from dataclasses import dataclass
from functools import partial
import torch
from torch import nn, Tensor, FloatTensor, IntTensor
import torch.nn.functional as F
from torch.testing import assert_close
from torch.profiler import ProfilerActivity, profile
#---
# do_bench is from triton, MIT-licensed
# https://github.com/triton-lang/triton/blob/11ec6354/python/triton/testing.py#L127
# with fixes by Alex Birch to clear grads before warmup steps too (cudagraphs models are sensitive to tensor reuse)
from triton import runtime
from triton.testing import _summarize_statistics
def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, return_mode="mean"):
"""
Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with
the 20-th and 80-th performance percentile.
:param fn: Function to benchmark
:type fn: Callable
:param warmup: Warmup time (in ms)
:type warmup: int
:param rep: Repetition time (in ms)
:type rep: int
:param grad_to_none: Reset the gradient of the provided tensor to None
:type grad_to_none: torch.tensor, optional
:param quantiles: Performance percentile to return in addition to the median.
:type quantiles: list[float], optional
:param return_mode: The statistical measure to return. Options are "min", "max", "mean", "median", or "all". Default is "mean".
:type return_mode: str
"""
assert return_mode in ["min", "max", "mean", "median", "all"]
di = runtime.driver.active.get_device_interface()
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
fn()
di.synchronize()
cache = runtime.driver.active.get_empty_cache_for_benchmark()
# Estimate the runtime of the function
start_event = di.Event(enable_timing=True)
end_event = di.Event(enable_timing=True)
start_event.record()
for _ in range(5):
runtime.driver.active.clear_cache(cache)
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
fn()
end_event.record()
di.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
start_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [di.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
fn()
# Benchmark
for i in range(n_repeat):
# we don't want `fn` to accumulate gradient values
# if it contains a backward pass. So we clear the
# provided gradients
if grad_to_none is not None:
for x in grad_to_none:
x.grad = None
# we clear the L2 cache before each run
runtime.driver.active.clear_cache(cache)
# record time of `fn`
start_event[i].record()
fn()
end_event[i].record()
# Record clocks
di.synchronize()
times = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
return _summarize_statistics(times, quantiles, return_mode)
#---
if use_fa3 := False:
get_attn_out: Callable[[Tensor|tuple[Tensor, ...]], Tensor]
if use_local_fa3_kernel := True:
from kernels import get_local_kernel
flash_attn_interface = get_local_kernel(repo_path=Path('hf-kernels/flash-attention-3'), package_name='flash_attention_3').flash_attn_interface
get_attn_out = lambda out: out
elif use_community_fa3_kernel := False:
from kernels import get_kernel
flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface
get_attn_out = lambda out: out
elif use_dist_fa3 := False:
import flash_attn_interface
def get_attn_out(out):
assert isinstance(out, tuple)
out, lse = out
return out
else:
raise ValueError("well you have to pick one")
def varlen_attn(
q: FloatTensor,
k: FloatTensor,
v: FloatTensor,
max_seqlen_q: int,
max_seqlen_k: int,
cum_seq_q: Optional[IntTensor] = None,
cum_seq_k: Optional[IntTensor] = None,
causal=False,
scale: Optional[float] = None,
window_size_left: Optional[int] = None,
window_size_right: Optional[int] = None,
) -> FloatTensor:
assert (window_size_left is None) == (window_size_right is None)
window_size = None if window_size_left is None else (window_size_left, window_size_right)
out = flash_attn_interface.flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cum_seq_q,
cu_seqlens_k=cum_seq_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
causal=causal,
softmax_scale=scale,
window_size=window_size,
)
return get_attn_out(out)
else:
print("[WARN] falling back to torch builtin private varlen attn API, which is probably FA2")
def varlen_attn(
q: FloatTensor,
k: FloatTensor,
v: FloatTensor,
max_seqlen_q: int,
max_seqlen_k: int,
cum_seq_q: Optional[IntTensor] = None,
cum_seq_k: Optional[IntTensor] = None,
causal=False,
scale: Optional[float] = None,
window_size_left: Optional[int] = None,
window_size_right: Optional[int] = None,
) -> FloatTensor:
out, _, _, _, _ = torch.ops.aten._flash_attention_forward(
q, k, v,
cum_seq_q=cum_seq_q,
cum_seq_k=cum_seq_k,
max_q=max_seqlen_q,
max_k=max_seqlen_k,
dropout_p=0.0,
is_causal=causal,
return_debug_mask=False,
scale=scale,
window_size_left=window_size_left,
window_size_right=window_size_right,
)
return out
def next_multiple_of_n(v: float | int, *, n: int):
return next(x for x in range(n, int(v) + 1 + n, n) if x >= v)
@dataclass
class Hyperparameters:
train_bs_schedule: tuple = (8 * 2048 * 8, 16 * 2048 * 8, 24 * 2048 * 8)
train_max_seq_len: int = 128 * 16
val_batch_size: int = 4 * 64 * 1024 * 8
args = Hyperparameters()
rank = 0
world_size = 8
grad_accum_steps = 8 // world_size
max_seq_len=args.val_batch_size // (grad_accum_steps * world_size)
device = torch.device("cuda", 0)
torch.cuda.set_device(device)
@dataclass
class AttnArgs:
ve: torch.Tensor
sa_lambdas: torch.Tensor
seqlens: torch.Tensor
bm_size: int
cos: torch.Tensor
sin: torch.Tensor
attn_scale: float
key_shift: bool
hp_dtype=torch.bfloat16
# eps: float = torch.finfo(torch.bfloat16).eps # rms_norm eps defaults to the finfo eps of x.dtype.. well, it first upcasts x to at least f32, and *then* checks its dtype.
def norm(x: Tensor):
return F.rms_norm(x, (x.size(-1),))
class Yarn(nn.Module):
def __init__(self, head_dim, max_seq_len):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.reset()
def reset(self):
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device)
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)])
t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device)
theta = torch.outer(t, angular_freq)
self.cos = nn.Buffer(
theta.cos().to(hp_dtype), persistent=False
)
self.sin = nn.Buffer(
theta.sin().to(hp_dtype), persistent=False
)
self.angular_freq = angular_freq
# start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
self.attn_scale = 0.1
def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32):
rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi)
scaling_factor = old_window / new_window
interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1)
self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor)
t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device)
theta = torch.outer(t, self.angular_freq)
self.cos.copy_(theta.cos())
self.sin.copy_(theta.sin())
self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1
def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor):
assert cos.size(0) >= x_BTHD.size(-3)
cos, sin = (
cos[None, : x_BTHD.size(-3), None, :],
sin[None, : x_BTHD.size(-3), None, :],
)
x1, x2 = x_BTHD.chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3)
class CastedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0):
super().__init__(in_features, out_features, bias=False)
self.use_fp8 = use_fp8
self.x_s = x_s
self.w_s = w_s
self.grad_s = grad_s
def reset_parameters(self) -> None:
with torch.no_grad():
self.weight.zero_() # @Grad62304977 and others
def forward(self, x: Tensor):
if self.use_fp8 and self.training:
_x = x.flatten(0, -2)
out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0]
return out.reshape(*x.shape[:-1], -1)
else:
return F.linear(x, self.weight.type_as(x))
class CausalSelfAttentionOrig(nn.Module):
def __init__(self, dim: int, head_dim: int, num_heads: int):
super().__init__()
self.call_super_init = True
self.num_heads = num_heads
self.head_dim = head_dim
self.dim = dim
self.hdim = num_heads * head_dim
assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim"
std = self.dim ** -0.5
bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng
# merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
# https://x.com/hi_tysam/status/1879699187107033311
# Simplified layout by @chrisjmccormick
self.qkvo_w = nn.Parameter(torch.empty(self.dim * 4, self.hdim))
# label all modules for explicit optimizer grouping
self.qkvo_w.label = 'attn'
with torch.no_grad():
self.qkvo_w[:self.dim * 3].uniform_(-bound, bound) # init QKV weights
self.qkvo_w[self.dim * 3:].zero_() # init O weights to zero
# sparse gated attention to enable context based no-op by @classiclarryd
self.attn_gate = CastedLinear(12, num_heads)
self.attn_gate.weight.label = 'attn_gate'
def forward(self, x: Tensor, attn_args: AttnArgs):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "varlen sequences requires B == 1"
assert T % 16 == 0
# unpack attention args
cos, sin = attn_args.cos, attn_args.sin
ve, sa_lambdas, key_shift = attn_args.ve, attn_args.sa_lambdas, attn_args.key_shift
seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size
q, k, v = F.linear(x, sa_lambdas[0] * self.qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
q, k = norm(q), norm(k) # QK norm @Grad62304977
q, k = rotary(q, cos, sin), rotary(k, cos, sin)
if key_shift:
# shift keys forward for the stationary head dims. Enables 1-layer induction.
k[:, 1:, :, self.head_dim//4:self.head_dim//2] = k[:, :-1, :, self.head_dim//4:self.head_dim//2]
k[:, 1:, :, self.head_dim//4+self.head_dim//2:] = k[:, :-1, :, self.head_dim//4+self.head_dim//2:]
if ve is not None:
v = v + ve.view_as(v) # @ KoszarskyB & @Grad62304977
max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size))
# use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng
y: Tensor = varlen_attn(
q[0],
k[0],
v[0],
max_seqlen_q=max_len,
max_seqlen_k=max_len,
cum_seq_q=seqlens,
cum_seq_k=seqlens,
causal=True,
scale=attn_scale,
window_size_left=bm_size,
window_size_right=0,
)
y = y.view(B, T, self.num_heads, self.head_dim)
y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1)
y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side
y = F.linear(y, sa_lambdas[1] * self.qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg
return y
head_dim=128
with torch.device('cuda'):
yarn = Yarn(head_dim, max_seq_len)
dim=768
avg_seqlen=400 # median doc length is ~400
minibsz=args.train_bs_schedule[0]
pergpu_minibsz=minibsz//grad_accum_steps
microbsz = pergpu_minibsz // world_size
max_num_docs = next_multiple_of_n(microbsz // 300, n=128)
seqlens=torch.arange(0, avg_seqlen*max_num_docs, avg_seqlen, dtype=torch.int32, device=device).clamp_max_(microbsz)
short_bm=128
num_heads=6
with torch.device('meta'):
orig = CausalSelfAttentionOrig(
dim=dim,
head_dim=head_dim,
num_heads=num_heads,
)
cg = CausalSelfAttentionOrig(
dim=dim,
head_dim=head_dim,
num_heads=num_heads,
)
seed=42
gen=torch.Generator(device)
loss_fn = nn.MSELoss()
for attn in (orig, cg):
attn.to_empty(device=device)
attn.qkvo_w.data.normal_(std=attn.dim**-.5, generator=gen.manual_seed(seed))
attn.attn_gate.weight.data.normal_(std=attn.attn_gate.in_features**-.5, generator=gen.manual_seed(seed))
cg_grads_to_none = [cg.qkvo_w, cg.attn_gate.weight]
orig = torch.compile(orig, dynamic=False, fullgraph=True)
cg = torch.compile(cg, dynamic=False, fullgraph=True, mode='reduce-overhead')
ve = torch.randn((microbsz, dim), device=device, dtype=hp_dtype, requires_grad=True)
sa_lambdas = torch.tensor((.5, 1.), device=device, requires_grad=True)
input = torch.randn((1, microbsz, dim), device=device, dtype=hp_dtype, generator=gen.manual_seed(seed+1), requires_grad=True)
target = torch.randn((1, microbsz, dim), device=device, dtype=hp_dtype, generator=gen.manual_seed(seed+2))
def do_fwd(mod: nn.Module):
attn_args = AttnArgs(
ve=ve,
sa_lambdas=sa_lambdas,
seqlens=seqlens,
bm_size=short_bm,
cos=yarn.cos,
sin=yarn.sin,
attn_scale=yarn.attn_scale,
# attn_scale=head_dim**-.5,
key_shift=False
)
out: Tensor = mod(input, attn_args=attn_args)
return out
def do_lossbwd(out: Tensor):
loss: Tensor = loss_fn(out, target)
loss.backward()
return loss
def do_fwdbwd(mod: nn.Module):
out: Tensor = do_fwd(mod)
do_lossbwd(out)
def with_cudagraph_do_fwd(mod: nn.Module):
torch.compiler.cudagraph_mark_step_begin()
return do_fwd(mod)
def with_cudagraph_do_fwdbwd(mod: nn.Module):
torch.compiler.cudagraph_mark_step_begin()
out: Tensor = do_fwd(mod)
loss: Tensor = loss_fn(out, target)
loss.backward()
return loss
if test_correctness := False:
out_orig = do_fwd(orig)
out_cg = do_fwd(cg)
# rtol=1e-2 due to torch.finfo(torch.bfloat16).resolution
assert_close(out_orig, out_cg, rtol=1e-2, atol=1e-2)
do_lossbwd(out_orig)
do_lossbwd(out_cg)
assert_close(orig.qkvo_w.grad, cg.qkvo_w.grad, rtol=1e-6, atol=5e-5)
assert_close(orig.attn_gate.weight.grad, cg.attn_gate.weight.grad)
inputs_to_none = [input, ve, sa_lambdas]
def clear_input_grads():
for t in inputs_to_none:
t.grad = None
if do_profile := False:
wait, warmup, active = 1, 1, 1
prof_its = wait + warmup + active
prof = profile(
activities=[
ProfilerActivity.CPU,
ProfilerActivity.CUDA,
],
record_shapes=False,
# stack traces introduce sufficient CPU overhead as to mislead, so don't believe such profiles entirely.
# with_stack=True,
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
)
torch.cuda.synchronize()
for mod, label, wants_cudagraph in zip((orig, cg), ("orig", "cg"), (False, True), strict=True):
do_fwdbwd_ = with_cudagraph_do_fwdbwd if wants_cudagraph else do_fwdbwd
with prof:
for step in range(prof_its):
clear_input_grads()
mod.zero_grad()
do_fwdbwd_(mod)
torch.cuda.synchronize()
prof.step()
trace_dir = Path("out_trace")
trace_dir.mkdir(exist_ok=True)
profile_path = trace_dir / f"{label}.json"
print(f"Saving profile to {profile_path}")
prof.export_chrome_trace(str(profile_path))
if test_fwdbwd := False:
clear_input_grads()
cg.zero_grad()
with_cudagraph_do_fwdbwd(cg)
torch.cuda.synchronize()
clear_input_grads()
cg.zero_grad()
with_cudagraph_do_fwdbwd(cg)
torch.cuda.synchronize()
if test_latency := True:
warmup, rep = 1000, 2000
orig_ms: float = do_bench(partial(do_fwdbwd, mod=orig), rep=rep, warmup=warmup, grad_to_none=inputs_to_none)
cg_ms: float = do_bench(partial(with_cudagraph_do_fwdbwd, mod=cg), rep=rep, warmup=warmup, grad_to_none=[*cg_grads_to_none, *inputs_to_none])
orig_its: float = 1000 / orig_ms
cg_its: float = 1000 / cg_ms
print(f"""
orig: {orig_its:.2f} it/s
cg: {cg_its:.2f} it/s
""")
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment