Skip to content

Instantly share code, notes, and snippets.

@n2cholas
Created June 6, 2020 21:09
Show Gist options
  • Save n2cholas/b6f04900321f0a49378da9ac22331aaf to your computer and use it in GitHub Desktop.
Save n2cholas/b6f04900321f0a49378da9ac22331aaf to your computer and use it in GitHub Desktop.
import sys
import time
import torch
from torch import nn
import torchvision
from torchvision import transforms
from ignite.metrics import Accuracy
acc_type = sys.argv[1] # custom or ignite
acc_device = 'cpu' if acc_type == 'ignite' else sys.argv[2]
# Set up Classes =============================================================
class MyAccuracy(Accuracy):
def reset(self) -> None:
self._num_correct = torch.tensor(0., device=acc_device)
self._num_examples = 0
super().reset()
def update(self, output):
y_pred, y = output
self._check_shape((y_pred, y))
self._check_type((y_pred, y))
if self._type == "binary":
correct = torch.eq(y_pred.view(-1).to(y), y.view(-1))
elif self._type == "multiclass":
indices = torch.argmax(y_pred, dim=1)
correct = torch.eq(indices, y).view(-1)
elif self._type == "multilabel":
# if y, y_pred shape is (N, C, ...) -> (N x ..., C)
num_classes = y_pred.size(1)
last_dim = y_pred.ndimension()
y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
correct = torch.all(y == y_pred.type_as(y), dim=-1)
self._num_correct += torch.sum(correct).to(acc_device)
self._num_examples += correct.shape[0]
def Net():
return nn.Sequential(
nn.Conv2d(3, 128, 5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(256, 512, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(512, 1024, 3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveMaxPool2d((1,1)),
nn.Flatten(),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 10),
)
# Set up Data, Network, etc ==================================================
transform = torchvision.transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=512, num_workers=8, pin_memory=True)
net = Net().cuda()
acc_metric = Accuracy() if acc_type == 'ignite' else MyAccuracy()
# Run Profiler ===============================================================
loader_iter = iter(dataloader)
time.sleep(15.) # preload some batches so the trace is more condensed
with torch.autograd.profiler.profile(enabled=True, use_cuda=True) as p:
for i in range(4):
data = next(loader_iter)
inputs, labels = (data[0].cuda(non_blocking=True),
data[1].cuda(non_blocking=True))
outputs = net(inputs)
acc_metric.update((outputs, labels))
trace_file = f'{acc_type}_{acc_device}_trace'
p.export_chrome_trace(trace_file)
print(f'Done trace: {trace_file}')
# Time Validation Loop =======================================================
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for i in range(50):
start.record()
for i, data in enumerate(dataloader):
inputs, labels = (data[0].cuda(non_blocking=True),
data[1].cuda(non_blocking=True))
outputs = net(inputs)
acc_metric.update((outputs, labels))
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end)*0.001)
std, mean = torch.std_mean(torch.tensor(times))
print(f'Mean Time: {mean.item()}s\nStd Time: {std.item()}s\nAll Times: {times}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment