Skip to content

Instantly share code, notes, and snippets.

@Chillee
Last active August 1, 2024 18:11
Show Gist options
  • Save Chillee/f86675147366a7a0c6e244eaa78660f7 to your computer and use it in GitHub Desktop.
Save Chillee/f86675147366a7a0c6e244eaa78660f7 to your computer and use it in GitHub Desktop.
PT 2.0 Benchmarks
import torch
import torch._inductor.config
import time
torch._inductor.config.triton.cudagraphs = False
torch.set_float32_matmul_precision('high')
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time()-begin)*1e6/iters
if name is None:
res = us_per_iter
else:
res= f"{name}: {us_per_iter}us"
if display:
print(res)
return res
def f1(a, b, c, d):
a = a.relu()
b = b.tanh()
e = a * b
f = (c + 2).cos()
return (e + f) * d
inp = [torch.randn(2**24, device='cuda') for _ in range(4)]
f = f1
nf = torch.compile(f)
bench(lambda: f(*inp), name="eager")
bench(lambda: nf(*inp), name="PT 2.0")
import torch
from torch.nn import *
torch.set_float32_matmul_precision('high')
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
import time
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time()-begin)*1e6/iters
if name is None:
res = us_per_iter
else:
res= f"{name}: {us_per_iter:.2f}us"
if display:
print(res)
return res
import torchvision.models as models
mod = models.resnet18().eval().cuda()
opt_mod = torch.compile(mod, mode="reduce-overhead")
inp = torch.randn(1, 3, 224, 224).cuda()
with torch.no_grad():
# Eager: 1938.18us
bench(lambda: mod(inp), "Eager")
# torch.compile (default): 953.96us
# torch.compile (reduce-overhead): 744.02us
bench(lambda: opt_mod(inp), "torch.compile (reduce-overhead)")
import torch
from triton.testing import do_bench
def get_flops(N, get_kernels=False):
A = torch.randn(N, N, device='cuda', dtype=torch.float16)
B = torch.randn(N, N, device='cuda', dtype=torch.float16)
def f():
return torch.mm(A, B)
if get_kernels:
with torch.profiler.profile() as prof:
f()
for e in prof.events():
if "gemm" in e.name or "triton" in e.name or "gemv" in e.name:
print(f"{N}: {e.name}")
timer = e.cuda_time/1e3
timer = do_bench(f)
iters_per_second = 1e3/timer
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
flops_achieved = iters_per_second * flops/1e12
print(f"{N}: {flops_achieved:.2f}TF/s")
for N in range(1, 4096):
get_flops(N)
import torch
torch.set_float32_matmul_precision('high')
import torch._inductor.config
torch._inductor.config.debug = True
def bench(f, name=None, iters=100, warmup=5, display=True, profile=False):
import time
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time()-begin)*1e6/iters
if name is None:
res = us_per_iter
else:
res= f"{name}: {us_per_iter:.3f}us"
if display:
print(res)
return res
def get_bandwidth(name, f):
iters_per_second = 1e6/bench(f, display=False)
bytes_accessed = N**2*4*3
print(f"{name}: {iters_per_second * bytes_accessed/1e9:.2f}GB")
N = 2**14
def f(a, b):
return a + b
A = torch.randn(N, N, device='cuda')
B = torch.randn(N, N, device='cuda')
# eager: 1389.84GB
get_bandwidth("eager", lambda: f(A, B))
# torch.compile: 1388.19GB
get_bandwidth("torch.compile", lambda: torch.compile(f)(A, B))
def f2(a, b):
return a + b.t()
A = torch.randn(N, N, device='cuda')
B = torch.randn(N, N, device='cuda')
# eager: 904.01GB
get_bandwidth("eager", lambda: f2(A, B))
# torch.compile: 1334.89GB
get_bandwidth("torch.compile", lambda: torch.compile(f2)(A, B))
import torch
from triton.testing import do_bench
def get_flops(N, get_kernels=False):
A = torch.randn(N, N, device='cuda', dtype=torch.float16)
B = torch.randn(N, N, device='cuda', dtype=torch.float16)
def f():
return torch.mm(A, B)
if get_kernels:
with torch.profiler.profile() as prof:
f()
for e in prof.events():
if "gemm" in e.name or "triton" in e.name or "gemv" in e.name:
print(f"{N}: {e.name}")
timer = e.cuda_time/1e3
timer = do_bench(f)
iters_per_second = 1e3/timer
flops = A.shape[0] * A.shape[1] * B.shape[1] * 2
flops_achieved = iters_per_second * flops/1e12
print(f"{N}: {flops_achieved:.2f}TF/s")
for N in range(1, 4096):
get_flops(N)
import torch
from torch import nn
import torch.nn.functional as F
import torch.autograd as autograd
torch.set_default_device('cuda')
import torch._inductor.config
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.assert_indirect_indexing = False
D = 2048
E = 8
for D in [1024, 2048, 4096, 8192, 16384]:
def bench(f, name=None, iters=1000, warmup=5, display=True, profile=False):
import time
from triton.testing import do_bench
for _ in range(warmup):
f()
if profile:
with torch.profiler.profile() as prof:
f()
prof.export_chrome_trace(f"{name if name is not None else 'trace'}.json")
us_per_iter = do_bench(lambda: f())*1000
print(f"{name}: {(1e6/us_per_iter) * 2 * D * D * 4 / 1e9} GB/s")
return 0
def cuda_indexing(W, score_idxs, x):
return W[score_idxs] @ x
def python_indexing(W, score_idxs, x):
return W[score_idxs[0]] @ x, W[score_idxs[1]] @ x
W = torch.randn(E, D, D)
x = torch.randn(D)
score_idxs = torch.tensor([3, 5])
compiled_cuda = torch.compile(cuda_indexing, dynamic=False)
print(f"D={D}")
bench(lambda: python_indexing(W, score_idxs, x), "python indexing")
bench(lambda: cuda_indexing(W, score_idxs, x), "eager CUDA indexing")
bench(lambda: compiled_cuda(W, score_idxs, x), "compiled CUDA indexing")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment