Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active March 2, 2025 22:10
Show Gist options
  • Save Chillee/07b36672a0ca2d1280e42b8d10f23174 to your computer and use it in GitHub Desktop.
Save Chillee/07b36672a0ca2d1280e42b8d10f23174 to your computer and use it in GitHub Desktop.
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
iters_per_second = 1e3/ms_per_iter
print(f"{iters_per_second * total_flops / 1e12} TF/s")
from torchvision.models import resnet18
model = resnet18().cuda().half()
inp = torch.randn(128, 3, 224, 224, device='cuda', dtype=torch.half)
get_flops_achieved(lambda: model(inp).sum().backward())
@stas00
Copy link

stas00 commented Jan 23, 2025

Ooops, it looks like I didn't do it - just did here pytorch/pytorch#145555

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment