Last active
November 22, 2022 21:10
-
-
Save Birch-san/8f3eb99deffdc3541595e46a01605dea to your computer and use it in GitHub Desktop.
benchmark: batched matmul with scale factor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import einsum, tensor, matmul, bmm, baddbmm, empty | |
import time | |
scale=2 | |
repeats = 10 | |
# both einsum 0s use the same plan, so whichever batch runs first has to pay the price of warmup | |
# uncomment this to run a warmup before either batch runs, for fairer comparison of batch avg time | |
# q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
# k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
# start = time.perf_counter() | |
# (einsum('b i d, b j d -> b i j', q, k) * scale).max().item() | |
# duration = time.perf_counter()-start | |
# print('einsum 0 warmup took %.4f seconds' % (duration)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
(einsum('b i d, b j d -> b i j', q, k) * scale).max().item() | |
duration = time.perf_counter()-start | |
print('einsum 0 iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of einsum 0 took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
(einsum('b n m, b m p -> b n p', q, k.transpose(1, 2)) * scale).max().item() | |
duration = time.perf_counter()-start | |
print('einsum 0 transposed k iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of einsum 0 transposed k took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
(matmul(q, k.transpose(1, 2)) * scale).max().item() | |
duration = time.perf_counter()-start | |
print('matmul iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of matmul took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
(bmm(q, k.transpose(1, 2)) * scale).max().item() | |
duration = time.perf_counter()-start | |
print('bmm iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of bmm took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
e = empty((1, 1, 1), device='mps') | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
baddbmm(e, q, k.transpose(1, 2), alpha=scale, beta=0).max().item() | |
duration = time.perf_counter()-start | |
print('baddbmm iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of baddbmm took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) |
thanks for this. similar results here (base M1 16GB)
torch 1.12.1:
einsum 0 iteration 0 took 0.1881 seconds
einsum 0 iteration 1 took 0.0805 seconds
einsum 0 iteration 2 took 0.0803 seconds
einsum 0 iteration 3 took 0.0799 seconds
einsum 0 iteration 4 took 0.0802 seconds
einsum 0 iteration 5 took 0.0831 seconds
einsum 0 iteration 6 took 0.0817 seconds
einsum 0 iteration 7 took 0.0814 seconds
einsum 0 iteration 8 took 0.0810 seconds
einsum 0 iteration 9 took 0.0805 seconds
10 iterations of einsum 0 took 0.9165 seconds; avg 0.0917 secs
einsum 0 transposed k iteration 0 took 0.0811 seconds
einsum 0 transposed k iteration 1 took 0.0810 seconds
einsum 0 transposed k iteration 2 took 0.0796 seconds
einsum 0 transposed k iteration 3 took 0.0803 seconds
einsum 0 transposed k iteration 4 took 0.0803 seconds
einsum 0 transposed k iteration 5 took 0.0809 seconds
einsum 0 transposed k iteration 6 took 0.0808 seconds
einsum 0 transposed k iteration 7 took 0.0801 seconds
einsum 0 transposed k iteration 8 took 0.0803 seconds
einsum 0 transposed k iteration 9 took 0.0805 seconds
10 iterations of einsum 0 transposed k took 0.8051 seconds; avg 0.0805 secs
matmul iteration 0 took 0.0996 seconds
matmul iteration 1 took 0.0830 seconds
matmul iteration 2 took 0.0823 seconds
matmul iteration 3 took 0.0844 seconds
matmul iteration 4 took 0.0856 seconds
matmul iteration 5 took 0.0834 seconds
matmul iteration 6 took 0.0823 seconds
matmul iteration 7 took 0.0826 seconds
matmul iteration 8 took 0.0828 seconds
matmul iteration 9 took 0.0821 seconds
10 iterations of matmul took 0.8483 seconds; avg 0.0848 secs
bmm iteration 0 took 0.0794 seconds
bmm iteration 1 took 0.0796 seconds
bmm iteration 2 took 0.0798 seconds
bmm iteration 3 took 0.0812 seconds
bmm iteration 4 took 0.0800 seconds
bmm iteration 5 took 0.0805 seconds
bmm iteration 6 took 0.0894 seconds
bmm iteration 7 took 0.0887 seconds
bmm iteration 8 took 0.0802 seconds
bmm iteration 9 took 0.0799 seconds
10 iterations of bmm took 0.8188 seconds; avg 0.0819 secs
baddbmm iteration 0 took 0.0535 seconds
baddbmm iteration 1 took 0.0456 seconds
baddbmm iteration 2 took 0.0457 seconds
baddbmm iteration 3 took 0.0456 seconds
baddbmm iteration 4 took 0.0454 seconds
baddbmm iteration 5 took 0.0451 seconds
baddbmm iteration 6 took 0.0458 seconds
baddbmm iteration 7 took 0.0458 seconds
baddbmm iteration 8 took 0.0457 seconds
baddbmm iteration 9 took 0.0453 seconds
10 iterations of baddbmm took 0.4635 seconds; avg 0.0463 secs
torch 1.14.0.dev20221121:
einsum 0 iteration 0 took 0.6219 seconds
einsum 0 iteration 1 took 0.1123 seconds
einsum 0 iteration 2 took 0.1147 seconds
einsum 0 iteration 3 took 0.1134 seconds
einsum 0 iteration 4 took 0.1126 seconds
einsum 0 iteration 5 took 0.1126 seconds
einsum 0 iteration 6 took 0.1122 seconds
einsum 0 iteration 7 took 0.1137 seconds
einsum 0 iteration 8 took 0.1137 seconds
einsum 0 iteration 9 took 0.1126 seconds
10 iterations of einsum 0 took 1.6395 seconds; avg 0.1640 secs
einsum 0 transposed k iteration 0 took 0.1133 seconds
einsum 0 transposed k iteration 1 took 0.1125 seconds
einsum 0 transposed k iteration 2 took 0.1123 seconds
einsum 0 transposed k iteration 3 took 0.1128 seconds
einsum 0 transposed k iteration 4 took 0.1128 seconds
einsum 0 transposed k iteration 5 took 0.1127 seconds
einsum 0 transposed k iteration 6 took 0.1129 seconds
einsum 0 transposed k iteration 7 took 0.1128 seconds
einsum 0 transposed k iteration 8 took 0.1131 seconds
einsum 0 transposed k iteration 9 took 0.1126 seconds
10 iterations of einsum 0 transposed k took 1.1277 seconds; avg 0.1128 secs
matmul iteration 0 took 0.0979 seconds
matmul iteration 1 took 0.0800 seconds
matmul iteration 2 took 0.0808 seconds
matmul iteration 3 took 0.0812 seconds
matmul iteration 4 took 0.0794 seconds
matmul iteration 5 took 0.0792 seconds
matmul iteration 6 took 0.0794 seconds
matmul iteration 7 took 0.0796 seconds
matmul iteration 8 took 0.0790 seconds
matmul iteration 9 took 0.0791 seconds
10 iterations of matmul took 0.8157 seconds; avg 0.0816 secs
bmm iteration 0 took 0.0765 seconds
bmm iteration 1 took 0.0956 seconds
bmm iteration 2 took 0.0757 seconds
bmm iteration 3 took 0.0763 seconds
bmm iteration 4 took 0.0773 seconds
bmm iteration 5 took 0.0764 seconds
bmm iteration 6 took 0.0768 seconds
bmm iteration 7 took 0.0768 seconds
bmm iteration 8 took 0.0767 seconds
bmm iteration 9 took 0.0763 seconds
10 iterations of bmm took 0.7844 seconds; avg 0.0784 secs
baddbmm iteration 0 took 0.4508 seconds
baddbmm iteration 1 took 0.0462 seconds
baddbmm iteration 2 took 0.0442 seconds
baddbmm iteration 3 took 0.0445 seconds
baddbmm iteration 4 took 0.0442 seconds
baddbmm iteration 5 took 0.0453 seconds
baddbmm iteration 6 took 0.0450 seconds
baddbmm iteration 7 took 0.0452 seconds
baddbmm iteration 8 took 0.0448 seconds
baddbmm iteration 9 took 0.0453 seconds
10 iterations of baddbmm took 0.8555 seconds; avg 0.0855 secs
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
8 Heun steps
1.12.1
einsum:
9.688710416987306
baddbmm:
9.598911916022189
1.14.0.dev20221103
einsum:
10.383701542014023
baddbmm
9.281007582991151