Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created June 14, 2025 00:12
Show Gist options
  • Save Birch-san/b378e14266003c4bf3854f735b40d986 to your computer and use it in GitHub Desktop.
Save Birch-san/b378e14266003c4bf3854f735b40d986 to your computer and use it in GitHub Desktop.
Does linearize work? Am I using it right?
from __future__ import annotations
from argparse import ArgumentParser, Namespace
from dataclasses import dataclass
from functools import partial
from typing import Callable, Generic, TypeVar
import torch
from torch import enable_grad, no_grad
import torch.autograd.forward_ad as fwAD
from torch.func import linearize
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention
from torch.utils.flop_counter import FlopCounterMode
T = TypeVar('T')
# Python *please* bring back support for generic NamedTuples
def get_flop_count(f: Callable[[], T], display_ops=True) -> tuple[int, T]:
flop_counter = FlopCounterMode(display=display_ops)
with flop_counter:
out: T = f()
return flop_counter.get_total_flops(), out
@dataclass
class Args:
bsz: int
model_dim: int
head_dim: int
seq_len: int
@staticmethod
def get_parser() -> ArgumentParser:
parser = ArgumentParser()
parser.add_argument("--bsz", default=1, type=int)
parser.add_argument("--model-dim", default=320, type=int)
parser.add_argument("--head-dim", default=64, type=int)
parser.add_argument("--seq-len", default=128, type=int)
return parser
@staticmethod
def from_namespace(namespace: Namespace) -> Args:
args = Args(**vars(namespace))
return args
def main(args: Args) -> None:
device = torch.device('cuda')
dtype = torch.float16
seed = 42
gen = torch.Generator(device=device)
heads = args.model_dim // args.head_dim
q_p, q_t, k_p, k_t, v_p, v_t = (torch.randn(args.bsz, heads, args.seq_len, args.head_dim, device=device, dtype=dtype, generator=gen.manual_seed(seed + ix)) for ix in range(6))
with sdpa_kernel(SDPBackend.FLASH_ATTENTION), no_grad():
print("Flash, fwd only")
flop_count_flash_fwd, out_flash_fwd = get_flop_count(partial(scaled_dot_product_attention, q_p, k_p, v_p), display_ops=True)
with sdpa_kernel(SDPBackend.MATH):
with no_grad():
print("Math, fwd only")
flop_count_math_fwd, out_math_fwd = get_flop_count(partial(scaled_dot_product_attention, q_p, k_p, v_p), display_ops=True)
with fwAD.dual_level(), enable_grad():
print("Math, fwd+jvp")
q, k, v = (fwAD.make_dual(p, t) for p, t in zip((q_p, k_p, v_p), (q_t, k_t, v_t)))
flop_count_math_jvp, out_math_jvp = get_flop_count(partial(scaled_dot_product_attention, q, k, v), display_ops=True)
print("Math, fwd+jvp, via linearize (step 1: invoking linearize)")
flop_count_linearize, (attn_out, jvp_fn) = get_flop_count(partial(linearize, scaled_dot_product_attention, q_p, k_p, v_p), display_ops=True)
print("Math, fwd+jvp, via linearize (step 2: invoking jvp_fn)")
flop_count_jvp_fn, out_jvp_fn = get_flop_count(partial(jvp_fn, q_t, k_t, v_t), display_ops=True)
pass
if __name__ == "__main__":
parser = Args.get_parser()
args_untyped: Namespace = parser.parse_args()
args: Args = Args.from_namespace(args_untyped)
main(args)
@Birch-san
Copy link
Author

I thought the idea of linearize() was that (compared to fwAD) it avoids redoing forward pass computations. Yet here it seems there's no savings, or rather it's far worse?

Flash, fwd only
Module                                          FLOP    % Total
-------------------------------------------  -------  ---------
Global                                       20.972M    100.00%
 - aten._scaled_dot_product_flash_attention  20.972M    100.00%

Math, fwd only
Module          FLOP    % Total
-----------  -------  ---------
Global       20.972M    100.00%
 - aten.bmm  20.972M    100.00%

Math, fwd+jvp
Module          FLOP    % Total
-----------  -------  ---------
Global       62.915M    100.00%
 - aten.bmm  62.915M    100.00%

Math, fwd+jvp, via linearize (step 1: invoking linearize)
/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/const_fold.py:264: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  new_node = root_const_gm.graph.get_attr(in_node.target)
Module          FLOP    % Total
-----------  -------  ---------
Global       83.886M    100.00%
 - aten.bmm  83.886M    100.00%
 - 
Math, fwd+jvp, via linearize (step 2: invoking jvp_fn)
Module          FLOP    % Total
-----------  -------  ---------
GraphModule  62.915M    100.00%
 - aten.bmm  62.915M    100.00%

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment