Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created January 26, 2020 05:27
Show Gist options
  • Save ptrblck/19a5685a118f6b1e4e3cb61c7a8cd55e to your computer and use it in GitHub Desktop.
Save ptrblck/19a5685a118f6b1e4e3cb61c7a8cd55e to your computer and use it in GitHub Desktop.
import torch
import torch.nn.functional as F
import time
def test(input, kernel, target):
# Warmup
for _ in range(50):
output = F.conv3d(input, kernel)
g0 = torch.rand_like(output)
for _ in range(50):
output = F.conv3d(input, kernel)
output.backward(g0)
nb_iters = 100
torch.cuda.synchronize()
start = time.time()
for _ in range(nb_iters):
output = F.conv3d(input, kernel)
torch.cuda.synchronize()
end = time.time()
fwd_time = (end - start) / nb_iters
# Profile backward pass
torch.cuda.synchronize()
start = time.time()
for _ in range(nb_iters):
output = F.conv3d(input, kernel)
kernel.grad = None
output.backward(g0)
torch.cuda.synchronize()
end = time.time()
all_time = (end - start) / nb_iters
bwd_time = all_time - fwd_time
print('cudnn={}'.format(torch.backends.cudnn.enabled))
print('cudnn.benchmark={}'.format(torch.backends.cudnn.benchmark))
print('Forward took {}ms/iter'.format(fwd_time*1e3))
print('Backward took {}ms/iter'.format(bwd_time*1e3))
if __name__=='__main__':
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
input = torch.rand(1, 2, 33, 33, 33, device="cuda")
kernel = torch.rand(729, 2, 2, 2, 2, device="cuda", requires_grad=True)
target = torch.rand(1, 729, 32, 32, 32, device="cuda")
test(input, kernel, target)
torch.backends.cudnn.enabled = True
test(input, kernel, target)
torch.backends.cudnn.benchmark = True
test(input, kernel, target)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment