Last active
October 19, 2018 10:17
-
-
Save breeko/9f3d7ec269188cfde587d54f79be6e63 to your computer and use it in GitHub Desktop.
Recreation of GANs in Action Chapter 3 Variational Auto Encoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# 2.8 Variational autoencoder (VAE) | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
from keras.layers import Input, Dense, Lambda, Flatten, Reshape | |
from keras.models import Model | |
from keras import backend as K | |
from keras.callbacks import EarlyStopping | |
from keras import objectives | |
from keras.datasets import mnist | |
import numpy as np | |
from itertools import product | |
batch_size = 100 | |
digit_size = 28 | |
original_dim = (28, 28) | |
original_dim_flat = np.prod(original_dim) | |
latent_dim = 2 | |
intermediate_dim = 256 | |
epochs = 10 | |
epsilon_std = 1.0 | |
def sampling(args): | |
z_mean, z_log_var = args | |
epsilon = K.random_normal(shape=(latent_dim,), mean=0.0) | |
return z_mean + K.exp(z_log_var/2.0) * epsilon | |
# Encoder | |
inp = Input(shape=(*original_dim,), name="input") | |
f = Flatten()(inp) | |
h = Dense(intermediate_dim, activation="relu", name="encoding")(f) | |
z_mean = Dense(latent_dim, activation="sigmoid", name="mean")(h) | |
z_log_var = Dense(latent_dim, name="var")(h) | |
z = Lambda(sampling, output_shape=(latent_dim,), name="sampling")([z_mean, z_log_var]) | |
encoder = Model(inp, [z_mean, z_log_var, z], name="encoder") | |
encoder.summary() | |
# Decoder | |
inp_encoded = Input(shape=(latent_dim,), name="z_sampling") | |
decoded = Dense(intermediate_dim, activation="relu", name="decoding")(inp_encoded) | |
flat_decoded = Dense(original_dim_flat, activation="sigmoid", name="flat_decoded")(decoded) | |
out_decoded = Reshape(original_dim, name="out_decoded")(flat_decoded) | |
decoder = Model(inp_encoded, out_decoded, name="decoder") | |
decoder.summary() | |
# VAE | |
out_combined = decoder(encoder(inp)[2]) | |
vae = Model(inp, out_combined) | |
vae.summary() | |
def vae_loss(x, x_decoded_mean): | |
xent_loss = original_dim_flat * objectives.binary_crossentropy(K.flatten(x), K.flatten(x_decoded_mean)) | |
# Kullback-Leibler divergence - how one prob dist diverges from another | |
# (ie how far off we're from 0 mean and 1 std). Used as kind of regularization | |
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) | |
return K.mean(xent_loss + kl_loss) | |
vae.compile(optimizer="rmsprop", loss=vae_loss) | |
(x_train, y_train), (x_test, y_test) = mnist.load_data() | |
x_train = x_train.astype("float32") / 255.0 | |
x_test = x_test.astype("float32") / 255.0 | |
vae.fit(x_train, x_train, shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=(x_test,x_test), verbose=1, callbacks=[EarlyStopping(patience=2)]) | |
def show_side_by_side(left, right): | |
zipped = zip(left,right) | |
num_rows = min(len(right), len(left)) | |
f, axes = plt.subplots(num_rows, 2, sharex=True, sharey=True) | |
f.set_size_inches(2, num_rows) | |
n_row = 0 | |
for l, r in zipped: | |
axes[n_row,0].imshow(l) | |
axes[n_row,1].imshow(r) | |
n_row += 1 | |
f.tight_layout() | |
plt.show() | |
out_example = vae.predict(x_test[:10]).reshape(-1, *original_dim) | |
show_side_by_side(x_test, out_example) | |
decoder.predict(np.array([[0.5,0.5]])) | |
def show_grid(out): | |
num_items = len(out) | |
num_rows = int(num_items ** 0.5) | |
assert num_rows % 1 == 0, "Input be square" | |
f = np.zeros(shape=(num_rows * digit_size, num_rows * digit_size)) | |
for idx, digit in enumerate(out): | |
c = idx // num_rows | |
r = idx % num_rows | |
f[r * digit_size: (r + 1) * digit_size, c * digit_size: (c + 1) * digit_size] = digit | |
plt.figure(figsize=(10,10)) | |
plt.imshow(f, cmap="Greys_r") | |
plt.show() | |
steps = np.array(list(product(np.arange(0, 1.1, 0.1), np.arange(0, 1.1, 0.1)))) | |
out_steps = decoder.predict(steps).reshape(-1, *original_dim) | |
show_grid(out_steps) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment