Forked from simopal6/gan_failure_normalization.py
Last active
December 1, 2017 09:47
-
-
Save t-vi/5fad0a6181eb9485b25b0935396f8687 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from PIL import Image | |
from torch.utils.data import DataLoader | |
import torchvision | |
from torchvision import transforms, datasets | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.optim | |
import torch.backends.cudnn as cudnn; cudnn.benchmark = True | |
opt_bad = True | |
opt_dataset = "parent_of_n02510455" | |
opt_batch_size = 16 | |
opt_penalty = 10 | |
# Imports | |
# Setup data transforms | |
load_size = 80 | |
crop_size = 64 | |
if opt_bad: | |
mean = (0.485, 0.456, 0.406) | |
std=(0.229, 0.224, 0.225) | |
else: | |
mean = (0.5, 0.5, 0.5) | |
std = (0.5, 0.5, 0.5) | |
train_transform = transforms.Compose([ | |
transforms.Scale(load_size), | |
transforms.RandomCrop(crop_size), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=mean, std=std) | |
]) | |
# Load dataset | |
train_dataset = datasets.ImageFolder(root = opt_dataset, transform = train_transform) | |
# Create loader | |
loader = DataLoader(train_dataset, batch_size = opt_batch_size, shuffle = True, num_workers = 4, pin_memory = True, drop_last = True) | |
# Discriminator ("critic") -- Wasserstein | |
class WDiscriminator(nn.Module): | |
def __init__(self, isize, ndf, nc = 3): | |
super(WDiscriminator, self).__init__() | |
assert isize % 16 == 0, "isize has to be a multiple of 16" | |
main = nn.Sequential() | |
# input is nc x isize x isize | |
main.add_module('initial.conv.{0}-{1}'.format(nc, ndf), nn.Conv2d(nc, ndf, 4, 2, 1, bias=False)) | |
main.add_module('initial.relu.{0}'.format(ndf), nn.LeakyReLU(0.2, inplace=True)) | |
csize, cndf = isize / 2, ndf | |
# Reduce map size | |
while csize > 4: | |
in_feat = cndf | |
out_feat = cndf * 2 | |
main.add_module('pyramid.{0}-{1}.conv'.format(in_feat, out_feat), nn.Conv2d(in_feat, out_feat, 4, 2, 1, bias=False)) | |
main.add_module('pyramid.{0}.batchnorm'.format(out_feat), nn.InstanceNorm2d(out_feat, affine=True)) | |
main.add_module('pyramid.{0}.relu'.format(out_feat), nn.LeakyReLU(0.2, inplace=True)) | |
cndf = cndf * 2 | |
csize = csize / 2 | |
# state size. K x 4 x 4 | |
main.add_module('final.{0}-{1}.conv'.format(cndf, 1), nn.Conv2d(cndf, 1, 4, 1, 0, bias=False)) | |
self.main = main | |
def forward(self, input): | |
output = self.main(input) | |
#output = output.mean(0) | |
return output.view(-1) | |
# Generator -- Wasserstein | |
class WGenerator(nn.Module): | |
def __init__(self, isize, nz, ngf, nc = 3): | |
super(WGenerator, self).__init__() | |
assert isize % 16 == 0, "isize has to be a multiple of 16" | |
cngf, tisize = ngf//2, 4 | |
while tisize != isize: | |
cngf = cngf * 2 | |
tisize = tisize * 2 | |
main = nn.Sequential() | |
# input is Z, going into a convolution | |
main.add_module('initial.{0}-{1}.convt'.format(nz, cngf), nn.ConvTranspose2d(nz, cngf, 4, 1, 0, bias=False)) | |
main.add_module('initial.{0}.batchnorm'.format(cngf), nn.InstanceNorm2d(cngf, affine=True)) | |
main.add_module('initial.{0}.relu'.format(cngf), nn.ReLU(True)) | |
csize, cndf = 4, cngf | |
while csize < isize//2: | |
main.add_module('pyramid.{0}-{1}.convt'.format(cngf, cngf//2), nn.ConvTranspose2d(cngf, cngf//2, 4, 2, 1, bias=False)) | |
main.add_module('pyramid.{0}.batchnorm'.format(cngf//2), nn.InstanceNorm2d(cngf//2, affine=True)) | |
main.add_module('pyramid.{0}.relu'.format(cngf//2), nn.ReLU(True)) | |
cngf = cngf // 2 | |
csize = csize * 2 | |
main.add_module('final.{0}-{1}.convt'.format(cngf, nc), nn.ConvTranspose2d(cngf, nc, 4, 2, 1, bias=False)) | |
main.add_module('final.{0}.tanh'.format(nc), nn.Tanh()) | |
self.main = main | |
def forward(self, input): | |
output = self.main(input) | |
return output | |
# Custom weight initialization | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0.0, 0.02) | |
elif classname.find('InstanceNorm') != -1: | |
m.weight.data.normal_(1.0, 0.02) | |
m.bias.data.fill_(0) | |
# Create generator model/optimizer | |
g_net = WGenerator(nz=100, isize=64, ngf=64) | |
g_net.apply(weights_init) | |
g_optimizer = torch.optim.RMSprop(g_net.parameters(), lr=1e-4) | |
# Create discriminator model/optimizer | |
d_net = WDiscriminator(isize=64, ndf=64) | |
d_net.apply(weights_init) | |
d_optimizer = torch.optim.RMSprop(d_net.parameters(), lr=5e-5) | |
# Setup CUDA | |
g_net.cuda() | |
d_net.cuda() | |
# Debug options | |
save_images_every = 20 | |
cnt = 0 | |
# Auxiliary variables | |
noise = torch.FloatTensor(opt_batch_size, 100, 1, 1) | |
one = torch.FloatTensor([1]) | |
minus_one = one*-1 | |
noise = noise.cuda() | |
one = one.cuda() | |
minus_one = minus_one.cuda() | |
for p in d_net.parameters(): | |
torch.nn.init.normal(p.data, std=0.1) | |
# Start training | |
d_net.train() | |
g_net.train() | |
g_iterations = 0 | |
for epoch in range(0, 10000): | |
# Keep track of losses | |
d_real_loss_sum = 0 | |
d_fake_loss_sum = 0 | |
d_loss_cnt = 0 | |
g_loss_sum = 0 | |
g_loss_cnt = 0 | |
# Get data iterator | |
data_iter = iter(loader) | |
data_len = len(loader) | |
data_i = 0 | |
# Process until data ends | |
while data_i < data_len: | |
# Compute gradients for discriminator | |
for p in d_net.parameters(): p.requires_grad = True | |
for p in g_net.parameters(): p.requires_grad = False | |
# Set number of discrimator iterations | |
d_iters = 500 if g_iterations <= 25 or g_iterations % 500 == 0 else 5 | |
# Perform discriminator iterations | |
d_i = 0 | |
while data_i < data_len and d_i < d_iters: | |
# Increase discriminator iterations | |
d_i += 1 | |
# Clamp parameters to a cube | |
#for p in d_net.parameters(): | |
# p.data.clamp_(-0.01, 0.01) | |
# Get data (keep reference to data on host) | |
(real_input_cpu, _) = data_iter.next() | |
real_input = real_input_cpu | |
data_i += 1 | |
# Check CUDA | |
real_input = real_input.cuda(async = True) | |
# Wrap for autograd | |
real_input = Variable(real_input) | |
# Reset gradients | |
d_optimizer.zero_grad() | |
g_optimizer.zero_grad() | |
# Forward (discriminator, real) | |
d_real_loss_vec = d_net(real_input) | |
d_real_loss = d_real_loss_vec.mean(0).view(1) | |
d_real_loss_sum += d_real_loss.data[0] | |
# Backward (discriminator, real) | |
#d_real_loss.backward(one) | |
# Forward (discriminator, fake) | |
noise.normal_(0,1) | |
noise_v = Variable(noise, volatile = True) | |
g_output = Variable(g_net(noise_v).data) | |
d_fake_loss_vec = d_net(g_output) | |
d_fake_loss = d_fake_loss_vec.mean(0).view(1) | |
d_fake_loss_sum += d_fake_loss.data[0] | |
# Backward (discriminator, fake) | |
#d_fake_loss.backward(minus_one) | |
dist = (((g_output-real_input).view(g_output.size(0),-1)**2).sum(1)+1e-6)**0.5 | |
lip_est = (d_fake_loss_vec-d_real_loss_vec).abs()/(dist+1e-6) | |
lip_loss = opt_penalty*((lip_est-1).clamp(min=0)**2).mean(0).view(1) | |
d_loss = d_real_loss-d_fake_loss+lip_loss | |
d_loss.backward() | |
# Update discriminator | |
d_optimizer.step() | |
# Update loss count | |
d_loss_cnt += 1 | |
# Don't compute gradients w.r.t. parameters for discriminator | |
for p in d_net.parameters(): p.requires_grad = False | |
for p in g_net.parameters(): p.requires_grad = True | |
# Forward (generator) | |
noise.normal_(0,1) | |
noise_v = Variable(noise) | |
g_output = g_net(noise_v) | |
g_loss = d_net(g_output).mean(0).view(-1) | |
g_loss_sum += g_loss.data[0] | |
g_loss_cnt += 1 | |
# Backward (generator) | |
g_optimizer.zero_grad() | |
g_loss.backward() | |
g_optimizer.step() | |
# Increase generator iterations | |
g_iterations += 1 | |
# Save images every once in a while | |
cnt += 1 | |
if cnt % save_images_every == 0: | |
# Move generator output to host | |
g_output_cpu = g_output.data.cpu() | |
# Normalize images between 0 and 1 | |
real_input_cpu = (real_input_cpu - real_input_cpu.min())/(real_input_cpu.max() - real_input_cpu.min()) | |
g_output_cpu = (g_output_cpu - g_output_cpu.min())/(g_output_cpu.max() - g_output_cpu.min()) | |
# Save images | |
Image.fromarray(torchvision.utils.make_grid(real_input_cpu, nrow = 4).permute(1,2,0).mul(255).byte().numpy()).save("real_input_" + ("bad" if opt_bad else "good") + ".png") | |
Image.fromarray(torchvision.utils.make_grid(g_output_cpu, nrow = 4).permute(1,2,0).mul(255).byte().numpy()).save("g_output_" + ("bad" if opt_bad else "good") + ".png") | |
# Print losses at the end of the epoch | |
print("Epoch {0}: GL={1:.4f}, DRL={2:.4f}, DFL={3:.4f}".format(epoch, g_loss_sum/g_loss_cnt, d_real_loss_sum/d_loss_cnt, d_fake_loss_sum/d_loss_cnt)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uses SLOGAN Lipschitz penalty.