Skip to content

Instantly share code, notes, and snippets.

@eqy
Created September 19, 2025 16:50
Show Gist options
  • Select an option

  • Save eqy/34459e277be009112676e8b651aebbac to your computer and use it in GitHub Desktop.

Select an option

Save eqy/34459e277be009112676e8b651aebbac to your computer and use it in GitHub Desktop.
tf32.py
import torch
import time
warmup = 100
iters = 1000
dtype = torch.float32
x = torch.randn(16, 4096, device='cuda', dtype=dtype)
w = torch.randn(4096, 4096, device='cuda', dtype=dtype)
b = torch.randn(4096, device='cuda', dtype=dtype)
flops = 16 * 4096 * 4096 * 2 + 4096
torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cuda.matmul.allow_tf32 = True
#@torch.compile(mode='reduce-overhead')
def func(x, w, b):
return torch.nn.functional.linear(x, w, b)
for _ in range(warmup):
func(x, w, b)
torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(iters):
func(x, w, b)
t1 = time.perf_counter()
torch.cuda.synchronize()
tpi = (t1 - t0)/iters
print(f"time per iter {tpi}, {flops/tpi/1e12} tflops")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment