Last active
October 26, 2018 10:52
-
-
Save breeko/cc8dbaa940723a0365bfc362ea5dabb1 to your computer and use it in GitHub Desktop.
Creates a variational auto encoder
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
from keras.layers import Input, Dense, Lambda, Flatten, Reshape | |
from keras.models import Model | |
from keras import backend as K | |
from keras.callbacks import EarlyStopping | |
from keras import objectives | |
import numpy as np | |
from itertools import product | |
batch_size = 100 | |
digit_size = 28 | |
original_dim = (digit_size, digit_size) | |
original_dim_flat = np.prod(original_dim) | |
latent_dim = 2 | |
intermediate_dim = 512 | |
epochs = 100 | |
epsilon_std = 1.0 | |
def sampling(args): | |
""" Returns a noisy sample from a normal distribution """ | |
z_mean, z_log_var = args | |
epsilon = K.random_normal(shape=(latent_dim,), mean=0.0) | |
return z_mean + K.exp(z_log_var/2.0) * epsilon | |
def create_vae(): | |
""" Returns a variational auto encoder """ | |
inp = Input(shape=(*original_dim,), name="input") | |
# Encoder | |
f = Flatten(name="flatten")(inp) | |
h = Dense(intermediate_dim, activation="relu", name="encoding")(f) | |
z_mean = Dense(latent_dim, activation="sigmoid", name="mean")(h) | |
z_log_var = Dense(latent_dim, name="var")(h) | |
z = Lambda(sampling, output_shape=(latent_dim,), name="sampling")([z_mean, z_log_var]) | |
encoder = Model(inp, [z, z_mean], name="encoder") | |
# Decoder | |
inp_encoded = Input(shape=(latent_dim,), name="sampling") | |
decoded = Dense(intermediate_dim, activation="relu", name="decoding")(inp_encoded) | |
flat_decoded = Dense(original_dim_flat, activation="sigmoid", name="flat_decoded")(decoded) | |
out_decoded = Reshape(original_dim, name="out_decoded")(flat_decoded) | |
decoder = Model(inp_encoded, out_decoded, name="decoder") | |
def vae_loss(x, x_decoded_mean): | |
xent_loss = original_dim_flat * objectives.binary_crossentropy(K.flatten(x), K.flatten(x_decoded_mean)) | |
# Kullback-Leibler divergence - how one prob dist diverges from another | |
# (ie how far off we're from 0 mean and 1 std). Used as kind of regularization | |
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) | |
return K.mean(xent_loss + kl_loss) | |
# VAE | |
out_combined_train = decoder(encoder(inp)[0]) | |
out_combined_eval = decoder(encoder(inp)[1]) | |
vae_train = Model(inp, out_combined_train) | |
vae_eval = Model(inp, out_combined_eval) | |
vae_train.compile(optimizer="rmsprop", loss=vae_loss) | |
return vae_train, vae_eval |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment