Skip to content

Instantly share code, notes, and snippets.

@n2cholas
Created June 6, 2020 06:19
Show Gist options
  • Save n2cholas/654850d6345242c4406e0bef23e3d060 to your computer and use it in GitHub Desktop.
Save n2cholas/654850d6345242c4406e0bef23e3d060 to your computer and use it in GitHub Desktop.
import sys
import time
import torch
from torch import nn
import torchvision
from torchvision import transforms
from ignite.metrics import Accuracy
acc_type = sys.argv[1] # custom or ignite
acc_device = 'cpu' if acc_type == 'ignite' else sys.argv[2]
# Set up Classes =============================================================
class MyAccuracy(Accuracy):
def reset(self) -> None:
self._num_correct = torch.tensor(0., device=acc_device)
self._num_examples = 0
super().reset()
def update(self, output):
y_pred, y = output
self._check_shape((y_pred, y))
self._check_type((y_pred, y))
if self._type == "binary":
correct = torch.eq(y_pred.view(-1).to(y), y.view(-1))
elif self._type == "multiclass":
indices = torch.argmax(y_pred, dim=1)
correct = torch.eq(indices, y).view(-1)
elif self._type == "multilabel":
# if y, y_pred shape is (N, C, ...) -> (N x ..., C)
num_classes = y_pred.size(1)
last_dim = y_pred.ndimension()
y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
correct = torch.all(y == y_pred.type_as(y), dim=-1)
self._num_correct += torch.sum(correct).to(acc_device)
self._num_examples += correct.shape[0]
def Net():
return nn.Sequential(
nn.Conv2d(3, 128, 5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(256, 512, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(512, 1024, 3, padding=1),
nn.ReLU(inplace=True),
nn.AdaptiveMaxPool2d((1,1)),
nn.Flatten(),
nn.Linear(1024, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 10)
)
# Set up Data, Network, etc ==================================================
transform = torchvision.transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(
dataset, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)
net = Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
acc_metric = Accuracy() if acc_type == 'ignite' else MyAccuracy()
# Run Profiler ===============================================================
loader_iter = iter(dataloader)
time.sleep(15.) # preload some batches so the trace is more condensed
with torch.autograd.profiler.profile(enabled=True, use_cuda=True) as p:
for i in range(4):
data = next(loader_iter)
inputs, labels = (data[0].cuda(non_blocking=True),
data[1].cuda(non_blocking=True))
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
acc_metric.update((outputs, labels))
trace_file = f'{acc_type}_{acc_device}_trace'
p.export_chrome_trace(trace_file)
print(f'Done trace: {trace_file}')
# Time Train Loop ============================================================
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for i in range(20):
start.record()
for i, data in enumerate(dataloader):
inputs, labels = (data[0].cuda(non_blocking=True),
data[1].cuda(non_blocking=True))
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
acc_metric.update((outputs, labels))
if i > 500: break
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end)*0.001)
std, mean = torch.std_mean(torch.tensor(times))
print(f'Mean Time: {mean.item()}s Std Time: {std.item()}s All Times: {times}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment