-
-
Save GzuPark/ebdec00f6fd253e5a762a975a910ef55 to your computer and use it in GitHub Desktop.
Train, Validation and Test Split for torchvision Datasets
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Create train, valid, test iterators for CIFAR-10 [1]. | |
Easily extended to MNIST, CIFAR-100 and Imagenet. | |
[1]: https://discuss.pytorch.org/t/feedback-on-pytorch-for-kaggle-competitions/2252/4 | |
""" | |
import torch | |
import numpy as np | |
from utils import plot_images | |
from torchvision import datasets | |
from torchvision import transforms | |
from torch.utils.data.sampler import SubsetRandomSampler | |
def get_train_valid_loader(data_dir, | |
batch_size, | |
augment, | |
random_seed, | |
valid_size=0.1, | |
shuffle=True, | |
show_sample=False, | |
num_workers=4, | |
pin_memory=False): | |
""" | |
Utility function for loading and returning train and valid | |
multi-process iterators over the CIFAR-10 dataset. A sample | |
9x9 grid of the images can be optionally displayed. | |
If using CUDA, num_workers should be set to 1 and pin_memory to True. | |
Params | |
------ | |
- data_dir: path directory to the dataset. | |
- batch_size: how many samples per batch to load. | |
- augment: whether to apply the data augmentation scheme | |
mentioned in the paper. Only applied on the train split. | |
- random_seed: fix seed for reproducibility. | |
- valid_size: percentage split of the training set used for | |
the validation set. Should be a float in the range [0, 1]. | |
- shuffle: whether to shuffle the train/validation indices. | |
- show_sample: plot 9x9 sample grid of the dataset. | |
- num_workers: number of subprocesses to use when loading the dataset. | |
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to | |
True if using GPU. | |
Returns | |
------- | |
- train_loader: training set iterator. | |
- valid_loader: validation set iterator. | |
""" | |
error_msg = "[!] valid_size should be in the range [0, 1]." | |
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
) | |
# define transforms | |
valid_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
normalize, | |
]) | |
if augment: | |
train_transform = transforms.Compose([ | |
transforms.RandomCrop(32, padding=4), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize, | |
]) | |
else: | |
train_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
normalize, | |
]) | |
# load the dataset | |
train_dataset = datasets.CIFAR10( | |
root=data_dir, train=True, | |
download=True, transform=train_transform, | |
) | |
valid_dataset = datasets.CIFAR10( | |
root=data_dir, train=True, | |
download=True, transform=valid_transform, | |
) | |
num_train = len(train_dataset) | |
indices = list(range(num_train)) | |
split = int(np.floor(valid_size * num_train)) | |
if shuffle: | |
np.random.seed(random_seed) | |
np.random.shuffle(indices) | |
train_idx, valid_idx = indices[split:], indices[:split] | |
train_sampler = SubsetRandomSampler(train_idx) | |
valid_sampler = SubsetRandomSampler(valid_idx) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=batch_size, sampler=train_sampler, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
valid_loader = torch.utils.data.DataLoader( | |
valid_dataset, batch_size=batch_size, sampler=valid_sampler, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
# visualize some images | |
if show_sample: | |
sample_loader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=9, shuffle=shuffle, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
data_iter = iter(sample_loader) | |
images, labels = data_iter.next() | |
X = images.numpy().transpose([0, 2, 3, 1]) | |
plot_images(X, labels) | |
return (train_loader, valid_loader) | |
def get_test_loader(data_dir, | |
batch_size, | |
shuffle=True, | |
num_workers=4, | |
pin_memory=False): | |
""" | |
Utility function for loading and returning a multi-process | |
test iterator over the CIFAR-10 dataset. | |
If using CUDA, num_workers should be set to 1 and pin_memory to True. | |
Params | |
------ | |
- data_dir: path directory to the dataset. | |
- batch_size: how many samples per batch to load. | |
- shuffle: whether to shuffle the dataset after every epoch. | |
- num_workers: number of subprocesses to use when loading the dataset. | |
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to | |
True if using GPU. | |
Returns | |
------- | |
- data_loader: test set iterator. | |
""" | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
) | |
# define transform | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
normalize, | |
]) | |
dataset = datasets.CIFAR10( | |
root=data_dir, train=False, | |
download=True, transform=transform, | |
) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, batch_size=batch_size, shuffle=shuffle, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
return data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
label_names = [ | |
'airplane', | |
'automobile', | |
'bird', | |
'cat', | |
'deer', | |
'dog', | |
'frog', | |
'horse', | |
'ship', | |
'truck' | |
] | |
def plot_images(images, cls_true, cls_pred=None): | |
""" | |
Adapted from https://github.com/Hvass-Labs/TensorFlow-Tutorials/ | |
""" | |
fig, axes = plt.subplots(3, 3) | |
for i, ax in enumerate(axes.flat): | |
# plot img | |
ax.imshow(images[i, :, :, :], interpolation='spline16') | |
# show true & predicted classes | |
cls_true_name = label_names[cls_true[i]] | |
if cls_pred is None: | |
xlabel = "{0} ({1})".format(cls_true_name, cls_true[i]) | |
else: | |
cls_pred_name = label_names[cls_pred[i]] | |
xlabel = "True: {0}\nPred: {1}".format( | |
cls_true_name, cls_pred_name | |
) | |
ax.set_xlabel(xlabel) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment