Created
September 13, 2017 19:41
-
-
Save ikhlestov/0f174783eb8b37a77ab34c07f21ccd6a to your computer and use it in GitHub Desktop.
pytorch: custom data loader
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision as tv | |
class ImagesDataset(torch.utils.data.Dataset): | |
def __init__(self, df, transform=None, | |
loader=tv.datasets.folder.default_loader): | |
self.df = df | |
self.transform = transform | |
self.loader = loader | |
def __getitem__(self, index): | |
row = self.df.iloc[index] | |
target = row['class_'] | |
path = row['path'] | |
img = self.loader(path) | |
if self.transform is not None: | |
img = self.transform(img) | |
return img, target | |
def __len__(self): | |
n, _ = self.df.shape | |
return n | |
# what transformations should be done with our images | |
data_transforms = tv.transforms.Compose([ | |
tv.transforms.RandomCrop((64, 64), padding=4), | |
tv.transforms.RandomHorizontalFlip(), | |
tv.transforms.ToTensor(), | |
]) | |
train_df = pd.read_csv('path/to/some.csv') | |
# initialize our dataset at first | |
train_dataset = ImagesDataset( | |
df=train_df, | |
transform=data_transforms | |
) | |
# initialize data loader with required number of workers and other params | |
train_loader = torch.utils.data.DataLoader(train_dataset, | |
batch_size=10, | |
shuffle=True, | |
num_workers=16) | |
# fetch the batch(call to `__getitem__` method) | |
for img, target in train_loader: | |
pass | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment