Created
February 2, 2017 14:21
-
-
Save f0k/f3190ebba6c53887d598d03119ca2066 to your computer and use it in GitHub Desktop.
Lasagne WGAN example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
""" | |
Example employing Lasagne for digit generation using the MNIST dataset and | |
Wasserstein Generative Adversarial Networks | |
(WGANs, see https://arxiv.org/abs/1701.07875 for the paper and | |
https://github.com/martinarjovsky/WassersteinGAN for the "official" code). | |
It is based on a DCGAN example: | |
https://gist.github.com/f0k/738fa2eedd9666b78404ed1751336f56 | |
This, in turn, is based on the MNIST example in Lasagne: | |
https://lasagne.readthedocs.io/en/latest/user/tutorial.html | |
Jan Schlüter, 2017-02-02 | |
""" | |
from __future__ import print_function | |
import sys | |
import os | |
import time | |
import numpy as np | |
import theano | |
import theano.tensor as T | |
import lasagne | |
# ################## Download and prepare the MNIST dataset ################## | |
# This is just some way of getting the MNIST dataset from an online location | |
# and loading it into numpy arrays. It doesn't involve Lasagne at all. | |
def load_dataset(): | |
# We first define a download function, supporting both Python 2 and 3. | |
if sys.version_info[0] == 2: | |
from urllib import urlretrieve | |
else: | |
from urllib.request import urlretrieve | |
def download(filename, source='http://yann.lecun.com/exdb/mnist/'): | |
print("Downloading %s" % filename) | |
urlretrieve(source + filename, filename) | |
# We then define functions for loading MNIST images and labels. | |
# For convenience, they also download the requested files if needed. | |
import gzip | |
def load_mnist_images(filename): | |
if not os.path.exists(filename): | |
download(filename) | |
# Read the inputs in Yann LeCun's binary format. | |
with gzip.open(filename, 'rb') as f: | |
data = np.frombuffer(f.read(), np.uint8, offset=16) | |
# The inputs are vectors now, we reshape them to monochrome 2D images, | |
# following the shape convention: (examples, channels, rows, columns) | |
data = data.reshape(-1, 1, 28, 28) | |
# The inputs come as bytes, we convert them to float32 in range [0,1]. | |
# (Actually to range [0, 255/256], for compatibility to the version | |
# provided at http://deeplearning.net/data/mnist/mnist.pkl.gz.) | |
return data / np.float32(256) | |
def load_mnist_labels(filename): | |
if not os.path.exists(filename): | |
download(filename) | |
# Read the labels in Yann LeCun's binary format. | |
with gzip.open(filename, 'rb') as f: | |
data = np.frombuffer(f.read(), np.uint8, offset=8) | |
# The labels are vectors of integers now, that's exactly what we want. | |
return data | |
# We can now download and read the training and test set images and labels. | |
X_train = load_mnist_images('train-images-idx3-ubyte.gz') | |
y_train = load_mnist_labels('train-labels-idx1-ubyte.gz') | |
X_test = load_mnist_images('t10k-images-idx3-ubyte.gz') | |
y_test = load_mnist_labels('t10k-labels-idx1-ubyte.gz') | |
# We reserve the last 10000 training examples for validation. | |
X_train, X_val = X_train[:-10000], X_train[-10000:] | |
y_train, y_val = y_train[:-10000], y_train[-10000:] | |
# We just return all the arrays in order, as expected in main(). | |
# (It doesn't matter how we do this as long as we can read them again.) | |
return X_train, y_train, X_val, y_val, X_test, y_test | |
# ##################### Build the neural network model ####################### | |
# We create two models: The generator and the critic network. | |
# The models are the same as in the Lasagne DCGAN example, except that the | |
# discriminator is now a critic with linear output instead of sigmoid output. | |
def build_generator(input_var=None): | |
from lasagne.layers import InputLayer, ReshapeLayer, DenseLayer | |
try: | |
from lasagne.layers import TransposedConv2DLayer as Deconv2DLayer | |
except ImportError: | |
raise ImportError("Your Lasagne is too old. Try the bleeding-edge " | |
"version: http://lasagne.readthedocs.io/en/latest/" | |
"user/installation.html#bleeding-edge-version") | |
try: | |
from lasagne.layers.dnn import batch_norm_dnn as batch_norm | |
except ImportError: | |
from lasagne.layers import batch_norm | |
from lasagne.nonlinearities import sigmoid | |
# input: 100dim | |
layer = InputLayer(shape=(None, 100), input_var=input_var) | |
# fully-connected layer | |
layer = batch_norm(DenseLayer(layer, 1024)) | |
# project and reshape | |
layer = batch_norm(DenseLayer(layer, 128*7*7)) | |
layer = ReshapeLayer(layer, ([0], 128, 7, 7)) | |
# two fractional-stride convolutions | |
layer = batch_norm(Deconv2DLayer(layer, 64, 5, stride=2, crop='same', | |
output_size=14)) | |
layer = Deconv2DLayer(layer, 1, 5, stride=2, crop='same', output_size=28, | |
nonlinearity=sigmoid) | |
print ("Generator output:", layer.output_shape) | |
return layer | |
def build_critic(input_var=None): | |
from lasagne.layers import (InputLayer, Conv2DLayer, ReshapeLayer, | |
DenseLayer) | |
try: | |
from lasagne.layers.dnn import batch_norm_dnn as batch_norm | |
except ImportError: | |
from lasagne.layers import batch_norm | |
from lasagne.nonlinearities import LeakyRectify | |
lrelu = LeakyRectify(0.2) | |
# input: (None, 1, 28, 28) | |
layer = InputLayer(shape=(None, 1, 28, 28), input_var=input_var) | |
# two convolutions | |
layer = batch_norm(Conv2DLayer(layer, 64, 5, stride=2, pad='same', | |
nonlinearity=lrelu)) | |
layer = batch_norm(Conv2DLayer(layer, 128, 5, stride=2, pad='same', | |
nonlinearity=lrelu)) | |
# fully-connected layer | |
layer = batch_norm(DenseLayer(layer, 1024, nonlinearity=lrelu)) | |
# output layer (linear and without bias) | |
layer = DenseLayer(layer, 1, nonlinearity=None, b=None) | |
print ("critic output:", layer.output_shape) | |
return layer | |
# ############################# Batch iterator ############################### | |
# This is just a simple helper function iterating over training data in | |
# mini-batches of a particular size, optionally in random order. It assumes | |
# data is available as numpy arrays. For big datasets, you could load numpy | |
# arrays as memory-mapped files (np.load(..., mmap_mode='r')), or write your | |
# own custom data iteration function. For small datasets, you can also copy | |
# them to GPU at once for slightly improved performance. This would involve | |
# several changes in the main program, though, and is not demonstrated here. | |
def iterate_minibatches(inputs, targets, batchsize, shuffle=False, | |
forever=False): | |
assert len(inputs) == len(targets) | |
if shuffle: | |
indices = np.arange(len(inputs)) | |
while True: | |
if shuffle: | |
np.random.shuffle(indices) | |
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize): | |
if shuffle: | |
excerpt = indices[start_idx:start_idx + batchsize] | |
else: | |
excerpt = slice(start_idx, start_idx + batchsize) | |
yield inputs[excerpt], targets[excerpt] | |
if not forever: | |
break | |
# ############################## Main program ################################ | |
# Everything else will be handled in our main program now. We could pull out | |
# more functions to better separate the code, but it wouldn't make it any | |
# easier to read. | |
def main(num_epochs=1000, epochsize=100, batchsize=64, initial_eta=5e-5, | |
clip=0.01): | |
# Load the dataset | |
print("Loading data...") | |
X_train, y_train, X_val, y_val, X_test, y_test = load_dataset() | |
# Prepare Theano variables for inputs and targets | |
noise_var = T.matrix('noise') | |
input_var = T.tensor4('inputs') | |
# Create neural network model | |
print("Building model and compiling functions...") | |
generator = build_generator(noise_var) | |
critic = build_critic(input_var) | |
# Create expression for passing real data through the critic | |
real_out = lasagne.layers.get_output(critic) | |
# Create expression for passing fake data through the critic | |
fake_out = lasagne.layers.get_output(critic, | |
lasagne.layers.get_output(generator)) | |
# Create score expressions to be maximized (i.e., negative losses) | |
generator_score = fake_out.mean() | |
critic_score = real_out.mean() - fake_out.mean() | |
# Create update expressions for training | |
generator_params = lasagne.layers.get_all_params(generator, trainable=True) | |
critic_params = lasagne.layers.get_all_params(critic, trainable=True) | |
eta = theano.shared(lasagne.utils.floatX(initial_eta)) | |
generator_updates = lasagne.updates.rmsprop( | |
-generator_score, generator_params, learning_rate=eta) | |
critic_updates = lasagne.updates.rmsprop( | |
-critic_score, critic_params, learning_rate=eta) | |
# Clip critic parameters in a limited range around zero (except biases) | |
for param in lasagne.layers.get_all_params(critic, trainable=True, | |
regularizable=True): | |
critic_updates[param] = T.clip(critic_updates[param], -clip, clip) | |
# Instantiate a symbolic noise generator to use for training | |
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams | |
srng = RandomStreams(seed=np.random.randint(2147462579, size=6)) | |
noise = srng.uniform((batchsize, 100)) | |
# Compile functions performing a training step on a mini-batch (according | |
# to the updates dictionary) and returning the corresponding score: | |
generator_train_fn = theano.function([], generator_score, | |
givens={noise_var: noise}, | |
updates=generator_updates) | |
critic_train_fn = theano.function([input_var], critic_score, | |
givens={noise_var: noise}, | |
updates=critic_updates) | |
# Compile another function generating some data | |
gen_fn = theano.function([noise_var], | |
lasagne.layers.get_output(generator, | |
deterministic=True)) | |
# Finally, launch the training loop. | |
print("Starting training...") | |
# We create an infinite supply of batches (as an iterable generator): | |
batches = iterate_minibatches(X_train, y_train, batchsize, shuffle=True, | |
forever=True) | |
# We iterate over epochs: | |
generator_updates = 0 | |
for epoch in range(num_epochs): | |
start_time = time.time() | |
# In each epoch, we do `epochsize` generator updates. Usually, the | |
# critic is updated 5 times before every generator update. For the | |
# first 25 generator updates and every 500 generator updates, the | |
# critic is updated 100 times instead, following the authors' code. | |
critic_scores = [] | |
generator_scores = [] | |
for _ in range(epochsize): | |
if (generator_updates < 25) or (generator_updates % 500 == 0): | |
critic_runs = 100 | |
else: | |
critic_runs = 5 | |
for _ in range(critic_runs): | |
batch = next(batches) | |
inputs, targets = batch | |
critic_scores.append(critic_train_fn(inputs)) | |
generator_scores.append(generator_train_fn()) | |
generator_updates += 1 | |
# Then we print the results for this epoch: | |
print("Epoch {} of {} took {:.3f}s".format( | |
epoch + 1, num_epochs, time.time() - start_time)) | |
print(" generator score:\t\t{}".format(np.mean(generator_scores))) | |
print(" Wasserstein distance:\t\t{}".format(np.mean(critic_scores))) | |
# And finally, we plot some generated data | |
samples = gen_fn(lasagne.utils.floatX(np.random.rand(42, 100))) | |
try: | |
import matplotlib.pyplot as plt | |
except ImportError: | |
pass | |
else: | |
plt.imsave('wgan_mnist_samples.png', | |
(samples.reshape(6, 7, 28, 28) | |
.transpose(0, 2, 1, 3) | |
.reshape(6*28, 7*28)), | |
cmap='gray') | |
# After half the epochs, we start decaying the learn rate towards zero | |
if epoch >= num_epochs // 2: | |
progress = float(epoch) / num_epochs | |
eta.set_value(lasagne.utils.floatX(initial_eta*2*(1 - progress))) | |
# Optionally, you could now dump the network weights to a file like this: | |
np.savez('wgan_mnist_gen.npz', *lasagne.layers.get_all_param_values(generator)) | |
np.savez('wgan_mnist_crit.npz', *lasagne.layers.get_all_param_values(critic)) | |
# | |
# And load them again later on like this: | |
# with np.load('model.npz') as f: | |
# param_values = [f['arr_%d' % i] for i in range(len(f.files))] | |
# lasagne.layers.set_all_param_values(network, param_values) | |
if __name__ == '__main__': | |
if ('--help' in sys.argv) or ('-h' in sys.argv): | |
print("Trains a WGAN on MNIST using Lasagne.") | |
print("Usage: %s [EPOCHS [EPOCHSIZE]]" % sys.argv[0]) | |
print() | |
print("EPOCHS: number of training epochs to perform (default: 1000)") | |
print("EPOCHSIZE: number of generator updates per epoch (default: 100)") | |
else: | |
kwargs = {} | |
if len(sys.argv) > 1: | |
kwargs['num_epochs'] = int(sys.argv[1]) | |
if len(sys.argv) > 2: | |
kwargs['epochsize'] = int(sys.argv[2]) | |
main(**kwargs) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment