Skip to content

Instantly share code, notes, and snippets.

@jfsantos
Last active February 8, 2017 23:05
Show Gist options
  • Save jfsantos/9d9f6fce543c425a2746ac15c96a9ed4 to your computer and use it in GitHub Desktop.
Save jfsantos/9d9f6fce543c425a2746ac15c96a9ed4 to your computer and use it in GitHub Desktop.
from torch.utils.data import Dataset
class DummyDataset(Dataset):
def __init__(self, items):
super(DummyDataset, self).__init__()
self.items = items
def __getitem__(self, index):
return self.items[index]
def __len__(self):
return len(self.items)
if __name__ == '__main__':
from torch.utils.data import DataLoader
a = list(range(1000))
dataset = DummyDataset(a)
loader = DataLoader(dataset, batch_size=4, shuffle=True)
batches = [x for x in loader]
~
from torch.utils.data import Dataset
class DummyMultiDataset(Dataset):
def __init__(self, a, b, c):
super(DummyMultiDataset, self).__init__()
assert len(a) == len(b) == len(c)
self.a = a
self.b = b
self.c = c
def __getitem__(self, index):
return self.a[index], self.b[index], self.c[index]
def __len__(self):
return len(self.a)
if __name__ == '__main__':
from torch.utils.data import DataLoader
import numpy, torch
a = torch.FloatTensor(numpy.random.randn(1000, 10))
b = torch.LongTensor(numpy.random.randint(10, size=(1000,)))
c = list(range(1000))
dataset = DummyMultiDataset(a, b, c)
loader = DataLoader(dataset, batch_size=4, shuffle=True)
batches = [(x, y, z) for x, y, z in loader]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment