Last active
December 27, 2017 16:07
-
-
Save suriyadeepan/70771510183fdca9a5bde463e110d8e5 to your computer and use it in GitHub Desktop.
Variation Autoencoder with MNIST
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tqdm import tqdm | |
import os | |
import tensorflow as tf | |
import edward as ed | |
import numpy as np | |
from edward.models import Bernoulli, Normal | |
from tensorflow.examples.tutorials.mnist import input_data | |
from scipy.misc import imsave | |
# shorthands | |
dense = tf.contrib.layers.fully_connected | |
mnist = input_data.read_data_sets("MNIST_data/") | |
# image attributes | |
IMGX, IMGY = 28, 28 | |
# save directory | |
OUT = 'results/' | |
class VAE(): | |
def __init__(self, latent_dims=2, batch_size=50, | |
x_dims=IMGX*IMGY, hdim_prob=256, hdim_var=256, | |
lr=0.01): | |
""" | |
Variation Autoencoder | |
Args: | |
[latent_dims] : (2) Number of latent factors | |
[batch_size] : (50) batch size | |
[x_dims] : () input size (MNIST image size : IMGX * IMGY) | |
[hdim_prob] : (256) hidden dims of probabilistic model | |
[hdim_var ] : (256) hidden dims of variational model | |
[lr] : (0.01) learning rate | |
Returns: | |
(NA) | |
""" | |
# clear graph | |
tf.reset_default_graph() | |
# [1] Probabilistic Model | |
# Model z as a Normal distribution | |
z = Normal(loc=tf.zeros([batch_size, latent_dims]), | |
scale=tf.ones([batch_size, latent_dims])) | |
# Sample from z, project to 256 dims | |
hid_prob = dense(z.value(), hdim_prob) | |
# Model a Bernoulli distribution to sample 'x' | |
x = Bernoulli(logits=dense(hid_prob, x_dims, | |
activation_fn=None)) | |
# [2] Variational Model | |
# placeholder for images as input | |
self._x = tf.placeholder(tf.int32, [batch_size, x_dims]) | |
# project to 256 dims | |
hid_var = dense(tf.cast(self._x, tf.float32), hdim_var) | |
# model p(z | x) | |
qz = Normal(loc=dense(hid_var, latent_dims), | |
scale=dense(hid_var, latent_dims, | |
activation_fn=tf.nn.softplus)) | |
# inference operation | |
inference = ed.KLqp({ z : qz }, data = { x : self._x }) | |
# setup optimizer | |
optimizer = tf.train.RMSPropOptimizer(0.01, epsilon=1.0) | |
# initialize inference | |
inference.initialize(optimizer=optimizer) | |
# attach to object | |
self.inference = inference | |
self.x = x | |
def sample(self): | |
""" | |
Sample from x | |
Args: | |
None | |
Returns: | |
a sample from x | |
""" | |
return self.x.eval() | |
def next_batch(): | |
""" | |
Fetch next batch of MNIST | |
Args: | |
batch_size : batch size | |
Returns: | |
batch of images of size batch_size | |
""" | |
images, labels = mnist.train.next_batch(batch_size) | |
return np.array(images > 0, np.int32) | |
def train_and_sample(model, epochs=1000, batch_size=50, | |
sample_after=0): | |
""" | |
Train VAE and sample from x every step of training | |
Args: | |
model : an instance of VAE model | |
[epochs] : (1000) number of epochs | |
[batch_size] : (50) batch size | |
Returns: | |
(trained_model, samples) | |
""" | |
samples = [] | |
print(':: "train and sample" running') | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
iterations = 55000 // batch_size | |
for i in tqdm(range(epochs)): | |
loss = 0. | |
for j in range(iterations): | |
batch_x = next_batch() | |
train_info = model.inference.update(feed_dict={ | |
model._x : batch_x | |
}) | |
loss += train_info['loss'] | |
avg_loss = loss / iterations | |
avg_loss = avg_loss / batch_size | |
# sample from x | |
if i > sample_after: | |
samples.append(model.sample()) | |
return model, samples | |
def save_samples(samples): | |
""" | |
Save images to disk | |
Args: | |
samples : [ epochs x batch_size ] images | |
""" | |
epochs = len(samples) | |
batch_size = samples[0].shape[0] | |
print('\n:: writing samples to disk') | |
for i, sample in enumerate(samples): | |
for j in range(batch_size): | |
imsave(os.path.join(OUT, '{}_{}.png'.format(i,j)), | |
sample[j].reshape(IMGX, IMGY)) | |
if __name__ == '__main__': | |
# hyperparameters | |
batch_size = 128 | |
latent_dims = 2 | |
# initialize VAE | |
vae = VAE(latent_dims=latent_dims, batch_size=batch_size) | |
# train and sample | |
trained_vae, samples = train_and_sample(vae, | |
batch_size=batch_size, | |
epochs=1000, sample_after=900) | |
# save to disk | |
save_samples(samples) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment