Skip to content

Instantly share code, notes, and snippets.

@cjratcliff
Created December 3, 2017 23:34
Show Gist options
  • Save cjratcliff/c85b1672930d4becef3bd2a483c0bd5d to your computer and use it in GitHub Desktop.
Save cjratcliff/c85b1672930d4becef3bd2a483c0bd5d to your computer and use it in GitHub Desktop.
import time
import numpy as np
import tensorflow as tf
from keras.datasets import mnist, cifar10, cifar100
import matplotlib.pyplot as plt
from utils import get_minibatches_idx
# Based on https://jmetzen.github.io/2015-11-27/vae.html
n_samples = 60000
eps = 1e-10
class VariationalAutoencoder(object):
# See "Auto-Encoding Variational Bayes" by Kingma and Welling for more details
def __init__(self, batch_size=100):
self.x = tf.placeholder(tf.float32, [None, 784])
# Encode each image as mean and variance vectors
h = tf.contrib.layers.fully_connected(self.x, 500)
h = tf.contrib.layers.fully_connected(h, 500)
self.z_mean = tf.contrib.layers.fully_connected(h, 20, activation_fn=tf.identity)
self.z_log_sigma_sq = tf.contrib.layers.fully_connected(h, 20, activation_fn=tf.identity)
# Draw one sample z from Gaussian distribution
noise = tf.random_normal((batch_size, 20), 0, 1, dtype=tf.float32)
# Add noise
self.z = self.z_mean + noise*tf.sqrt(tf.exp(self.z_log_sigma_sq))
# Decoder
h = tf.contrib.layers.fully_connected(self.z, 500)
h = tf.contrib.layers.fully_connected(h, 500)
self.x_reconstr_mean = tf.contrib.layers.fully_connected(h, 784, activation_fn=tf.nn.sigmoid)
reconstr_loss = -tf.reduce_sum(self.x * tf.log(eps + self.x_reconstr_mean)
+ (1-self.x) * tf.log(eps + 1 - self.x_reconstr_mean), 1)
# KL-divergence
latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq
- tf.square(self.z_mean)
- tf.exp(self.z_log_sigma_sq), 1)
self.cost = tf.reduce_mean(reconstr_loss + latent_loss) # average over batch
self.train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(self.cost)
self.sess = tf.Session()
self.sess.run(tf.global_variables_initializer())
def generate(self, z_mu=None):
if z_mu is None:
z_mu = np.random.normal(size=20)
return self.sess.run(self.x_reconstr_mean, feed_dict={self.z: z_mu})
def reconstruct(self, X):
return self.sess.run(self.x_reconstr_mean, feed_dict={self.x: X})
def train(self, X, batch_size=100, training_epochs=10):
print("\nStarting training")
for epoch in range(training_epochs):
avg_cost = 0.0
train_indices = get_minibatches_idx(len(X), batch_size, shuffle=True)
for it in train_indices:
batch_x = [X[i] for i in it]
_, cost = self.sess.run((self.train_step, self.cost), feed_dict={self.x: batch_x})
avg_cost += cost / n_samples * batch_size
print("Epoch:", '%d' % (epoch+1), "cost=", "{:.3f}".format(avg_cost))
def main():
dataset = 'mnist' # mnist, cifar10, cifar100
# Load the data
# It will be downloaded first if necessary
if dataset == 'mnist':
(X_train, _), (X_test, _) = mnist.load_data()
img_size = 28
num_channels = 1
elif dataset == 'cifar10':
(X_train, _), (X_test, _) = cifar10.load_data()
img_size = 32
num_channels = 3
elif dataset == 'cifar100':
(X_train, _), (X_test, _) = cifar100.load_data()
img_size = 32
num_channels = 3
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train = np.reshape(X_train,[-1,img_size,img_size,num_channels])
X_test = np.reshape(X_test,[-1,img_size,img_size,num_channels])
X_train /= 255
X_test /= 255
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
X_train = np.reshape(X_train,[-1,28*28])
X_test = np.reshape(X_test,[-1,28*28])
vae = VariationalAutoencoder(batch_size=100)
print("Model compiled")
vae.train(X_train, training_epochs=5)
x_sample = X_test[:100]
x_reconstruct = vae.reconstruct(x_sample)
#x_gen = vae.generate()
#print(x_gen.shape)
plt.figure(figsize=(8,12))
for i in range(5):
plt.subplot(5, 2, 2*i + 1)
plt.imshow(x_sample[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray")
plt.title("Test input")
plt.subplot(5, 2, 2*i + 2)
plt.imshow(x_reconstruct[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray")
plt.title("Reconstruction")
plt.tight_layout()
plt.show()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment