Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Last active June 13, 2025 23:47
Show Gist options
  • Save Birch-san/08f49ca5f8fb0bf0174c9f760605ffa5 to your computer and use it in GitHub Desktop.
Save Birch-san/08f49ca5f8fb0bf0174c9f760605ffa5 to your computer and use it in GitHub Desktop.
from __future__ import annotations
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from typing import Callable
import torch
from torch import enable_grad, no_grad
import torch.autograd.forward_ad as fwAD
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from torch.utils.flop_counter import FlopCounterMode
NiladicFn = Callable[[], None]
def get_flop_count(f: NiladicFn, display_ops=True) -> int:
flop_counter = FlopCounterMode(display=display_ops)
with flop_counter:
f()
return flop_counter.get_total_flops()
@dataclass
class Args:
bsz: int
model_dim: int
head_dim: int
seq_len: int
@staticmethod
def get_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("--bsz", default=1, type=int)
parser.add_argument("--model-dim", default=320, type=int)
parser.add_argument("--head-dim", default=64, type=int)
parser.add_argument("--seq-len", default=128, type=int)
return parser
@staticmethod
def from_namespace(namespace: Namespace) -> Args:
args = Args(**vars(namespace))
return args
def main(args: Args) -> None:
device = torch.device('cuda')
dtype = torch.float16
seed = 42
gen = torch.Generator(device=device)
heads = args.model_dim // args.head_dim
q_p, q_t, k_p, k_t, v_p, v_t = (torch.randn(args.bsz, heads, args.seq_len, args.head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(6))
with sdpa_kernel(SDPBackend.FLASH_ATTENTION), no_grad():
print("Flash, fwd only")
flop_count_flash_fwd: int = get_flop_count(partial(scaled_dot_product_attention, q_p, k_p, v_p), display_ops=True)
with sdpa_kernel(SDPBackend.MATH):
with no_grad():
print("Math, fwd only")
flop_count_math_fwd: int = get_flop_count(partial(scaled_dot_product_attention, q_p, k_p, v_p), display_ops=True)
with fwAD.dual_level(), enable_grad():
print("Math, fwd+jvp")
q, k, v = (fwAD.make_dual(p, t) for p, t in zip((q_p, k_p, v_p), (q_t, k_t, v_t)))
flop_count_math_jvp: int = get_flop_count(partial(scaled_dot_product_attention, q, k, v), display_ops=True)
if __name__ == "__main__":
parser = Args.get_parser()
args_untyped: Namespace = parser.parse_args()
args: Args = Args.from_namespace(args_untyped)
main(args)
@Birch-san
Copy link
Author

Birch-san commented Jun 13, 2025

bsz=1
model_dim=320
head_dim=64
seq_len=128
Flash, fwd only
Module                                          FLOP    % Total
-------------------------------------------  -------  ---------
Global                                       20.972M    100.00%
 - aten._scaled_dot_product_flash_attention  20.972M    100.00%

Math, fwd only
Module          FLOP    % Total
-----------  -------  ---------
Global       20.972M    100.00%
 - aten.bmm  20.972M    100.00%

Math, fwd+jvp
Module          FLOP    % Total
-----------  -------  ---------
Global       62.915M    100.00%
 - aten.bmm  62.915M    100.00%

62.915 / 20.972 = 3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment