Created
September 11, 2017 14:08
-
-
Save gabrielhuang/c23fa9680b2595e7c2c84a6f9a9a0fd4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os, sys | |
sys.path.append(os.getcwd()) | |
import time | |
import matplotlib | |
matplotlib.use('Agg') | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import sklearn.datasets | |
import tflib as lib | |
import tflib.save_images | |
import tflib.mnist | |
import tflib.plot | |
import torch | |
import torch.autograd as autograd | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from scatmnist_data import ScatMnistDataset | |
from torch.utils.data import DataLoader | |
torch.manual_seed(1) | |
use_cuda = torch.cuda.is_available() | |
if use_cuda: | |
gpu = 0 | |
DIM = 64 # Model dimensionality | |
BATCH_SIZE = 50 # Batch size | |
CRITIC_ITERS = 5 # For WGAN and WGAN-GP, number of critic iters per gen iter | |
LAMBDA = 10 # Gradient penalty lambda hyperparameter | |
ITERS = 200000 # How many generator iterations to train for | |
#OUTPUT_DIM = 784 # Number of pixels in MNIST (28*28) | |
# Load data | |
smd = ScatMnistDataset('scatmnist.npy') | |
smd_loader = DataLoader(smd, batch_size=BATCH_SIZE, shuffle=True) | |
def inf_train_gen_scat(): | |
while True: | |
for data in smd_loader: | |
yield data | |
SCAT_CHANNELS = smd.data.shape[1] # 81 | |
SCAT_HEIGHT = smd.data.shape[2] # 7 | |
SCAT_WIDTH = smd.data.shape[3] # 7 | |
OUTPUT_DIM = SCAT_CHANNELS*SCAT_HEIGHT*SCAT_WIDTH | |
lib.print_model_settings(locals().copy()) | |
# ==================Definition Start====================== | |
class Generator(nn.Module): | |
''' | |
See http://pytorch.org/docs/master/nn.html#torch.nn.ConvTranspose2d | |
for shapes using ConvTranspose2d | |
''' | |
def __init__(self): | |
super(Generator, self).__init__() | |
''' | |
preprocess = nn.Sequential( | |
nn.Linear(128, 4*4*4*DIM), | |
nn.ReLU(True), | |
) | |
# in-channels, out-channels | |
block1 = nn.Sequential( | |
nn.ConvTranspose2d(4*DIM, 2*DIM, 5), | |
nn.ReLU(True), | |
) | |
block2 = nn.Sequential( | |
nn.ConvTranspose2d(2*DIM, DIM, 5), | |
nn.ReLU(True), | |
) | |
deconv_out = nn.ConvTranspose2d(DIM, 1, 8, stride=2) | |
self.block1 = block1 | |
self.block2 = block2 | |
self.deconv_out = deconv_out | |
self.preprocess = preprocess | |
self.sigmoid = nn.Sigmoid() | |
''' | |
N = SCAT_CHANNELS*SCAT_HEIGHT*SCAT_WIDTH | |
self.dense1 = nn.Linear(128, N/4) | |
self.dense2 = nn.Linear(N/4, N/2) | |
self.dense3 = nn.Linear(N/2, N) | |
def forward(self, input): | |
out = F.relu(self.dense1(input)) | |
out = F.relu(self.dense2(out)) | |
out = self.dense3(out) | |
return out.view(-1, SCAT_CHANNELS, SCAT_HEIGHT, SCAT_WIDTH) | |
''' | |
output = self.preprocess(input) | |
output = output.view(-1, 4*DIM, 4, 4) | |
#print output.size() | |
output = self.block1(output) | |
#print output.size() | |
output = output[:, :, :7, :7] | |
#print output.size() | |
output = self.block2(output) | |
#print output.size() | |
output = self.deconv_out(output) | |
output = self.sigmoid(output) | |
#print output.size() | |
return output.view(-1, OUTPUT_DIM) | |
''' | |
class Discriminator(nn.Module): | |
def __init__(self): | |
super(Discriminator, self).__init__() | |
''' | |
self.conv1 = nn.Conv2d(SCAT_CHANNELS, DIM, 3, stride=1, padding=2) | |
self.conv2 = nn.Conv2d(DIM, DIM, 3, stride=1, padding=2) | |
self.conv3 = nn.Conv2d(DIM, DIM, 3, stride=1, padding=2) | |
self.dense4 = nn.Linear(4*4*4*DIM, 1) | |
main = nn.Sequential( | |
#nn.Conv2d(1, DIM, 5, stride=2, padding=2), | |
nn.Conv2d(SCAT_CHANNELS, DIM, 3, stride=1, padding=2), | |
nn.ReLU(True), | |
nn.Conv2d(DIM, 2*DIM, 5, stride=2, padding=2), | |
nn.ReLU(True), | |
nn.Conv2d(2*DIM, 4*DIM, 5, stride=2, padding=2), | |
nn.ReLU(True), | |
) | |
self.main = main | |
''' | |
N = SCAT_CHANNELS*SCAT_HEIGHT*SCAT_WIDTH | |
self.dense1 = nn.Linear(N, N/2) | |
self.dense2 = nn.Linear(N/2, N/4) | |
self.dense3 = nn.Linear(N/4, 1) | |
def forward(self, input): | |
''' | |
#input = input.view(-1, 1, 28, 28) | |
print input.size() | |
out = F.relu(self.conv1(input)) | |
print out.size() | |
out = F.relu(self.conv2(out)) | |
print out.size() | |
out = F.relu(self.conv3(out)) | |
print out.size() | |
out = out.view(-1, 4*4*4*DIM) | |
print out.size() | |
out = self.dense4(out) | |
''' | |
out = input.view(-1, 1, SCAT_CHANNELS*SCAT_HEIGHT*SCAT_WIDTH) | |
out = F.relu(self.dense1(out)) | |
out = F.relu(self.dense2(out)) | |
out = self.dense3(out) | |
return out.view(-1, 1) | |
def generate_image(frame, netG): | |
noise = torch.randn(BATCH_SIZE, 128) | |
if use_cuda: | |
noise = noise.cuda(gpu) | |
noisev = autograd.Variable(noise, volatile=True) | |
samples = netG(noisev) | |
#samples = samples.view(BATCH_SIZE, 28, 28) | |
# print samples.size() | |
samples = samples.cpu().data | |
with open('tmp/mnist/ft_{}.pkl'.format(frame), 'wb') as fp: | |
np.save(fp, samples.numpy()) | |
torch.save(samples, 'tmp/mnist/ft_{}.torch'.format(frame)) | |
''' | |
lib.save_images.save_images( | |
samples, | |
'tmp/mnist/samples_{}.png'.format(frame) | |
) | |
''' | |
# Dataset iterator | |
train_gen, dev_gen, test_gen = lib.mnist.load(BATCH_SIZE, BATCH_SIZE) | |
def inf_train_gen(): | |
while True: | |
for images,targets in train_gen(): | |
yield images | |
def calc_gradient_penalty(netD, real_data, fake_data): | |
#print real_data.size() | |
alpha = torch.rand(BATCH_SIZE, 1, 1, 1) # (BATCH, CHANNEL, HEI, WID) | |
#alpha = torch.rand(BATCH_SIZE, 1) | |
alpha = alpha.expand(real_data.size()) | |
alpha = alpha.cuda(gpu) if use_cuda else alpha | |
interpolates = alpha * real_data + ((1 - alpha) * fake_data) | |
if use_cuda: | |
interpolates = interpolates.cuda(gpu) | |
interpolates = autograd.Variable(interpolates, requires_grad=True) | |
disc_interpolates = netD(interpolates) | |
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, | |
grad_outputs=torch.ones(disc_interpolates.size()).cuda(gpu) if use_cuda else torch.ones( | |
disc_interpolates.size()), | |
create_graph=True, retain_graph=True, only_inputs=True)[0] | |
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA | |
return gradient_penalty | |
# ==================Definition End====================== | |
netG = Generator() | |
netD = Discriminator() | |
print netG | |
print netD | |
if use_cuda: | |
netD = netD.cuda(gpu) | |
netG = netG.cuda(gpu) | |
#optimizerD = optim.Adam(netD.parameters(), lr=1e-5, betas=(0.5, 0.9)) | |
#optimizerG = optim.Adam(netG.parameters(), lr=1e-5, betas=(0.5, 0.9)) | |
optimizerD = optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) | |
optimizerG = optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) | |
one = torch.FloatTensor([1]) | |
mone = one * -1 | |
if use_cuda: | |
one = one.cuda(gpu) | |
mone = mone.cuda(gpu) | |
data = inf_train_gen_scat() | |
for iteration in xrange(ITERS): | |
start_time = time.time() | |
############################ | |
# (1) Update D network | |
########################### | |
for p in netD.parameters(): # reset requires_grad | |
p.requires_grad = True # they are set to False below in netG update | |
for iter_d in xrange(CRITIC_ITERS): | |
_data = data.next() | |
real_data = torch.Tensor(_data) | |
if use_cuda: | |
real_data = real_data.cuda(gpu) | |
real_data_v = autograd.Variable(real_data) | |
netD.zero_grad() | |
# train with real | |
D_real = netD(real_data_v) | |
D_real = D_real.mean() | |
# print D_real | |
D_real.backward(mone) | |
# train with fake | |
noise = torch.randn(BATCH_SIZE, 128) | |
if use_cuda: | |
noise = noise.cuda(gpu) | |
noisev = autograd.Variable(noise, volatile=True) # totally freeze netG | |
fake = autograd.Variable(netG(noisev).data) | |
inputv = fake | |
D_fake = netD(inputv) | |
D_fake = D_fake.mean() | |
D_fake.backward(one) | |
# train with gradient penalty | |
gradient_penalty = calc_gradient_penalty(netD, real_data_v.data, fake.data) | |
gradient_penalty.backward() | |
D_cost = D_fake - D_real + gradient_penalty | |
Wasserstein_D = D_real - D_fake | |
optimizerD.step() | |
############################ | |
# (2) Update G network | |
########################### | |
for p in netD.parameters(): | |
p.requires_grad = False # to avoid computation | |
netG.zero_grad() | |
noise = torch.randn(BATCH_SIZE, 128) | |
if use_cuda: | |
noise = noise.cuda(gpu) | |
noisev = autograd.Variable(noise) | |
fake = netG(noisev) | |
G = netD(fake) | |
G = G.mean() | |
G.backward(mone) | |
G_cost = -G | |
optimizerG.step() | |
# Write logs and save samples | |
lib.plot.plot('tmp/mnist/time', time.time() - start_time) | |
lib.plot.plot('tmp/mnist/train disc cost', D_cost.cpu().data.numpy()) | |
lib.plot.plot('tmp/mnist/train gen cost', G_cost.cpu().data.numpy()) | |
lib.plot.plot('tmp/mnist/wasserstein distance', Wasserstein_D.cpu().data.numpy()) | |
# Calculate dev loss and generate samples every 100 iters | |
if iteration % 100 == 99: | |
''' | |
dev_disc_costs = [] | |
for images,_ in dev_gen(): | |
imgs = torch.Tensor(images) | |
if use_cuda: | |
imgs = imgs.cuda(gpu) | |
imgs_v = autograd.Variable(imgs, volatile=True) | |
D = netD(imgs_v) | |
_dev_disc_cost = -D.mean().cpu().data.numpy() | |
dev_disc_costs.append(_dev_disc_cost) | |
lib.plot.plot('tmp/mnist/dev disc cost', np.mean(dev_disc_costs)) | |
''' | |
generate_image(iteration, netG) | |
# Write logs every 100 iters | |
if (iteration < 5) or (iteration % 100 == 99): | |
lib.plot.flush() | |
lib.plot.tick() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment