Skip to content

Instantly share code, notes, and snippets.

@ayush29feb
Created November 15, 2018 18:22
Show Gist options
  • Save ayush29feb/8621c7182623dddb1d046dbd41800ed6 to your computer and use it in GitHub Desktop.
Save ayush29feb/8621c7182623dddb1d046dbd41800ed6 to your computer and use it in GitHub Desktop.
Loading data in PyTorch
!pip install Pillow==4.0.0
!pip install PIL
!pip install image
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
dataset = datasets.ImageFolder(
'<path/to/data>',
transforms.Compose([
transforms.ToTensor(),
]))
data_loader = DataLoader(dataset=dataset,
batch_size=2,
shuffle=False,
num_workers=2)
for i, data in enumerate(data_loader):
img, lbl = data
print(img.numpy()[0].T.shape)
plt.imshow(np.moveaxis(img.numpy()[0], 0, -1))
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment