Created
November 5, 2022 01:02
-
-
Save Birch-san/cba16789ec27bb20996a4b4831b13ce0 to your computer and use it in GitHub Desktop.
benchmark: batched matmul
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import einsum, matmul, bmm | |
import time | |
repeats = 10 | |
batch_duration = 0 | |
for ix in range(repeats): | |
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps") | |
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
einsum('b i j, b j d -> b i d', attn, v).max().item() | |
duration = time.perf_counter()-start | |
print('einsum 1 iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of einsum 1 took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps") | |
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
matmul(attn, v).max().item() | |
duration = time.perf_counter()-start | |
print('matmul iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of matmul took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
attn = torch.rand(16, 4096, 4096, dtype=torch.float, device="mps") | |
v = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
bmm(attn, v).max().item() | |
duration = time.perf_counter()-start | |
print('bmm iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of bmm took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) |
PyTorch 1.12.1, MPS
einsum 1 iteration 0 took 0.1368 seconds
einsum 1 iteration 1 took 0.1184 seconds
einsum 1 iteration 2 took 0.0402 seconds
einsum 1 iteration 3 took 0.0408 seconds
einsum 1 iteration 4 took 0.0414 seconds
einsum 1 iteration 5 took 0.0414 seconds
einsum 1 iteration 6 took 0.0393 seconds
einsum 1 iteration 7 took 0.0416 seconds
einsum 1 iteration 8 took 0.0413 seconds
einsum 1 iteration 9 took 0.0406 seconds
10 iterations of einsum 1 took 0.5818 seconds; avg 0.0582 secs
matmul iteration 0 took 0.2365 seconds
matmul iteration 1 took 0.1446 seconds
matmul iteration 2 took 0.0865 seconds
matmul iteration 3 took 0.0863 seconds
matmul iteration 4 took 0.0866 seconds
matmul iteration 5 took 0.0870 seconds
matmul iteration 6 took 0.0864 seconds
matmul iteration 7 took 0.0870 seconds
matmul iteration 8 took 0.0864 seconds
matmul iteration 9 took 0.0864 seconds
10 iterations of matmul took 1.0737 seconds; avg 0.1074 secs
bmm iteration 0 took 0.0412 seconds
bmm iteration 1 took 0.0409 seconds
bmm iteration 2 took 0.0415 seconds
bmm iteration 3 took 0.0406 seconds
bmm iteration 4 took 0.0404 seconds
bmm iteration 5 took 0.0413 seconds
bmm iteration 6 took 0.0407 seconds
bmm iteration 7 took 0.0410 seconds
bmm iteration 8 took 0.0410 seconds
bmm iteration 9 took 0.0411 seconds
10 iterations of bmm took 0.4096 seconds; avg 0.0410 secs
comparing best times of einsum 1 vs bmm; .0393 vs .0404
bmm is 2.7% slower
8 Heun steps, building on top of the baddbmm optimization from:
https://gist.github.com/Birch-san/8f3eb99deffdc3541595e46a01605dea
1.12.1
einsum + einsum:
9.688710416987306
1.14.0.dev20221103
einsum + einsum:
10.383701542014023
====
1.12.1
baddbmm + einsum:
9.598911916022189
1.14.0.dev20221103
baddbmm + einsum:
9.281007582991151
=====
1.12.1
baddbmm + bdd:
9.153142041992396
1.14.0.dev20221103:
baddbmm + bdd:
8.686828749952838
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Torch 1.14.0.dev20221103, MPS
comparing best times of einsum 1 vs bmm; .0418 vs .0362
bmm is 15% faster