Last active
December 23, 2018 07:41
-
-
Save NaxAlpha/241f3430fdc1424c84d3f49e52c1f896 to your computer and use it in GitHub Desktop.
MNIST Vanilla GAN Failure: Worked after weight initialization!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # Implements GAN | |
| from itertools import chain | |
| import torch | |
| from torch.optim import * | |
| from torch.autograd import * | |
| from torch.nn.modules import * | |
| import torchvision.transforms as T | |
| from torchvision.datasets import MNIST | |
| from torch.utils.data import DataLoader | |
| from torchvision.utils import make_grid | |
| def make_generator(*layers, leaky=0.2): | |
| return Sequential( | |
| *chain(*(( | |
| Linear(inp, out), | |
| LeakyReLU(leaky), | |
| ) for inp, out in zip(layers, layers[1:]))), | |
| Tanh() | |
| ) | |
| def make_discriminator(*layers, leaky=0.2, drops=0.3): | |
| return Sequential( | |
| *chain(*(( | |
| Linear(inp, out), | |
| LeakyReLU(leaky), | |
| Dropout(drops), | |
| ) for inp, out in zip(layers, layers[1:]))), | |
| Sigmoid() | |
| ) | |
| def ones(size): | |
| # return 0.9 + Variable(torch.randn(size, 1)).cuda()/100 | |
| return Variable(torch.ones(size, 1)).cuda() | |
| def zeros(size): | |
| # return 0.1 + Variable(torch.randn(size, 1)).cuda()/100 | |
| return Variable(torch.zeros(size, 1)).cuda() | |
| class GAN: | |
| def __init__(self, input_shape, latent_shape, *layers, leaky=0.2, drops=0.3): | |
| self.batch_size = None | |
| self.input_shape = input_shape | |
| self.latent_shape = latent_shape | |
| self.generator = make_generator( | |
| latent_shape, | |
| *layers, | |
| input_shape, | |
| leaky=leaky | |
| ).cuda() | |
| for param in self.generator.parameters(): | |
| param.data.normal_(0, 0.02) | |
| self.discriminator = make_discriminator( | |
| input_shape, | |
| *reversed(layers), | |
| 1, | |
| leaky=leaky, | |
| drops=drops | |
| ).cuda() | |
| for param in self.discriminator.parameters(): | |
| param.data.normal_(0, 0.02) | |
| self.d_opt = Adam(self.discriminator.parameters(), lr=0.0002) | |
| self.g_opt = Adam(self.generator.parameters(), lr=0.0002) | |
| self.loss = BCELoss().cuda() | |
| def noise(self, size=None): | |
| if not size: | |
| size = self.batch_size | |
| return Variable(torch.randn(size, self.latent_shape)).cuda() | |
| def train_discriminator(self, real_data): | |
| fake_data = self.generator(self.noise()).detach() | |
| self.d_opt.zero_grad() | |
| p_real = self.discriminator(real_data) | |
| e_real = self.loss(p_real, ones(self.batch_size)) | |
| e_real.backward() | |
| p_fake = self.discriminator(fake_data) | |
| e_fake = self.loss(p_fake, zeros(self.batch_size)) | |
| e_fake.backward() | |
| self.d_opt.step() | |
| return e_fake + e_real | |
| def train_generator(self): | |
| fake_data = self.generator(self.noise()) | |
| self.g_opt.zero_grad() | |
| prd = self.discriminator(fake_data) | |
| err = self.loss(prd, ones(self.batch_size)) | |
| err.backward() | |
| self.g_opt.step() | |
| return err | |
| def generate(self, noise, *shape): | |
| items = self.generator(noise) | |
| return items.view(noise.size(0), *shape) | |
| def train_step(self, batch_data): | |
| real_data = Variable(batch_data).cuda() | |
| self.batch_size = real_data.size(0) | |
| real_data = real_data.view(self.batch_size, self.input_shape) | |
| e_dis = self.train_discriminator(real_data) | |
| e_gen = self.train_generator() | |
| return e_dis.detach().cpu().item(), e_gen.detach().cpu().item() | |
| if __name__ == '__main__': | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| plt.ion() | |
| trans = T.Compose([ | |
| T.ToTensor(), | |
| T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
| ]) | |
| data = MNIST('data/', True, trans) | |
| loader = DataLoader(data, 100, True) | |
| gen_loss = [] | |
| dsc_loss = [] | |
| mnist = GAN(28*28, 100, 256, 512, 1024) | |
| sample = mnist.noise(9) | |
| for i in range(100): | |
| for j, (batch, _) in enumerate(loader): | |
| dl, gl = mnist.train_step(batch) | |
| gen_loss.append(gl) | |
| dsc_loss.append(dl) | |
| if j % 50 != 0: | |
| continue | |
| n_samples = 1000 | |
| n_items = len(gen_loss) | |
| plt.figure(0) | |
| plt.clf() | |
| plt.cla() | |
| d_x = max(n_items // n_samples, 1) | |
| x_xs = np.arange(0, n_items, d_x) | |
| plt.legend(handles=[ | |
| plt.plot(x_xs, gen_loss[::d_x], label='Generator Loss')[0], | |
| plt.plot(x_xs, dsc_loss[::d_x], label='Discriminator Loss')[0] | |
| ]) | |
| plt.figure(1) | |
| img = mnist.generate(sample, 1, 28, 28).data.cpu() | |
| grid = make_grid(img, nrow=3, normalize=True).permute(1, 2, 0).numpy() | |
| plt.imshow(grid) | |
| plt.show() | |
| plt.pause(0.1) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment