Skip to content

Instantly share code, notes, and snippets.

@breeko
Last active October 26, 2018 10:52
Show Gist options
  • Save breeko/cc8dbaa940723a0365bfc362ea5dabb1 to your computer and use it in GitHub Desktop.
Save breeko/cc8dbaa940723a0365bfc362ea5dabb1 to your computer and use it in GitHub Desktop.
Creates a variational auto encoder
import matplotlib.pyplot as plt
from keras.layers import Input, Dense, Lambda, Flatten, Reshape
from keras.models import Model
from keras import backend as K
from keras.callbacks import EarlyStopping
from keras import objectives
import numpy as np
from itertools import product
batch_size = 100
digit_size = 28
original_dim = (digit_size, digit_size)
original_dim_flat = np.prod(original_dim)
latent_dim = 2
intermediate_dim = 512
epochs = 100
epsilon_std = 1.0
def sampling(args):
""" Returns a noisy sample from a normal distribution """
z_mean, z_log_var = args
epsilon = K.random_normal(shape=(latent_dim,), mean=0.0)
return z_mean + K.exp(z_log_var/2.0) * epsilon
def create_vae():
""" Returns a variational auto encoder """
inp = Input(shape=(*original_dim,), name="input")
# Encoder
f = Flatten(name="flatten")(inp)
h = Dense(intermediate_dim, activation="relu", name="encoding")(f)
z_mean = Dense(latent_dim, activation="sigmoid", name="mean")(h)
z_log_var = Dense(latent_dim, name="var")(h)
z = Lambda(sampling, output_shape=(latent_dim,), name="sampling")([z_mean, z_log_var])
encoder = Model(inp, [z, z_mean], name="encoder")
# Decoder
inp_encoded = Input(shape=(latent_dim,), name="sampling")
decoded = Dense(intermediate_dim, activation="relu", name="decoding")(inp_encoded)
flat_decoded = Dense(original_dim_flat, activation="sigmoid", name="flat_decoded")(decoded)
out_decoded = Reshape(original_dim, name="out_decoded")(flat_decoded)
decoder = Model(inp_encoded, out_decoded, name="decoder")
def vae_loss(x, x_decoded_mean):
xent_loss = original_dim_flat * objectives.binary_crossentropy(K.flatten(x), K.flatten(x_decoded_mean))
# Kullback-Leibler divergence - how one prob dist diverges from another
# (ie how far off we're from 0 mean and 1 std). Used as kind of regularization
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
return K.mean(xent_loss + kl_loss)
# VAE
out_combined_train = decoder(encoder(inp)[0])
out_combined_eval = decoder(encoder(inp)[1])
vae_train = Model(inp, out_combined_train)
vae_eval = Model(inp, out_combined_eval)
vae_train.compile(optimizer="rmsprop", loss=vae_loss)
return vae_train, vae_eval
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment