Skip to content

Instantly share code, notes, and snippets.

@ZaxR
Created March 3, 2022 23:14
Show Gist options
  • Save ZaxR/a0c3f53c6f81bfa8f846128cfde0a316 to your computer and use it in GitHub Desktop.
Save ZaxR/a0c3f53c6f81bfa8f846128cfde0a316 to your computer and use it in GitHub Desktop.
from tensorflow.keras import backend as K, metrics
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import concatenate, Dense, Input, Layer
from tensorflow.keras.models import Model, model_from_json
from tensorflow.python.framework.ops import disable_eager_execution
import tensorflow as tf
from tensorflow import keras
class Sampling(Layer):
"""Uses (z_mean, z_log_var) to sample z, the vector encoding."""
def call(self, inputs, epsilon_std=1.0):
z_mean, z_log_var = inputs
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
# sample Z from isotropic normal
epsilon = K.random_normal(shape=(batch, dim), mean=0.0, stddev=epsilon_std)
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
class Encoder(Layer):
"""Maps raw inputs to a triplet (z_mean, z_log_var, z)."""
def __init__(self, latent_dim=32, intermediate_dim=64, id: str = "a", **kwargs):
super(Encoder, self).__init__(name=f"encoder_{id}", **kwargs)
self.dense_proj = Dense(intermediate_dim, activation="relu")
self.dense_mean = Dense(latent_dim)
self.dense_log_var = Dense(latent_dim)
self.sampling = Sampling()
def call(self, inputs):
x = self.dense_proj(inputs)
z_mean = self.dense_mean(x)
z_log_var = self.dense_log_var(x)
z = self.sampling((z_mean, z_log_var))
return z_mean, z_log_var, z
class Decoder(Layer):
"""Converts z, the encoded vector, back into a readable raw input."""
def __init__(self, original_dim: int, intermediate_dim: int = 64, id: str = "a", **kwargs):
super(Decoder, self).__init__(name=f"decoder_{id}", **kwargs)
self.dense_proj = Dense(intermediate_dim, activation="relu")
self.dense_output = Dense(original_dim, activation="sigmoid")
def call(self, inputs):
x = self.dense_proj(inputs)
return self.dense_output(x)
class VariationalAutoEncoder(keras.Model):
"""A Denoising, Sparse Auto Encoder for use with the NON-embedding layer version of the LVA model.
Combines the encoder and decoder into an end-to-end model for training.
Uses reconstruction + KL losses.
See:
https://arxiv.org/pdf/1801.01586.pdf
"""
def __init__(
self,
original_dim: int,
intermediate_dim: int = 64,
latent_dim: int = 32,
id: str = "a",
**kwargs
):
super(VariationalAutoEncoder, self).__init__(name=f"vae_{id}", **kwargs)
self.original_dim = original_dim
self.id = id
self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim, id=id)
self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim, id=id)
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
reconstruction_loss = self.original_dim * metrics.mean_squared_error(inputs, reconstructed)
kl_loss = -0.5 * tf.reduce_mean(
z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1
)
self.add_loss(reconstruction_loss + kl_loss)
return reconstructed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment