Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created November 5, 2022 01:02
Show Gist options
  • Save Birch-san/cba16789ec27bb20996a4b4831b13ce0 to your computer and use it in GitHub Desktop.
Save Birch-san/cba16789ec27bb20996a4b4831b13ce0 to your computer and use it in GitHub Desktop.
benchmark: batched matmul
import torch
from torch import einsum, matmul, bmm
import time
repeats = 10
batch_duration = 0
for ix in range(repeats):
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps")
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
einsum('b i j, b j d -> b i d', attn, v).max().item()
duration = time.perf_counter()-start
print('einsum 1 iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of einsum 1 took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
batch_duration = 0
for ix in range(repeats):
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps")
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
matmul(attn, v).max().item()
duration = time.perf_counter()-start
print('matmul iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of matmul took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
batch_duration = 0
for ix in range(repeats):
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps")
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
bmm(attn, v).max().item()
duration = time.perf_counter()-start
print('bmm iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of bmm took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
@Birch-san
Copy link
Author

8 Heun steps, building on top of the baddbmm optimization from:
https://gist.github.com/Birch-san/8f3eb99deffdc3541595e46a01605dea

1.12.1  
einsum + einsum:  
9.688710416987306

1.14.0.dev20221103  
einsum + einsum:  
10.383701542014023

====

1.12.1  
baddbmm + einsum:  
9.598911916022189

1.14.0.dev20221103  
baddbmm + einsum:  
9.281007582991151

=====

1.12.1  
baddbmm + bdd:  
9.153142041992396

1.14.0.dev20221103:  
baddbmm + bdd:  
8.686828749952838

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment