Skip to content

Instantly share code, notes, and snippets.

@BarclayII
Last active March 8, 2019 03:10
Show Gist options
  • Save BarclayII/1f99638701f4d1b8830f55c268a697cc to your computer and use it in GitHub Desktop.
Save BarclayII/1f99638701f4d1b8830f55c268a697cc to your computer and use it in GitHub Desktop.
Batched GEMM profiling for transformers in PyTorch
# coding: utf-8
import torch
import time
import pandas as pd
import tqdm
B, L, N, H, W = 64, 50, 10, 256, 3
print('warming up')
for _ in tqdm.trange(10):
x = torch.randn(B, L, N, W, H).cuda()
y = torch.randn(B, L, N, H, W).cuda()
z = x @ y
torch.cuda.synchronize()
print('star transformer profiling')
ts = []
for _ in tqdm.trange(10):
t = []
x = torch.randn(B, L, N, 1, H).cuda().requires_grad_()
y = torch.randn(B, L, N, H, W).cuda().requires_grad_()
y2 = torch.randn(B, L, N, W, H).cuda().requires_grad_()
torch.cuda.synchronize()
# bmm forward
t0 = time.time()
z1 = x @ y # # of muls: B * L * N * H * W
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
ones = torch.ones_like(z1)
torch.cuda.synchronize()
# bmm backward
t0 = time.time()
z1.backward(ones, retain_graph=True)
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
# bmm backward manual
with torch.no_grad():
t0 = time.time()
gx = ones @ y.transpose(-1, -2)
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
t0 = time.time()
gy = x.transpose(-1, -2) @ ones
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
# mul-sum forward
t0 = time.time()
z2 = (x * y2).sum(-1)
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
ones = torch.ones_like(z2)
torch.cuda.synchronize()
# mul-sum backward
t0 = time.time()
z2.backward(ones, retain_graph=True)
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
ts.append(t)
print(pd.DataFrame(data=ts, columns=['dot', 'dotB', 'dotB x', 'dotB y', 'mulsum', 'mulsumB']).describe())
print('vanilla transformer profiling')
ts = []
for _ in tqdm.trange(10):
t = []
x = torch.randn(B, N, L, H).cuda().requires_grad_()
y = torch.randn(B, N, H, L).cuda().requires_grad_()
y2 = torch.randn(B, N, L, H).cuda().requires_grad_()
torch.cuda.synchronize()
# bmm forward
t0 = time.time()
z1 = x @ y # # of muls: B * N * L * H * L
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
ones = torch.ones_like(z1)
torch.cuda.synchronize()
# bmm backward
t0 = time.time()
z1.backward(ones, retain_graph=True)
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
# bmm backward manual
with torch.no_grad():
t0 = time.time()
gx = ones @ y.transpose(-1, -2)
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
t0 = time.time()
gy = x.transpose(-1, -2) @ ones
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
# mul-sum forward
t0 = time.time()
z2 = (x * y2).sum(-1)
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
ones = torch.ones_like(z2)
torch.cuda.synchronize()
# mul-sum backward
t0 = time.time()
z2.backward(ones, retain_graph=True)
torch.cuda.synchronize()
tt = time.time()
t.append(tt - t0)
ts.append(t)
print(pd.DataFrame(data=ts, columns=['dot', 'dotB', 'dotB x', 'dotB y', 'mulsum', 'mulsumB']).describe())
# my output on V100 goes as follows:
# dot - bmm forward
# dotB - bmm backward
# dotB x - computing gradient wrt first argument manually
# dotB y - computing gradient wrt second argument manually
# mulsum - elementwise multiply followed by a reduce_sum (avoids gemm)
# mulsumB - backward pass of mulsum
# star transformer profiling bmm's (B, L, N, 1, H) and (B, L, N, H, W).
# vanilla transformer bmm's (B, N, L, H) and (B, N, H, L).
# the number of multiplications is B*N*L*H*W and B*N*L*H*L respectively.
#
#star transformer profiling
# dot dotB dotB x dotB y mulsum mulsumB
#count 10.000000 10.000000 10.000000 10.000000 10.000000 10.000000
#mean 0.000367 0.004982 0.001826 0.002916 0.000511 0.000869
#std 0.000172 0.000110 0.000214 0.000015 0.000137 0.000113
#min 0.000288 0.004925 0.001706 0.002899 0.000449 0.000817
#25% 0.000298 0.004941 0.001719 0.002908 0.000459 0.000828
#50% 0.000314 0.004945 0.001731 0.002912 0.000466 0.000833
#75% 0.000328 0.004955 0.001742 0.002918 0.000479 0.000839
#max 0.000855 0.005291 0.002254 0.002942 0.000899 0.001189
#vanilla transformer profiling
# dot dotB dotB x dotB y mulsum mulsumB
#count 10.000000 10.000000 10.000000 10.000000 10.000000 10.000000
#mean 0.000354 0.000543 0.000315 0.000379 0.000227 0.000411
#std 0.000014 0.000015 0.000009 0.000013 0.000010 0.000007
#min 0.000335 0.000526 0.000299 0.000364 0.000214 0.000402
#25% 0.000345 0.000531 0.000310 0.000371 0.000218 0.000406
#50% 0.000351 0.000542 0.000316 0.000375 0.000229 0.000412
#75% 0.000358 0.000551 0.000322 0.000387 0.000232 0.000413
#max 0.000376 0.000571 0.000327 0.000406 0.000244 0.000426
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment