Created
December 20, 2017 11:40
-
-
Save RomanSteinberg/54e516cc20ebfaed1a01cfc8b0b5765c to your computer and use it in GitHub Desktop.
VAE (MNIST)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# import tensorflow as tf | |
import tensorflow as tf, numpy as np | |
from tensorflow import nn | |
from tensorflow import keras as ke | |
from tensorflow.examples.tutorials.mnist import input_data | |
class VAELoss(ke.layers.Layer): | |
def __init__(self, **kwargs): | |
self.is_placeholder = True | |
super(VAELoss, self).__init__(**kwargs) | |
def __call__(self, *args, **kwargs): | |
return super(VAELoss, self).__call__(*args, **kwargs) | |
def vae_loss(self, input_var, mu, logstd, reconstruction): | |
log_likelihood = tf.reduce_sum(input_var * tf.log(reconstruction + 1e-9) + | |
(1 - input_var) * tf.log(1 - reconstruction + 1e-9), | |
reduction_indices=1) | |
KL_term = -.5 * tf.reduce_sum(1 + 2 * logstd - tf.pow(mu, 2) - tf.exp(2 * logstd), reduction_indices=1) | |
variational_lower_bound = tf.reduce_mean(log_likelihood - KL_term) | |
return -variational_lower_bound | |
def call(self, inputs, **kwargs): | |
loss = self.vae_loss(*inputs) | |
self.add_loss(loss, inputs=inputs) | |
return inputs[0] | |
class VAE: | |
def __init__(self): | |
self.latent_dim = 20 | |
self.h_dim = 500 | |
self.image_shape = 784 # (1280, 5000) | |
self.decoder_layers = [ke.layers.Dense(self.h_dim, activation='tanh'), | |
ke.layers.Dense(self.image_shape, activation='sigmoid')] | |
def fc_encoder(self, previous): | |
out = ke.layers.Dense(self.h_dim, activation='tanh')(previous) | |
return out | |
def fc_decoder(self, previous): | |
out = previous | |
for l in self.decoder_layers: | |
out = l(out) | |
return out | |
def distribution_layers(self, previous): | |
mu = ke.layers.Dense(self.latent_dim, activation='tanh')(previous) | |
logstd = ke.layers.Dense(self.latent_dim, activation='tanh')(previous) | |
la = lambda args: args[0] + tf.random_normal([1, self.latent_dim]) * tf.exp(.5 * args[1]) | |
sample = ke.layers.Lambda(la)([mu, logstd]) | |
return sample, mu, logstd | |
def loss(self, input_var, mu, logstd, reconstruction): | |
out = VAELoss()([input_var, mu, logstd, reconstruction]) | |
return out | |
def create_train_pass(self): | |
inp = ke.layers.Input(shape=(self.image_shape,)) | |
h = self.fc_encoder(inp) | |
z, mu, logstd = self.distribution_layers(h) | |
decoded = self.fc_decoder(z) | |
out = self.loss(inp, mu, logstd, decoded) | |
model = ke.models.Model(inputs=inp, outputs=out) | |
gen_inp = ke.layers.Input(shape=(self.latent_dim,)) | |
reconstruction = self.fc_decoder(gen_inp) | |
generator = ke.models.Model(inputs=gen_inp, outputs=reconstruction) | |
return model, generator | |
def create_gen_path(self): | |
pass | |
def train(train_data): | |
vae = VAE() | |
model, generator = vae.create_train_pass() | |
model.compile(optimizer='adam', loss=None) | |
model.fit(train_data, shuffle=True, batch_size=128, epochs=1, verbose=1) | |
generator.save('gen.h5') | |
def main(): | |
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) | |
train(mnist.train.images) | |
if __name__ == '__main__': | |
main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment