import os
import random
import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|  |
pip3 install matplotlib torch torchvision numpy pillowdataroot = "art/"
workers = 10
batch_size = 128
image_size = 64
nc = 3
nz = 100
ngf = 64
ndf = 64dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5),
(.5, .5, .5)),
]))
dataloader = torch.utils.data.DataLoader(dataset,device = torch.device("cuda")
real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.imshow(
np.transpose(
vutils.make_grid(real_batch[0].to(device)[:64],
padding=2,
normalize=True).cpu(), (1, 2,0)))def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') !=-1:
nn.init.normal_(m.weight.data,0.0,.02)
elif classname.find('BatchNorm') !=-1:
nn.init.normal_(m.weight.data,1.0,.02)
nn.init.constant_(m.bias.data,0)class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),netG = Generator(ngpu).to(device)
# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
netG.apply(weights_init)
print(netG)class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32