Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Last active December 23, 2018 07:41
Show Gist options
  • Select an option

  • Save NaxAlpha/241f3430fdc1424c84d3f49e52c1f896 to your computer and use it in GitHub Desktop.

Select an option

Save NaxAlpha/241f3430fdc1424c84d3f49e52c1f896 to your computer and use it in GitHub Desktop.
MNIST Vanilla GAN Failure: Worked after weight initialization!
# Implements GAN
from itertools import chain
import torch
from torch.optim import *
from torch.autograd import *
from torch.nn.modules import *
import torchvision.transforms as T
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
def make_generator(*layers, leaky=0.2):
return Sequential(
*chain(*((
Linear(inp, out),
LeakyReLU(leaky),
) for inp, out in zip(layers, layers[1:]))),
Tanh()
)
def make_discriminator(*layers, leaky=0.2, drops=0.3):
return Sequential(
*chain(*((
Linear(inp, out),
LeakyReLU(leaky),
Dropout(drops),
) for inp, out in zip(layers, layers[1:]))),
Sigmoid()
)
def ones(size):
# return 0.9 + Variable(torch.randn(size, 1)).cuda()/100
return Variable(torch.ones(size, 1)).cuda()
def zeros(size):
# return 0.1 + Variable(torch.randn(size, 1)).cuda()/100
return Variable(torch.zeros(size, 1)).cuda()
class GAN:
def __init__(self, input_shape, latent_shape, *layers, leaky=0.2, drops=0.3):
self.batch_size = None
self.input_shape = input_shape
self.latent_shape = latent_shape
self.generator = make_generator(
latent_shape,
*layers,
input_shape,
leaky=leaky
).cuda()
for param in self.generator.parameters():
param.data.normal_(0, 0.02)
self.discriminator = make_discriminator(
input_shape,
*reversed(layers),
1,
leaky=leaky,
drops=drops
).cuda()
for param in self.discriminator.parameters():
param.data.normal_(0, 0.02)
self.d_opt = Adam(self.discriminator.parameters(), lr=0.0002)
self.g_opt = Adam(self.generator.parameters(), lr=0.0002)
self.loss = BCELoss().cuda()
def noise(self, size=None):
if not size:
size = self.batch_size
return Variable(torch.randn(size, self.latent_shape)).cuda()
def train_discriminator(self, real_data):
fake_data = self.generator(self.noise()).detach()
self.d_opt.zero_grad()
p_real = self.discriminator(real_data)
e_real = self.loss(p_real, ones(self.batch_size))
e_real.backward()
p_fake = self.discriminator(fake_data)
e_fake = self.loss(p_fake, zeros(self.batch_size))
e_fake.backward()
self.d_opt.step()
return e_fake + e_real
def train_generator(self):
fake_data = self.generator(self.noise())
self.g_opt.zero_grad()
prd = self.discriminator(fake_data)
err = self.loss(prd, ones(self.batch_size))
err.backward()
self.g_opt.step()
return err
def generate(self, noise, *shape):
items = self.generator(noise)
return items.view(noise.size(0), *shape)
def train_step(self, batch_data):
real_data = Variable(batch_data).cuda()
self.batch_size = real_data.size(0)
real_data = real_data.view(self.batch_size, self.input_shape)
e_dis = self.train_discriminator(real_data)
e_gen = self.train_generator()
return e_dis.detach().cpu().item(), e_gen.detach().cpu().item()
if __name__ == '__main__':
import numpy as np
import matplotlib.pyplot as plt
plt.ion()
trans = T.Compose([
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
data = MNIST('data/', True, trans)
loader = DataLoader(data, 100, True)
gen_loss = []
dsc_loss = []
mnist = GAN(28*28, 100, 256, 512, 1024)
sample = mnist.noise(9)
for i in range(100):
for j, (batch, _) in enumerate(loader):
dl, gl = mnist.train_step(batch)
gen_loss.append(gl)
dsc_loss.append(dl)
if j % 50 != 0:
continue
n_samples = 1000
n_items = len(gen_loss)
plt.figure(0)
plt.clf()
plt.cla()
d_x = max(n_items // n_samples, 1)
x_xs = np.arange(0, n_items, d_x)
plt.legend(handles=[
plt.plot(x_xs, gen_loss[::d_x], label='Generator Loss')[0],
plt.plot(x_xs, dsc_loss[::d_x], label='Discriminator Loss')[0]
])
plt.figure(1)
img = mnist.generate(sample, 1, 28, 28).data.cpu()
grid = make_grid(img, nrow=3, normalize=True).permute(1, 2, 0).numpy()
plt.imshow(grid)
plt.show()
plt.pause(0.1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment