Skip to content

Instantly share code, notes, and snippets.

@stas00
Forked from Chillee/mfu_compute.py
Created January 5, 2024 23:28
Show Gist options
  • Save stas00/54ffb290d855baa3fc5dbd906b3b46fe to your computer and use it in GitHub Desktop.
Save stas00/54ffb290d855baa3fc5dbd906b3b46fe to your computer and use it in GitHub Desktop.
Compute Flop Utilization in PyTorch
import torch
from torch.utils.flop_counter import FlopCounterMode
from triton.testing import do_bench
def get_flops_achieved(f):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
f()
total_flops = flop_counter.get_total_flops()
ms_per_iter = do_bench(f)
iters_per_second = 1e3/ms_per_iter
print(f"{iters_per_second * total_flops / 1e12} TF/s")
from torchvision.models import resnet18
model = resnet18().cuda().half()
inp = torch.randn(128, 3, 224, 224, device='cuda', dtype=torch.half)
get_flops_achieved(lambda: model(inp).sum().backward())
compiled_model = torch.compile(model)
get_flops_achieved(lambda: compiled_model(inp).sum().backward())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment