Skip to content

Instantly share code, notes, and snippets.

@MLWhiz
Last active July 30, 2020 08:38
Show Gist options
  • Select an option

  • Save MLWhiz/f7108eeb4b0f12c33717136ae67e901a to your computer and use it in GitHub Desktop.

Select an option

Save MLWhiz/f7108eeb4b0f12c33717136ae67e901a to your computer and use it in GitHub Desktop.
# Root directory for dataset
dataroot = "anime_images/"
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = datasets.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment