Last active February 25, 2020 13:09
Variational Auto-Encoder using Lasagne
import sys
import os
import numpy as np
import theano
import theano.tensor as T
import lasagne as nn
import time
from PIL import Image
from scipy.stats import norm
from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
# ############################################################################
# Tencia Lee
# Some code borrowed from:
# Implementation of variational autoencoder (AEVB) algorithm as in:
# [1] arXiv:1312.6114 [stat.ML] (Diederik P Kingma, Max Welling 2013)
# ################## Download and prepare the MNIST dataset ##################
# For the linked MNIST data, the autoencoder learns well only in binary mode.
# This is most likely due to the distribution of the values. Most pixels are
# either very close to 0, or very close to 1.
# Running this code with default settings should produce a manifold similar
# to the example in this directory. An animation of the manifold's evolution
# can be found here:
def load_dataset():
if sys.version_info[0] == 2:
from urllib import urlretrieve
from urllib.request import urlretrieve
def download(filename, source=''):
print("Downloading %s" % filename)
urlretrieve(source + filename, filename)
import gzip
def load_mnist_images(filename):
if not os.path.exists(filename):
with, 'rb') as f:
data = np.frombuffer(, np.uint8, offset=16)
data = data.reshape(-1, 1, 28, 28).transpose(0,1,3,2)
return data / np.float32(255)
X_train = load_mnist_images('train-images-idx3-ubyte.gz')
X_test = load_mnist_images('t10k-images-idx3-ubyte.gz')
X_train, X_val = X_train[:-10000], X_train[-10000:]
return X_train, X_val, X_test
# ############################# Output images ################################
# image processing using PIL
def get_image_array(X, index, shp=(28,28), channels=1):
ret = (X[index] * 255.).reshape(channels,shp[0],shp[1]) \
if channels == 1:
ret = ret.reshape(shp[1], shp[0])
return ret
def get_image_pair(X, Xpr, channels=1, idx=-1):
mode = 'RGB' if channels == 3 else 'L'
i = np.random.randint(X.shape[0]) if idx == -1 else idx
orig = Image.fromarray(get_image_array(X, i, shp, channels), mode=mode)
ret =, (orig.size[0], orig.size[1]*2))
ret.paste(orig, (0,0))
new = Image.fromarray(get_image_array(Xpr, i, shp, channels), mode=mode)
ret.paste(new, (0, orig.size[1]))
return ret
# ############################# Batch iterator ###############################
def iterate_minibatches(inputs, batchsize, shuffle=False):
if shuffle:
indices = np.arange(len(inputs))
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
excerpt = slice(start_idx, start_idx + batchsize)
yield inputs[excerpt]
# ##################### Custom layer for middle of VCAE ######################
# This layer takes the mu and sigma (both DenseLayers) and combines them with
# a random vector epsilon to sample values for a multivariate Gaussian
class GaussianSampleLayer(nn.layers.MergeLayer):
def __init__(self, mu, logsigma, rng=None, **kwargs):
self.rng = rng if rng else RandomStreams(nn.random.get_rng().randint(1,2147462579))
super(GaussianSampleLayer, self).__init__([mu, logsigma], **kwargs)
def get_output_shape_for(self, input_shapes):
return input_shapes[0]
def get_output_for(self, inputs, withNoise=True, **kwargs):
mu, logsigma = inputs
shape=(self.input_shapes[0][0] or inputs[0].shape[0],
self.input_shapes[0][1] or inputs[0].shape[1])
if withNoise:
return mu + T.exp(logsigma) * self.rng.normal(shape)
return mu
# ############################## Build Model #################################
# encoder has 1 hidden layer, where we get mu and sigma for Z given an inp X
# continuous decoder has 1 hidden layer, where we get mu and sigma for X given code Z
# binary decoder has 1 hidden layer, where we calculate p(X=1)
# once we have (mu, sigma) for Z, we sample L times
# Then L separate outputs are constructed and the final layer averages them
def build_vae(inputvar, L=2, binary=True, imgshape=(28,28), channels=1, z_dim=2, n_hid=1024):
x_dim = imgshape[0] * imgshape[1] * channels
l_input = nn.layers.InputLayer(shape=(None,channels,imgshape[0], imgshape[1]),
input_var=inputvar, name='input')
l_enc_hid = nn.layers.DenseLayer(l_input, num_units=n_hid,
nonlinearity=nn.nonlinearities.tanh if binary else T.nnet.softplus,
l_enc_mu = nn.layers.DenseLayer(l_enc_hid, num_units=z_dim,
nonlinearity = None, name='enc_mu')
l_enc_logsigma = nn.layers.DenseLayer(l_enc_hid, num_units=z_dim,
nonlinearity = None, name='enc_logsigma')
l_dec_mu_list = []
l_dec_logsigma_list = []
l_output_list = []
# tie the weights of all L versions so they are the "same" layer
W_dec_hid = None
b_dec_hid = None
W_dec_mu = None
b_dec_mu = None
W_dec_ls = None
b_dec_ls = None
for i in xrange(L):
l_Z = GaussianSampleLayer(l_enc_mu, l_enc_logsigma, name='Z')
l_dec_hid = nn.layers.DenseLayer(l_Z, num_units=n_hid,
nonlinearity = nn.nonlinearities.tanh if binary else T.nnet.softplus,
W=nn.init.GlorotUniform() if W_dec_hid is None else W_dec_hid,
b=nn.init.Constant(0.) if b_dec_hid is None else b_dec_hid,
if binary:
l_output = nn.layers.DenseLayer(l_dec_hid, num_units = x_dim,
nonlinearity = nn.nonlinearities.sigmoid,
W = nn.init.GlorotUniform() if W_dec_mu is None else W_dec_mu,
b = nn.init.Constant(0.) if b_dec_mu is None else b_dec_mu,
name = 'dec_output')
if W_dec_hid is None:
W_dec_hid = l_dec_hid.W
b_dec_hid = l_dec_hid.b
W_dec_mu = l_output.W
b_dec_mu = l_output.b
l_dec_mu = nn.layers.DenseLayer(l_dec_hid, num_units=x_dim,
nonlinearity = None,
W = nn.init.GlorotUniform() if W_dec_mu is None else W_dec_mu,
b = nn.init.Constant(0) if b_dec_mu is None else b_dec_mu,
name = 'dec_mu')
# relu_shift is for numerical stability - if training data has any
# dimensions where stdev=0, allowing logsigma to approach -inf
# will cause the loss function to become NAN. So we set the limit
# stdev >= exp(-1 * relu_shift)
relu_shift = 10
l_dec_logsigma = nn.layers.DenseLayer(l_dec_hid, num_units=x_dim,
W = nn.init.GlorotUniform() if W_dec_ls is None else W_dec_ls,
b = nn.init.Constant(0) if b_dec_ls is None else b_dec_ls,
nonlinearity = lambda a: T.nnet.relu(a+relu_shift)-relu_shift,
l_output = GaussianSampleLayer(l_dec_mu, l_dec_logsigma,
if W_dec_hid is None:
W_dec_hid = l_dec_hid.W
b_dec_hid = l_dec_hid.b
W_dec_mu = l_dec_mu.W
b_dec_mu = l_dec_mu.b
W_dec_ls = l_dec_logsigma.W
b_dec_ls = l_dec_logsigma.b
l_output = nn.layers.ElemwiseSumLayer(l_output_list, coeffs=1./L, name='output')
return l_enc_mu, l_enc_logsigma, l_dec_mu_list, l_dec_logsigma_list, l_output_list, l_output
# ############################## Main program ################################
def log_likelihood(tgt, mu, ls):
return T.sum(-(np.float32(0.5 * np.log(2 * np.pi)) + ls)
- 0.5 * T.sqr(tgt - mu) / T.exp(2 * ls))
def main(L=2, z_dim=2, n_hid=1024, num_epochs=300, binary=True):
print("Loading data...")
X_train, X_val, X_test = load_dataset()
width, height = X_train.shape[2], X_train.shape[3]
input_var = T.tensor4('inputs')
# Create VAE model
print("Building model and compiling functions...")
print("L = {}, z_dim = {}, n_hid = {}, binary={}".format(L, z_dim, n_hid, binary))
x_dim = width * height
l_z_mu, l_z_ls, l_x_mu_list, l_x_ls_list, l_x_list, l_x = \
build_vae(input_var, L=L, binary=binary, z_dim=z_dim, n_hid=n_hid)
def build_loss(deterministic, withNoise=True):
layer_outputs = nn.layers.get_output([l_z_mu, l_z_ls] + l_x_mu_list + l_x_ls_list
+ l_x_list, deterministic=deterministic)
z_mu = layer_outputs[0]
z_ls = layer_outputs[1]
x_mu = [] if binary else layer_outputs[2:2+L]
x_ls = [] if binary else layer_outputs[2+L:2+2*L]
x_list = layer_outputs[2:2+L] if binary else layer_outputs[2+2*L:2+3*L]
# Loss expression has two parts as specified in [1]
# kl_div = KL divergence between p_theta(z) and p(z|x)
# - divergence between prior distr and approx posterior of z given x
# - or how likely we are to see this z when accounting for Gaussian prior
# logpxz = log p(x|z)
# - log-likelihood of x given z
# - in binary case logpxz = cross-entropy
# - in continuous case, is log-likelihood of seeing the target x under the
# Gaussian distribution parameterized by dec_mu, sigma = exp(dec_logsigma)
kl_div = 0.5 * T.sum(1 + 2*z_ls - T.sqr(z_mu) - T.exp(2 * z_ls))
if binary:
logpxz = sum(nn.objectives.binary_crossentropy(x,
input_var.flatten(2)).sum() for x in x_list) * (-1./L)
prediction = nn.layers.get_output(l_x, deterministic=deterministic,
logpxz = sum(log_likelihood(input_var.flatten(2), mu, ls)
for mu, ls in zip(x_mu, x_ls))/L
prediction = T.sum(nn.layers.get_output(l_x_mu_list, deterministic=deterministic,
withNoise=withNoise), axis=0)/L
loss = -1 * (logpxz + kl_div)
return loss, prediction
# If there are dropout layers etc these functions return masked or non-masked expressions
# depending on if they will be used for training or validation/test err calcs
loss, _ = build_loss(deterministic=False)
test_loss, test_prediction = build_loss(deterministic=True, withNoise=False)
# functions for generating images given a code (used for visualization)
# for an given code z, we deterministically take x_mu as the generated data
# (no Gaussian noise is used to either encode or decode).
z_var = T.vector()
if binary:
generated_x = nn.layers.get_output(l_x, {l_z_mu:z_var}, withNoise=False,
generated_x = nn.layers.get_output(l_x_mu_list[0], {l_z_mu:z_var}, withNoise=False,
gen_fn = theano.function([z_var], generated_x)
# ADAM updates
params = nn.layers.get_all_params(l_x, trainable=True)
updates = nn.updates.adam(loss, params, learning_rate=1e-4)
train_fn = theano.function([input_var], loss, updates=updates)
val_fn = theano.function([input_var], test_loss)
print("Starting training...")
batch_size = 100
for epoch in range(num_epochs):
train_err = 0
train_batches = 0
start_time = time.time()
for batch in iterate_minibatches(X_train, batch_size, shuffle=True):
this_err = train_fn(batch)
train_err += this_err
train_batches += 1
val_err = 0
val_batches = 0
for batch in iterate_minibatches(X_val, batch_size, shuffle=False):
err = val_fn(batch)
val_err += err
val_batches += 1
print("Epoch {} of {} took {:.3f}s".format(
epoch + 1, num_epochs, time.time() - start_time))
print(" training loss:\t\t{:.6f}".format(train_err / train_batches))
print(" validation loss:\t\t{:.6f}".format(val_err / val_batches))
test_err = 0
test_batches = 0
for batch in iterate_minibatches(X_test, batch_size, shuffle=False):
err = val_fn(batch)
test_err += err
test_batches += 1
test_err /= test_batches
print("Final results:")
print(" test loss:\t\t\t{:.6f}".format(test_err))
# save some example pictures so we can see what it's done
example_batch_size = 20
X_comp = X_test[:example_batch_size]
pred_fn = theano.function([input_var], test_prediction)
X_pred = pred_fn(X_comp).reshape(-1, 1, width, height)
for i in range(20):
get_image_pair(X_comp, X_pred, idx=i, channels=1).save('output_{}.jpg'.format(i))
# save the parameters so they can be loaded for next time
fn = 'params_{:.6f}'.format(test_err)
np.savez(fn + '.npz', *nn.layers.get_all_param_values(l_x))
# sample from latent space if it's 2d
if z_dim == 2:
im ='L', (width*19,height*19))
for (x,y),val in np.ndenumerate(np.zeros((19,19))):
z = np.asarray([norm.ppf(0.05*(x+1)), norm.ppf(0.05*(y+1))],
x_gen = gen_fn(z).reshape(-1, 1, width, height)
im.paste(Image.fromarray(get_image_array(x_gen,0)), (x*width,y*height))'gen.jpg')
if __name__ == '__main__':
# Arguments - integers, except for binary/continous. Default uses binary.
# Run with option --continuous for continuous output.
import argparse
parser = argparse.ArgumentParser(description='Command line options')
parser.add_argument('--num_epochs', type=int, dest='num_epochs')
parser.add_argument('--L', type=int, dest='L')
parser.add_argument('--z_dim', type=int, dest='z_dim')
parser.add_argument('--n_hid', type=int, dest='n_hid')
parser.add_argument('--binary', dest='binary', action='store_true')
parser.add_argument('--continuous', dest='binary', action='store_false')
args = parser.parse_args(sys.argv[1:])
main(**{k:v for (k,v) in vars(args).items() if v is not None})
How many epochs did you have to run it for to get the 2D picture from the paper?
And to what level do you get the Loss function down to?

Hi, sorry I never got a notification of your comment. I believe a full run through the script (300 epochs) is more than enough to get an image similar to the one in the video (it is a different embedding every time). Looking at my saved params, I think I got it down to a loss of around 13,958.89 (combined log-likelihood). Hope that helps!

