Last active
February 2, 2020 00:52
-
-
Save nlgranger/c02a015af63bd1d18a484ebb34e408b4 to your computer and use it in GitHub Desktop.
Pytorch's Dataloader reimplemented using SeqTools
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""A reimplementeation of PyTorch's DataLoader to showcase seqtools. | |
:author: Nicolas Granger | |
:license: 0BSD (~public domain) | |
""" | |
import numbers | |
import random | |
from functools import singledispatch | |
from multiprocessing import sharedctypes | |
import numpy as np | |
import torch | |
import seqtools | |
@singledispatch | |
def into_tensors(value): | |
return torch.tensor(value) | |
@into_tensors.register(torch.Tensor) | |
def _(value): | |
return value | |
@into_tensors.register(np.ndarray) | |
def _(value): | |
return torch.from_numpy(value) | |
@into_tensors.register(tuple) | |
def _(value): | |
return tuple(into_tensors(v) for v in value) | |
@into_tensors.register(list) | |
def _(value): | |
return [into_tensors(v) for v in value] | |
@into_tensors.register(dict) | |
def _(value): | |
return {k: into_tensors(v) for k, v in value.items()} | |
@singledispatch | |
def pin_memory(value): | |
return value.pin_memory() | |
@pin_memory.register(tuple) | |
def _(value): | |
return tuple(pin_memory(v) for v in value) | |
@pin_memory.register(list) | |
def _(value): | |
return [pin_memory(v) for v in value] | |
@pin_memory.register(dict) | |
def _(value): | |
return {k: pin_memory(v) for k, v in value.items()} | |
@singledispatch | |
def default_collate_fn(values): | |
if not isinstance(values, (list, tuple)): | |
values = list(values) | |
sample = values[0] | |
if isinstance(sample, torch.Tensor): | |
return torch.stack(values) | |
elif isinstance(sample, np.ndarray): | |
return np.stack(values) | |
elif isinstance(sample, numbers.Integral): | |
return np.array(values) | |
elif isinstance(sample, tuple): | |
return tuple(default_collate_fn(row) for row in zip(*values)) | |
elif isinstance(sample, list): | |
return [default_collate_fn(*row) for row in zip(*values)] | |
elif isinstance(sample, dict): | |
return {k: default_collate_fn(*(v[k] for v in values)) | |
for k in values[0].keys()} | |
class DataLoader: | |
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, | |
batch_sampler=None, num_workers=0, collate_fn=None, | |
pin_memory=False, drop_last=False, timeout=0, | |
worker_init_fn=None): | |
""" | |
Natbale differences: | |
- all dataset elements must have the same shape | |
- custom samplers are not implemented | |
- timeout is irrelevant | |
""" | |
if sampler is not None or batch_sampler is not None: | |
raise NotImplementedError("custom samplers are not supported yet") | |
# shuffle | |
if shuffle: | |
self.shuffling_indexes = memoryview(sharedctypes.RawArray('L', range(len(dataset)))) | |
dataset = seqtools.gather(dataset, self.shuffling_indexes) | |
else: | |
self.shuffling_indexes = None | |
if batch_size is not None: | |
dataset = seqtools.batch( | |
dataset, | |
k=batch_size, drop_last=drop_last, | |
collate_fn=collate_fn or default_collate_fn) | |
if num_workers > 0: | |
dataset = seqtools.prefetch( | |
dataset, | |
max_buffered=num_workers * 10, nworkers=num_workers, method='sharedmem', | |
start_hook=worker_init_fn) | |
# convert values into tensors | |
dataset = seqtools.smap(into_tensors, dataset) | |
if pin_memory: | |
dataset = seqtools.smap(pin_memory, dataset) | |
self.dataset = dataset | |
def __len__(self): | |
return len(self.dataset) | |
def __iter__(self): | |
if self.shuffling_indexes is not None: | |
random.shuffle(self.shuffling_indexes) | |
return iter(self.dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapted from https://github.com/pytorch/examples/blob/master/mnist/main.py | |
import argparse | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from torchvision import datasets, transforms | |
from torch.optim.lr_scheduler import StepLR | |
import seqtools | |
from docs.examples.dataloader import DataLoader | |
class Net(nn.Module): | |
def __init__(self): | |
super(Net, self).__init__() | |
self.conv1 = nn.Conv2d(1, 32, 3, 1) | |
self.conv2 = nn.Conv2d(32, 64, 3, 1) | |
self.dropout1 = nn.Dropout2d(0.25) | |
self.dropout2 = nn.Dropout2d(0.5) | |
self.fc1 = nn.Linear(9216, 128) | |
self.fc2 = nn.Linear(128, 10) | |
def forward(self, x): | |
x = self.conv1(x) | |
x = F.relu(x) | |
x = self.conv2(x) | |
x = F.max_pool2d(x, 2) | |
x = self.dropout1(x) | |
x = torch.flatten(x, 1) | |
x = self.fc1(x) | |
x = F.relu(x) | |
x = self.dropout2(x) | |
x = self.fc2(x) | |
output = F.log_softmax(x, dim=1) | |
return output | |
def train(args, model, device, train_loader, optimizer, epoch): | |
model.train() | |
for batch_idx, (data, target) in enumerate(train_loader): | |
data, target = data.to(device), target.to(device) | |
optimizer.zero_grad() | |
output = model(data) | |
loss = F.nll_loss(output, target) | |
loss.backward() | |
optimizer.step() | |
if batch_idx % args.log_interval == 0: | |
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( | |
epoch, batch_idx * len(data), len(train_loader.dataset), | |
100. * batch_idx / len(train_loader), loss.item())) | |
def test(args, model, device, test_loader): | |
model.eval() | |
test_loss = 0 | |
correct = 0 | |
with torch.no_grad(): | |
for data, target in test_loader: | |
data, target = data.to(device), target.to(device) | |
output = model(data) | |
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss | |
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability | |
correct += pred.eq(target.view_as(pred)).sum().item() | |
test_loss /= len(test_loader.dataset) | |
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( | |
test_loss, correct, len(test_loader.dataset), | |
100. * correct / len(test_loader.dataset))) | |
def main(): | |
# Training settings | |
parser = argparse.ArgumentParser(description='PyTorch MNIST Example') | |
parser.add_argument('--batch-size', type=int, default=64, metavar='N', | |
help='input batch size for training (default: 64)') | |
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', | |
help='input batch size for testing (default: 1000)') | |
parser.add_argument('--epochs', type=int, default=14, metavar='N', | |
help='number of epochs to train (default: 14)') | |
parser.add_argument('--lr', type=float, default=1.0, metavar='LR', | |
help='learning rate (default: 1.0)') | |
parser.add_argument('--gamma', type=float, default=0.7, metavar='M', | |
help='Learning rate step gamma (default: 0.7)') | |
parser.add_argument('--no-cuda', action='store_true', default=False, | |
help='disables CUDA training') | |
parser.add_argument('--seed', type=int, default=1, metavar='S', | |
help='random seed (default: 1)') | |
parser.add_argument('--log-interval', type=int, default=10, metavar='N', | |
help='how many batches to wait before logging training status') | |
parser.add_argument('--save-model', action='store_true', default=False, | |
help='For Saving the current Model') | |
args = parser.parse_args() | |
use_cuda = not args.no_cuda and torch.cuda.is_available() | |
torch.manual_seed(args.seed) | |
device = torch.device("cuda" if use_cuda else "cpu") | |
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {'num_workers': 1} | |
train_loader = DataLoader( | |
seqtools.starmap( | |
lambda img, label: ((np.asarray(img, dtype=np.float32)[None] - 0.1307) / 0.3081, label), | |
datasets.MNIST('/tmp/MNIST', train=True, download=True)), | |
batch_size=args.batch_size, shuffle=True, **kwargs) | |
test_loader = DataLoader( | |
seqtools.starmap( | |
lambda img, label: ((np.asarray(img, dtype=np.float32)[None] - 0.1307) / 0.3081, label), | |
datasets.MNIST('/tmp/MNIST', train=False, download=True)), | |
batch_size=args.test_batch_size, shuffle=True, **kwargs) | |
model = Net().to(device) | |
optimizer = optim.Adadelta(model.parameters(), lr=args.lr) | |
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) | |
for epoch in range(1, args.epochs + 1): | |
train(args, model, device, train_loader, optimizer, epoch) | |
test(args, model, device, test_loader) | |
scheduler.step() | |
if args.save_model: | |
torch.save(model.state_dict(), "mnist_cnn.pt") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment