Last active
February 16, 2024 00:27
-
-
Save malfet/6a17156d7f5663b8b12054a1beff3fe1 to your computer and use it in GitHub Desktop.
Measure performance difference of `torch.mm` vs `torch.bmm`
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Benchmark relative performance of torch.mm and torch.bmm with single batch | |
import torch | |
import time | |
def benchmark_fn(fn, args, warmup=5, cycles=300, use_kineto=False) -> float: | |
if use_kineto: | |
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p: | |
fn(*args) | |
return sum([e.cuda_time for e in p.key_averages()]) | |
for _ in range(warmup): | |
fn(*args) | |
torch.cuda.synchronize() | |
begin = time.time() | |
for _ in range(cycles): | |
fn(*args) | |
torch.cuda.synchronize() | |
dt = (time.time() - begin) | |
dt_us = int(dt * 1000000) / cycles | |
return dt_us | |
if __name__ == "__main__": | |
print("torch: ", torch.__version__, " device: ", torch.cuda.get_device_name(0)) | |
msizes = [(1, 1, 4096), (1, 1, 65536), (129, 129, 129), (257, 257, 257), (128, 257, 512), (16385, 5, 16385)] | |
msizes = [(1, 1, 2**x) for x in range(12, 18)] | |
msizes += [(2**x, 2**x, 2**x) for x in range(7, 12)] | |
msizes += [(2**x+1, 2**x-1, 2**x+1) for x in range(7, 12)] | |
msizes += [(2**x+1, 3, 2**x+1) for x in range(12, 17)] | |
msizes += [(2**x+1, 5, 2**x+1) for x in range(12, 17)] | |
msizes += [(2**x+1, 7, 2**x+1) for x in range(12, 17)] | |
print("| Shape | bmm_time | mm_time | slow down (%) |") | |
print("| -------------- | --------- | --------- | ------------- |") | |
for (m, n, k) in msizes: | |
a = torch.rand((m, k), device='cuda') | |
b = torch.rand((k, n), device='cuda') | |
bmm_time = benchmark_fn(torch.bmm, (a.unsqueeze(0), b.unsqueeze(0))) | |
mm_time = benchmark_fn(torch.mm, (a, b)) | |
shape_str=f"{m}x{n}x{k}" | |
print(f"| {shape_str :^14} | {bmm_time :^9.2f} | {mm_time :^9.2f} | {100.0*(bmm_time-mm_time)/mm_time :^13.2f} |") | |
assert torch.allclose(torch.bmm(a.unsqueeze(0), b.unsqueeze(0)).squeeze(0), torch.mm(a, b)) | |
# Running above script on A100 with torch-2.1.1+cu118 following output is produced | |
# torch: 2.1.1+cu118 device: NVIDIA A100-SXM4-40GB | |
# | Shape | bmm_time | mm_time | slow down (%) | | |
# | -------------- | --------- | --------- | ------------- | | |
# | 1x1x4096 | 12.38 | 11.96 | 3.48 | | |
# | 1x1x8192 | 12.26 | 11.84 | 3.55 | | |
# | 1x1x16384 | 11.81 | 11.66 | 1.29 | | |
# | 1x1x32768 | 12.00 | 11.81 | 1.61 | | |
# | 1x1x65536 | 14.82 | 15.05 | -1.48 | | |
# | 1x1x131072 | 12.02 | 11.77 | 2.15 | | |
# | 128x128x128 | 9.47 | 9.69 | -2.24 | | |
# | 256x256x256 | 12.66 | 12.60 | 0.50 | | |
# | 512x512x512 | 27.34 | 27.31 | 0.10 | | |
# | 1024x1024x1024 | 129.59 | 129.48 | 0.08 | | |
# | 2048x2048x2048 | 973.63 | 973.04 | 0.06 | | |
# | 129x127x129 | 9.56 | 8.97 | 6.62 | | |
# | 257x255x257 | 12.85 | 12.78 | 0.52 | | |
# | 513x511x513 | 28.99 | 28.98 | 0.05 | | |
# | 1025x1023x1025 | 137.92 | 137.76 | 0.11 | | |
# | 2049x2047x2049 | 982.34 | 982.32 | 0.00 | | |
# | 4097x3x4097 | 86.94 | 86.91 | 0.03 | | |
# | 8193x3x8193 | 384.38 | 384.54 | -0.04 | | |
# | 16385x3x16385 | 1106.25 | 1107.35 | -0.10 | | |
# | 32769x3x32769 | 4736.79 | 4737.19 | -0.01 | | |
# | 65537x3x65537 | 17368.65 | 17371.21 | -0.01 | | |
# | 4097x5x4097 | 87.50 | 87.49 | 0.01 | | |
# | 8193x5x8193 | 302.27 | 302.29 | -0.00 | | |
# | 16385x5x16385 | 1107.69 | 1107.65 | 0.00 | | |
# | 32769x5x32769 | 4743.02 | 4743.13 | -0.00 | | |
# | 65537x5x65537 | 17393.08 | 17392.32 | 0.00 | | |
# | 4097x7x4097 | 87.58 | 87.60 | -0.02 | | |
# | 8193x7x8193 | 302.42 | 302.45 | -0.01 | | |
# | 16385x7x16385 | 1106.55 | 1107.34 | -0.07 | | |
# | 32769x7x32769 | 4746.99 | 4746.58 | 0.01 | | |
# | 65537x7x65537 | 17406.08 | 17424.31 | -0.10 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment