Created
December 3, 2017 23:34
-
-
Save cjratcliff/c85b1672930d4becef3bd2a483c0bd5d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import numpy as np | |
import tensorflow as tf | |
from keras.datasets import mnist, cifar10, cifar100 | |
import matplotlib.pyplot as plt | |
from utils import get_minibatches_idx | |
# Based on https://jmetzen.github.io/2015-11-27/vae.html | |
n_samples = 60000 | |
eps = 1e-10 | |
class VariationalAutoencoder(object): | |
# See "Auto-Encoding Variational Bayes" by Kingma and Welling for more details | |
def __init__(self, batch_size=100): | |
self.x = tf.placeholder(tf.float32, [None, 784]) | |
# Encode each image as mean and variance vectors | |
h = tf.contrib.layers.fully_connected(self.x, 500) | |
h = tf.contrib.layers.fully_connected(h, 500) | |
self.z_mean = tf.contrib.layers.fully_connected(h, 20, activation_fn=tf.identity) | |
self.z_log_sigma_sq = tf.contrib.layers.fully_connected(h, 20, activation_fn=tf.identity) | |
# Draw one sample z from Gaussian distribution | |
noise = tf.random_normal((batch_size, 20), 0, 1, dtype=tf.float32) | |
# Add noise | |
self.z = self.z_mean + noise*tf.sqrt(tf.exp(self.z_log_sigma_sq)) | |
# Decoder | |
h = tf.contrib.layers.fully_connected(self.z, 500) | |
h = tf.contrib.layers.fully_connected(h, 500) | |
self.x_reconstr_mean = tf.contrib.layers.fully_connected(h, 784, activation_fn=tf.nn.sigmoid) | |
reconstr_loss = -tf.reduce_sum(self.x * tf.log(eps + self.x_reconstr_mean) | |
+ (1-self.x) * tf.log(eps + 1 - self.x_reconstr_mean), 1) | |
# KL-divergence | |
latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq | |
- tf.square(self.z_mean) | |
- tf.exp(self.z_log_sigma_sq), 1) | |
self.cost = tf.reduce_mean(reconstr_loss + latent_loss) # average over batch | |
self.train_step = tf.train.AdamOptimizer(learning_rate=0.001).minimize(self.cost) | |
self.sess = tf.Session() | |
self.sess.run(tf.global_variables_initializer()) | |
def generate(self, z_mu=None): | |
if z_mu is None: | |
z_mu = np.random.normal(size=20) | |
return self.sess.run(self.x_reconstr_mean, feed_dict={self.z: z_mu}) | |
def reconstruct(self, X): | |
return self.sess.run(self.x_reconstr_mean, feed_dict={self.x: X}) | |
def train(self, X, batch_size=100, training_epochs=10): | |
print("\nStarting training") | |
for epoch in range(training_epochs): | |
avg_cost = 0.0 | |
train_indices = get_minibatches_idx(len(X), batch_size, shuffle=True) | |
for it in train_indices: | |
batch_x = [X[i] for i in it] | |
_, cost = self.sess.run((self.train_step, self.cost), feed_dict={self.x: batch_x}) | |
avg_cost += cost / n_samples * batch_size | |
print("Epoch:", '%d' % (epoch+1), "cost=", "{:.3f}".format(avg_cost)) | |
def main(): | |
dataset = 'mnist' # mnist, cifar10, cifar100 | |
# Load the data | |
# It will be downloaded first if necessary | |
if dataset == 'mnist': | |
(X_train, _), (X_test, _) = mnist.load_data() | |
img_size = 28 | |
num_channels = 1 | |
elif dataset == 'cifar10': | |
(X_train, _), (X_test, _) = cifar10.load_data() | |
img_size = 32 | |
num_channels = 3 | |
elif dataset == 'cifar100': | |
(X_train, _), (X_test, _) = cifar100.load_data() | |
img_size = 32 | |
num_channels = 3 | |
X_train = X_train.astype('float32') | |
X_test = X_test.astype('float32') | |
X_train = np.reshape(X_train,[-1,img_size,img_size,num_channels]) | |
X_test = np.reshape(X_test,[-1,img_size,img_size,num_channels]) | |
X_train /= 255 | |
X_test /= 255 | |
print(X_train.shape[0], 'train samples') | |
print(X_test.shape[0], 'test samples') | |
X_train = np.reshape(X_train,[-1,28*28]) | |
X_test = np.reshape(X_test,[-1,28*28]) | |
vae = VariationalAutoencoder(batch_size=100) | |
print("Model compiled") | |
vae.train(X_train, training_epochs=5) | |
x_sample = X_test[:100] | |
x_reconstruct = vae.reconstruct(x_sample) | |
#x_gen = vae.generate() | |
#print(x_gen.shape) | |
plt.figure(figsize=(8,12)) | |
for i in range(5): | |
plt.subplot(5, 2, 2*i + 1) | |
plt.imshow(x_sample[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray") | |
plt.title("Test input") | |
plt.subplot(5, 2, 2*i + 2) | |
plt.imshow(x_reconstruct[i].reshape(28, 28), vmin=0, vmax=1, cmap="gray") | |
plt.title("Reconstruction") | |
plt.tight_layout() | |
plt.show() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment