Created
March 3, 2022 23:14
-
-
Save ZaxR/a0c3f53c6f81bfa8f846128cfde0a316 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.keras import backend as K, metrics | |
from tensorflow.keras.callbacks import ModelCheckpoint | |
from tensorflow.keras.layers import concatenate, Dense, Input, Layer | |
from tensorflow.keras.models import Model, model_from_json | |
from tensorflow.python.framework.ops import disable_eager_execution | |
import tensorflow as tf | |
from tensorflow import keras | |
class Sampling(Layer): | |
"""Uses (z_mean, z_log_var) to sample z, the vector encoding.""" | |
def call(self, inputs, epsilon_std=1.0): | |
z_mean, z_log_var = inputs | |
batch = tf.shape(z_mean)[0] | |
dim = tf.shape(z_mean)[1] | |
# sample Z from isotropic normal | |
epsilon = K.random_normal(shape=(batch, dim), mean=0.0, stddev=epsilon_std) | |
return z_mean + tf.exp(0.5 * z_log_var) * epsilon | |
class Encoder(Layer): | |
"""Maps raw inputs to a triplet (z_mean, z_log_var, z).""" | |
def __init__(self, latent_dim=32, intermediate_dim=64, id: str = "a", **kwargs): | |
super(Encoder, self).__init__(name=f"encoder_{id}", **kwargs) | |
self.dense_proj = Dense(intermediate_dim, activation="relu") | |
self.dense_mean = Dense(latent_dim) | |
self.dense_log_var = Dense(latent_dim) | |
self.sampling = Sampling() | |
def call(self, inputs): | |
x = self.dense_proj(inputs) | |
z_mean = self.dense_mean(x) | |
z_log_var = self.dense_log_var(x) | |
z = self.sampling((z_mean, z_log_var)) | |
return z_mean, z_log_var, z | |
class Decoder(Layer): | |
"""Converts z, the encoded vector, back into a readable raw input.""" | |
def __init__(self, original_dim: int, intermediate_dim: int = 64, id: str = "a", **kwargs): | |
super(Decoder, self).__init__(name=f"decoder_{id}", **kwargs) | |
self.dense_proj = Dense(intermediate_dim, activation="relu") | |
self.dense_output = Dense(original_dim, activation="sigmoid") | |
def call(self, inputs): | |
x = self.dense_proj(inputs) | |
return self.dense_output(x) | |
class VariationalAutoEncoder(keras.Model): | |
"""A Denoising, Sparse Auto Encoder for use with the NON-embedding layer version of the LVA model. | |
Combines the encoder and decoder into an end-to-end model for training. | |
Uses reconstruction + KL losses. | |
See: | |
https://arxiv.org/pdf/1801.01586.pdf | |
""" | |
def __init__( | |
self, | |
original_dim: int, | |
intermediate_dim: int = 64, | |
latent_dim: int = 32, | |
id: str = "a", | |
**kwargs | |
): | |
super(VariationalAutoEncoder, self).__init__(name=f"vae_{id}", **kwargs) | |
self.original_dim = original_dim | |
self.id = id | |
self.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim, id=id) | |
self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim, id=id) | |
def call(self, inputs): | |
z_mean, z_log_var, z = self.encoder(inputs) | |
reconstructed = self.decoder(z) | |
reconstruction_loss = self.original_dim * metrics.mean_squared_error(inputs, reconstructed) | |
kl_loss = -0.5 * tf.reduce_mean( | |
z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1 | |
) | |
self.add_loss(reconstruction_loss + kl_loss) | |
return reconstructed |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment