Created
September 27, 2017 02:25
-
-
Save xmfbit/cbdef5d6bfcb4f35f9c851161191f4b4 to your computer and use it in GitHub Desktop.
A simple example of DCGAN on MNIST using PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def init_weight(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0., 0.02) | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.normal_(1., 0.02) | |
m.bias.data.fill_(0.) | |
class DCGenerator(nn.Module): | |
def __init__(self, convs): | |
super(DCGenerator, self).__init__() | |
self.convs = nn.ModuleList() | |
in_channels = 1 | |
for i, (out_channels, kernel_size, stride, padding) in enumerate(convs): | |
self.convs.append(nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)) | |
if i < len(convs)-1: | |
# we use BN and RELU for each layer except the output | |
self.convs.append(nn.BatchNorm2d(out_channels)) | |
self.convs.append(nn.ReLU()) | |
else: | |
# in output, we use Tanh to generate data in [-1, 1] | |
self.convs.append(nn.Tanh()) | |
in_channels = out_channels | |
self.apply(init_weight) | |
def forward(self, input): | |
out = input | |
for module in self.convs: | |
out = module(out) | |
return out | |
class Discriminator(nn.Module): | |
def __init__(self, convs): | |
super(Discriminator, self).__init__() | |
self.convs = nn.ModuleList() | |
in_channels = 1 | |
for i, (out_channels, kernel_size, stride, padding) in enumerate(convs): | |
self.convs.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False)) | |
if i != 0 and i != len(convs)-1: | |
# we donot use BN in the input layer of D | |
self.convs.append(nn.BatchNorm2d(out_channels)) | |
if i != len(convs)-1: | |
self.convs.append(nn.LeakyReLU(0.2)) | |
in_channels = out_channels | |
#self.cls = nn.Linear(out_channels*in_width*in_height, nout) | |
self.apply(init_weight) | |
def forward(self, input): | |
out = input | |
for layer in self.convs: | |
out = layer(out) | |
out = out.view(out.size(0), -1) | |
out = F.sigmoid(out) | |
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import torch | |
import torch.optim as optim | |
from torch.autograd import Variable | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.datasets as datasets | |
import torchvision.transforms as transforms | |
from torchvision.utils import save_image | |
from torch.utils.data import DataLoader | |
import matplotlib.pyplot as plt | |
from model import * | |
def sample_noise(batch_size, channels): | |
return torch.randn(batch_size, channels, 1, 1).float() | |
max_iter = 25 | |
download = False | |
trans = transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize([0.5,], [0.5,])]) | |
mnist = datasets.MNIST('./', train=True, transform=trans, download=download) | |
batch_size = 64 | |
use_cuda = True | |
if __name__ == '__main__': | |
d_convs = [(32, 4, 2, 1), (64, 4, 2, 1), (1, 7, 1, 0)] | |
discriminator = Discriminator(d_convs) | |
g_convs = [(64, 7, 1, 0), (32, 4, 2, 1), (1, 4, 2, 1)] | |
generator = DCGenerator(g_convs) | |
print(discriminator) | |
print(generator) | |
if use_cuda: | |
discriminator, generator = discriminator.cuda(), generator.cuda() | |
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True) | |
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999)) | |
real_label, fake_label = 1, 0 | |
criterion = nn.BCELoss() | |
if use_cuda: | |
criterion = criterion.cuda() | |
fixed_noise = sample_noise(batch_size, 1) | |
if use_cuda: | |
fixed_noise = fixed_noise.cuda() | |
fixed_noise = Variable(fixed_noise, volatile=True) | |
for epoch in xrange(1, max_iter+1): | |
for i, (x, _) in enumerate(dataloader): | |
batch_size = x.size(0) | |
# training D on real data | |
optimizer_d.zero_grad() | |
x = Variable(x) | |
if use_cuda: | |
x = x.cuda() | |
output = discriminator(x) | |
real_v = Variable(torch.Tensor(batch_size).fill_(real_label).float()) | |
if use_cuda: | |
real_v = real_v.cuda() | |
loss_d = criterion(output, real_v) | |
loss_d.backward() | |
Dx = output.data.mean(dim=0)[0] | |
# training D on fake data | |
z = sample_noise(batch_size, 1) | |
z = Variable(z) | |
if use_cuda: | |
z = z.cuda() | |
fake = generator(z) | |
output = discriminator(fake.detach()) | |
fake_v = Variable(torch.Tensor(batch_size).fill_(fake_label).float()) | |
if use_cuda: | |
fake_v = fake_v.cuda() | |
loss_g = criterion(output, fake_v) | |
loss_g.backward() | |
optimizer_d.step() | |
err_D = loss_d.data[0] + loss_g.data[0] | |
# training G | |
optimizer_g.zero_grad() | |
output = discriminator(fake) | |
real_v = Variable(torch.Tensor(batch_size).fill_(real_label).float()) | |
if use_cuda: | |
real_v = real_v.cuda() | |
loss = criterion(output, real_v) | |
loss.backward() | |
optimizer_g.step() | |
err_G = loss.data[0] | |
DGz = output.data.mean(dim=0)[0] | |
print('[{:02d}/{:02d}],[{:03d}/{:03d}], errD: {:.4f}, D(x): {:.4f}, errG: {:.4f}, D(G(z)): {:.4f}'.format( | |
epoch, max_iter, i, len(dataloader), err_D, Dx, err_G, DGz)) | |
fake = generator(fixed_noise) | |
save_image(fake.data, './mnist-fake-{:02d}.png'.format(epoch), | |
normalize=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment