Skip to content

Instantly share code, notes, and snippets.

@xpzouying
Last active October 30, 2018 14:29
Show Gist options
  • Save xpzouying/8078000043a0713056ba14490a90807b to your computer and use it in GitHub Desktop.
Save xpzouying/8078000043a0713056ba14490a90807b to your computer and use it in GitHub Desktop.
# uncomments the following line if use notebook
# %matplotlib inline
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
data_dir = '/tmp/images2/'
data_dir = os.path.expanduser(data_dir)
train_transform= transforms.Compose([
transforms.Scale([512, 512]),
transforms.ToTensor(),
])
trainset = ImageFolder(data_dir, transform=train_transform)
train_dataloader = DataLoader(trainset, batch_size=4, shuffle=True)
def imshow(inp_tensor, title=None):
"""image show for Tensor"""
img = inp_tensor.numpy().transpose((1, 2, 0))
plt.imshow(img)
if title is not None:
plt.title(title)
plt.pause(0.001)
trainset_iter = iter(train_dataloader)
images, labels = trainset_iter.next()
batch_imgs = torchvision.utils.make_grid(images)
imshow(batch_imgs, title=[trainset.classes[x] for x in labels])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment