Last active
August 18, 2017 16:08
-
-
Save t-vi/9f6118ff84867e89f3348707c7a1271f to your computer and use it in GitHub Desktop.
Torch validation set split (MNIST example)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.utils.data | |
from torchvision import datasets, transforms | |
class PartialDataset(torch.utils.data.Dataset): | |
def __init__(self, parent_ds, offset, length): | |
self.parent_ds = parent_ds | |
self.offset = offset | |
self.length = length | |
assert len(parent_ds)>=offset+length, Exception("Parent Dataset not long enough") | |
super(PartialDataset, self).__init__() | |
def __len__(self): | |
return self.length | |
def __getitem__(self, i): | |
return self.parent_ds[i+self.offset] | |
def validation_split(dataset, val_share=0.1): | |
""" | |
Split a (training and vaidation combined) dataset into training and validation. | |
Note that to be statistically sound, the items in the dataset should be statistically | |
independent (e.g. not sorted by class, not several instances of the same dataset that | |
could end up in either set). | |
inputs: | |
dataset: ("training") dataset to split into training and validation | |
val_share: fraction of validation data (should be 0<val_share<1, default: 0.1) | |
returns: input dataset split into test_ds, val_ds | |
""" | |
val_offset = int(len(dataset)*(1-val_share)) | |
return PartialDataset(dataset, 0, val_offset), PartialDataset(dataset, val_offset, len(dataset)-val_offset) | |
mnist_train_ds = datasets.MNIST(os.path.expanduser('~/data/datasets/mnist'), train=True, download=True, | |
transform=transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
])) | |
train_ds, val_ds = validation_split(mnist_train_ds) | |
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True, **kwargs) | |
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=64, shuffle=True, **kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment