Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created June 2, 2025 22:20
Show Gist options
  • Save mlazos/43a2702d17ae3791d08628ef234308b3 to your computer and use it in GitHub Desktop.
Save mlazos/43a2702d17ae3791d08628ef234308b3 to your computer and use it in GitHub Desktop.
import torch
from torch import nn
from torch.nn import functional as F
from triton.testing import do_bench
import triton
import triton.language as tl
import torch._inductor.config as config
from torch._inductor.utils import fresh_inductor_cache
#torch._logging.set_logs(autotuning=True)
config.max_autotune_gemm_backends = "CUTLASS"
config.benchmark_fusion = False
config.benchmark_epilogue_fusion = False
torch.set_default_device('cuda')
with fresh_inductor_cache():
shape_a = 4224, 8192
#shape_a = 256, 256
shape_b = 2048, 8192
#shape_b = 256, 256
a = torch.randn(*shape_a, dtype=torch.bfloat16)
b = torch.randn(*shape_b, dtype=torch.bfloat16).t()
c = torch.randn(shape_a[0], 1, dtype=torch.bfloat16)
d = torch.randn(shape_a[0], shape_b[0], dtype=torch.float32)
def get_flops(f):
ms = do_bench(f, warmup=100, rep=10000)
print(ms)
print((1e3/ms) * a.shape[0] * a.shape[1] * b.shape[1] * 2 / 1e12, 'TF')
#f_layout = lambda: (torch.mm(a, b) + c).permute(1, 0) + d.permute(1, 0)
#f = lambda: torch.mm(a, b).relu().permute(1, 0).sigmoid()
#f = lambda: torch.mm(a, b).relu().permute(1, 0).sigmoid().permute(1, 0)
f = lambda: torch.mm(a, b).relu() + c
f = torch.compile(f, mode="max-autotune-no-cudagraphs")
# Also set `sudo nvidia-smi boost-slider --vboost 1`, which shifts more power from l2 cache to tensor cores
get_flops(f) # 780.1689058368037 TF
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment