Skip to content

Instantly share code, notes, and snippets.

@breeko
Last active October 19, 2018 10:17
Show Gist options
  • Save breeko/9f3d7ec269188cfde587d54f79be6e63 to your computer and use it in GitHub Desktop.
Save breeko/9f3d7ec269188cfde587d54f79be6e63 to your computer and use it in GitHub Desktop.
Recreation of GANs in Action Chapter 3 Variational Auto Encoder
# 2.8 Variational autoencoder (VAE)
import matplotlib.pyplot as plt
import matplotlib as mpl
from keras.layers import Input, Dense, Lambda, Flatten, Reshape
from keras.models import Model
from keras import backend as K
from keras.callbacks import EarlyStopping
from keras import objectives
from keras.datasets import mnist
import numpy as np
from itertools import product
batch_size = 100
digit_size = 28
original_dim = (28, 28)
original_dim_flat = np.prod(original_dim)
latent_dim = 2
intermediate_dim = 256
epochs = 10
epsilon_std = 1.0
def sampling(args):
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(latent_dim,), mean=0.0)
return z_mean + K.exp(z_log_var/2.0) * epsilon
# Encoder
inp = Input(shape=(*original_dim,), name="input")
f = Flatten()(inp)
h = Dense(intermediate_dim, activation="relu", name="encoding")(f)
z_mean = Dense(latent_dim, activation="sigmoid", name="mean")(h)
z_log_var = Dense(latent_dim, name="var")(h)
z = Lambda(sampling, output_shape=(latent_dim,), name="sampling")([z_mean, z_log_var])
encoder = Model(inp, [z_mean, z_log_var, z], name="encoder")
encoder.summary()
# Decoder
inp_encoded = Input(shape=(latent_dim,), name="z_sampling")
decoded = Dense(intermediate_dim, activation="relu", name="decoding")(inp_encoded)
flat_decoded = Dense(original_dim_flat, activation="sigmoid", name="flat_decoded")(decoded)
out_decoded = Reshape(original_dim, name="out_decoded")(flat_decoded)
decoder = Model(inp_encoded, out_decoded, name="decoder")
decoder.summary()
# VAE
out_combined = decoder(encoder(inp)[2])
vae = Model(inp, out_combined)
vae.summary()
def vae_loss(x, x_decoded_mean):
xent_loss = original_dim_flat * objectives.binary_crossentropy(K.flatten(x), K.flatten(x_decoded_mean))
# Kullback-Leibler divergence - how one prob dist diverges from another
# (ie how far off we're from 0 mean and 1 std). Used as kind of regularization
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
return K.mean(xent_loss + kl_loss)
vae.compile(optimizer="rmsprop", loss=vae_loss)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
vae.fit(x_train, x_train, shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=(x_test,x_test), verbose=1, callbacks=[EarlyStopping(patience=2)])
def show_side_by_side(left, right):
zipped = zip(left,right)
num_rows = min(len(right), len(left))
f, axes = plt.subplots(num_rows, 2, sharex=True, sharey=True)
f.set_size_inches(2, num_rows)
n_row = 0
for l, r in zipped:
axes[n_row,0].imshow(l)
axes[n_row,1].imshow(r)
n_row += 1
f.tight_layout()
plt.show()
out_example = vae.predict(x_test[:10]).reshape(-1, *original_dim)
show_side_by_side(x_test, out_example)
decoder.predict(np.array([[0.5,0.5]]))
def show_grid(out):
num_items = len(out)
num_rows = int(num_items ** 0.5)
assert num_rows % 1 == 0, "Input be square"
f = np.zeros(shape=(num_rows * digit_size, num_rows * digit_size))
for idx, digit in enumerate(out):
c = idx // num_rows
r = idx % num_rows
f[r * digit_size: (r + 1) * digit_size, c * digit_size: (c + 1) * digit_size] = digit
plt.figure(figsize=(10,10))
plt.imshow(f, cmap="Greys_r")
plt.show()
steps = np.array(list(product(np.arange(0, 1.1, 0.1), np.arange(0, 1.1, 0.1))))
out_steps = decoder.predict(steps).reshape(-1, *original_dim)
show_grid(out_steps)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment