Skip to content

Instantly share code, notes, and snippets.

@ptrblck
Created August 16, 2019 14:57
Show Gist options
  • Save ptrblck/cd38dbdd0570741aa44f2725ccbc1288 to your computer and use it in GitHub Desktop.
Save ptrblck/cd38dbdd0570741aa44f2725ccbc1288 to your computer and use it in GitHub Desktop.
import time
import torch
import torch.nn as nn
def test(cudnn, benchmark, dtype):
print('cudnn {}, benchmark {}, dtype {}'.format(cudnn, benchmark, dtype))
torch.backends.cudnn.enabled = cudnn
torch.backends.cudnn.benchmark = benchmark
dtype = dtype
device = torch.device("cuda")
mod = nn.Conv3d(64, 64, 3).to(device, dtype=dtype)
input_ = torch.randn((8, 64, 64, 64, 64), device=device, dtype=dtype)
for i in range(40):
mod(input_)
torch.cuda.synchronize()
curr = time.time()
for i in range(100):
mod(input_)
torch.cuda.synchronize()
tend = time.time() - curr
print(tend)
return tend
test(True, True, torch.float32)
test(True, False, torch.float32)
test(False, False, torch.float32)
test(True, True, torch.float16)
test(True, False, torch.float16)
test(False, False, torch.float16)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment