Last active
November 22, 2022 21:10
-
-
Save Birch-san/8f3eb99deffdc3541595e46a01605dea to your computer and use it in GitHub Desktop.
benchmark: batched matmul with scale factor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import einsum, tensor, matmul, bmm, baddbmm, empty | |
import time | |
scale=2 | |
repeats = 10 | |
# both einsum 0s use the same plan, so whichever batch runs first has to pay the price of warmup | |
# uncomment this to run a warmup before either batch runs, for fairer comparison of batch avg time | |
# q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
# k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
# start = time.perf_counter() | |
# (einsum('b i d, b j d -> b i j', q, k) * scale).max().item() | |
# duration = time.perf_counter()-start | |
# print('einsum 0 warmup took %.4f seconds' % (duration)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
(einsum('b i d, b j d -> b i j', q, k) * scale).max().item() | |
duration = time.perf_counter()-start | |
print('einsum 0 iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of einsum 0 took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
(einsum('b n m, b m p -> b n p', q, k.transpose(1, 2)) * scale).max().item() | |
duration = time.perf_counter()-start | |
print('einsum 0 transposed k iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of einsum 0 transposed k took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
(matmul(q, k.transpose(1, 2)) * scale).max().item() | |
duration = time.perf_counter()-start | |
print('matmul iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of matmul took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
(bmm(q, k.transpose(1, 2)) * scale).max().item() | |
duration = time.perf_counter()-start | |
print('bmm iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of bmm took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) | |
e = empty((1, 1, 1), device='mps') | |
batch_duration = 0 | |
for ix in range(repeats): | |
q = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
k = torch.rand(16, 4096, 40, dtype=torch.float, device="mps") | |
start = time.perf_counter() | |
baddbmm(e, q, k.transpose(1, 2), alpha=scale, beta=0).max().item() | |
duration = time.perf_counter()-start | |
print('baddbmm iteration %d took %.4f seconds' % (ix, duration)) | |
batch_duration += duration | |
print('%d iterations of baddbmm took %.4f seconds; avg %.4f secs' % (repeats, batch_duration, batch_duration/repeats)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
thanks for this. similar results here (base M1 16GB)
torch 1.12.1:
torch 1.14.0.dev20221121: