Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created August 7, 2019 10:45
Show Gist options
  • Save ptrblck/0e84ecc6b29fa17b0b230fb0c52415c1 to your computer and use it in GitHub Desktop.
Save ptrblck/0e84ecc6b29fa17b0b230fb0c52415c1 to your computer and use it in GitHub Desktop.
import torch
import time
torch.backends.cudnn.benchmark = True
# 1a)
I, J, K = 64, 1024, 1024
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)
# warumup
for _ in range(50):
C = torch.matmul(A, B)
torch.cuda.synchronize()
nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('1a) {:.3f}us per iteration)'.format((t1 - t0) / nb_iters * 1e6))
# 1b)
I, J, K = 1, 1024, 1024
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)
# warumup
for _ in range(50):
C = torch.matmul(A, B)
torch.cuda.synchronize()
nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('1b) {:.3f}us per iteration'.format((t1 - t0) / nb_iters * 1e6))
# 2a)
I, J, K = 63, 1023, 1023
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)
# warumup
for _ in range(50):
C = torch.matmul(A, B)
torch.cuda.synchronize()
nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('2a) {:.3f}us per iteration'.format((t1 - t0) / nb_iters * 1e6))
# 2b)
I, J, K = 1, 1023, 1023
A = torch.randn(I, J, device='cuda', dtype=torch.half)
B = torch.randn(J, K, device='cuda', dtype=torch.half)
# warumup
for _ in range(50):
C = torch.matmul(A, B)
torch.cuda.synchronize()
nb_iters = 1000
torch.cuda.synchronize()
t0 = time.time()
for _ in range(nb_iters):
C = torch.matmul(A, B)
torch.cuda.synchronize()
t1 = time.time()
print('2b) {:.3f}us per iteration'.format((t1 - t0) / nb_iters * 1e6))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment