Last active
September 9, 2024 22:23
-
-
Save srikarplus/8bdb5bedf0ca25e894e39ea78fce2f39 to your computer and use it in GitHub Desktop.
Train and Validation Split for Pytorch torchvision Datasets
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
from utils import plot_images | |
from torchvision import datasets | |
from torchvision import transforms | |
from torch.utils.data.sampler import SubsetRandomSampler | |
def get_train_valid_loader(data_dir, | |
batch_size, | |
augment, | |
random_seed, | |
valid_size=0.1, | |
shuffle=True, | |
show_sample=False, | |
num_workers=4, | |
pin_memory=False): | |
""" | |
Params | |
------ | |
- data_dir: path directory to the dataset. | |
- batch_size: how many samples per batch to load. | |
- augment: whether to apply the data augmentation scheme | |
mentioned in the paper. Only applied on the train split. | |
- random_seed: fix seed for reproducibility. | |
- valid_size: percentage split of the training set used for | |
the validation set. Should be a float in the range [0, 1]. | |
- shuffle: whether to shuffle the train/validation indices. | |
- show_sample: plot 9x9 sample grid of the dataset. | |
- num_workers: number of subprocesses to use when loading the dataset. | |
- pin_memory: whether to copy tensors into CUDA pinned memory. Set it to | |
True if using GPU. | |
Returns | |
------- | |
- train_loader: training set iterator. | |
- valid_loader: validation set iterator. | |
""" | |
error_msg = "[!] valid_size should be in the range [0, 1]." | |
assert ((valid_size >= 0) and (valid_size <= 1)), error_msg | |
normalize = transforms.Normalize( | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225], | |
) | |
# define transforms | |
valid_transform = transforms.Compose([transforms.Resize(255), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalize]) | |
if augment: | |
train_transform = transforms.Compose([transforms.RandomRotation(30), | |
transforms.RandomResizedCrop(224), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
normalize]) | |
else: | |
train_transform = transforms.Compose([transforms.Resize(255), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
normalize]) | |
# load the dataset | |
train_dataset = datasets.ImageFolder( | |
root=data_dir, transform=train_transform, | |
) | |
valid_dataset = datasets.ImageFolder( | |
root=data_dir, transform=valid_transform, | |
) | |
num_train = len(train_dataset) | |
indices = list(range(num_train)) | |
split = int(np.floor(valid_size * num_train)) | |
if shuffle: | |
np.random.seed(random_seed) | |
np.random.shuffle(indices) | |
train_idx, valid_idx = indices[split:], indices[:split] | |
train_sampler = SubsetRandomSampler(train_idx) | |
valid_sampler = SubsetRandomSampler(valid_idx) | |
train_loader = torch.utils.data.DataLoader( | |
train_dataset, batch_size=batch_size, sampler=train_sampler, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
valid_loader = torch.utils.data.DataLoader( | |
valid_dataset, batch_size=batch_size, sampler=valid_sampler, | |
num_workers=num_workers, pin_memory=pin_memory, | |
) | |
return (train_loader, valid_loader) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment