Skip to content

Instantly share code, notes, and snippets.

@Birch-san
Created November 5, 2022 01:02
Show Gist options
  • Save Birch-san/cba16789ec27bb20996a4b4831b13ce0 to your computer and use it in GitHub Desktop.
Save Birch-san/cba16789ec27bb20996a4b4831b13ce0 to your computer and use it in GitHub Desktop.
benchmark: batched matmul
import torch
from torch import einsum, matmul, bmm
import time
repeats = 10
batch_duration = 0
for ix in range(repeats):
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps")
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
einsum('b i j, b j d -> b i d', attn, v).max().item()
duration = time.perf_counter()-start
print('einsum 1 iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of einsum 1 took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
batch_duration = 0
for ix in range(repeats):
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps")
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
matmul(attn, v).max().item()
duration = time.perf_counter()-start
print('matmul iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of matmul took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
batch_duration = 0
for ix in range(repeats):
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps")
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps")
start = time.perf_counter()
bmm(attn, v).max().item()
duration = time.perf_counter()-start
print('bmm iteration %d took %.4f seconds' % (ix, duration))
batch_duration += duration
print('%d iterations of bmm took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats))
@Birch-san
Copy link
Author

Birch-san commented Nov 5, 2022

Torch 1.14.0.dev20221103, MPS

einsum 1 iteration 0 took 0.1539 seconds
einsum 1 iteration 1 took 0.0418 seconds
einsum 1 iteration 2 took 0.0421 seconds
einsum 1 iteration 3 took 0.0421 seconds
einsum 1 iteration 4 took 0.0421 seconds
einsum 1 iteration 5 took 0.0426 seconds
einsum 1 iteration 6 took 0.0418 seconds
einsum 1 iteration 7 took 0.0421 seconds
einsum 1 iteration 8 took 0.0422 seconds
einsum 1 iteration 9 took 0.0427 seconds
10 iterations of einsum 1 took 0.5335 seconds; avg 0.0533 secs
matmul iteration 0 took 0.2850 seconds
matmul iteration 1 took 0.1274 seconds
matmul iteration 2 took 0.0813 seconds
matmul iteration 3 took 0.0822 seconds
matmul iteration 4 took 0.0819 seconds
matmul iteration 5 took 0.0823 seconds
matmul iteration 6 took 0.0821 seconds
matmul iteration 7 took 0.0823 seconds
matmul iteration 8 took 0.0813 seconds
matmul iteration 9 took 0.0827 seconds
10 iterations of matmul took 1.0686 seconds; avg 0.1069 secs
bmm iteration 0 took 0.0368 seconds
bmm iteration 1 took 0.0367 seconds
bmm iteration 2 took 0.0366 seconds
bmm iteration 3 took 0.0367 seconds
bmm iteration 4 took 0.0367 seconds
bmm iteration 5 took 0.0370 seconds
bmm iteration 6 took 0.0367 seconds
bmm iteration 7 took 0.0362 seconds
bmm iteration 8 took 0.0368 seconds
bmm iteration 9 took 0.0373 seconds
10 iterations of bmm took 0.3674 seconds; avg 0.0367 secs

comparing best times of einsum 1 vs bmm; .0418 vs .0362

bmm is 15% faster

@Birch-san
Copy link
Author

Birch-san commented Nov 5, 2022

PyTorch 1.12.1, MPS

einsum 1 iteration 0 took 0.1368 seconds
einsum 1 iteration 1 took 0.1184 seconds
einsum 1 iteration 2 took 0.0402 seconds
einsum 1 iteration 3 took 0.0408 seconds
einsum 1 iteration 4 took 0.0414 seconds
einsum 1 iteration 5 took 0.0414 seconds
einsum 1 iteration 6 took 0.0393 seconds
einsum 1 iteration 7 took 0.0416 seconds
einsum 1 iteration 8 took 0.0413 seconds
einsum 1 iteration 9 took 0.0406 seconds
10 iterations of einsum 1 took 0.5818 seconds; avg 0.0582 secs
matmul iteration 0 took 0.2365 seconds
matmul iteration 1 took 0.1446 seconds
matmul iteration 2 took 0.0865 seconds
matmul iteration 3 took 0.0863 seconds
matmul iteration 4 took 0.0866 seconds
matmul iteration 5 took 0.0870 seconds
matmul iteration 6 took 0.0864 seconds
matmul iteration 7 took 0.0870 seconds
matmul iteration 8 took 0.0864 seconds
matmul iteration 9 took 0.0864 seconds
10 iterations of matmul took 1.0737 seconds; avg 0.1074 secs
bmm iteration 0 took 0.0412 seconds
bmm iteration 1 took 0.0409 seconds
bmm iteration 2 took 0.0415 seconds
bmm iteration 3 took 0.0406 seconds
bmm iteration 4 took 0.0404 seconds
bmm iteration 5 took 0.0413 seconds
bmm iteration 6 took 0.0407 seconds
bmm iteration 7 took 0.0410 seconds
bmm iteration 8 took 0.0410 seconds
bmm iteration 9 took 0.0411 seconds
10 iterations of bmm took 0.4096 seconds; avg 0.0410 secs

comparing best times of einsum 1 vs bmm; .0393 vs .0404

bmm is 2.7% slower

@Birch-san
Copy link
Author

8 Heun steps, building on top of the baddbmm optimization from:
https://gist.github.com/Birch-san/8f3eb99deffdc3541595e46a01605dea

1.12.1  
einsum + einsum:  
9.688710416987306

1.14.0.dev20221103  
einsum + einsum:  
10.383701542014023

====

1.12.1  
baddbmm + einsum:  
9.598911916022189

1.14.0.dev20221103  
baddbmm + einsum:  
9.281007582991151

=====

1.12.1  
baddbmm + bdd:  
9.153142041992396

1.14.0.dev20221103:  
baddbmm + bdd:  
8.686828749952838

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment