Skip to content

Instantly share code, notes, and snippets.

@a-agmon
Last active July 29, 2021 17:47
Show Gist options
  • Save a-agmon/5fda007baa494105ace4a58be61a1284 to your computer and use it in GitHub Desktop.
Save a-agmon/5fda007baa494105ace4a58be61a1284 to your computer and use it in GitHub Desktop.
# the KL loss function:
def vae_loss(x, x_decoded_mean):
# compute the average MSE error, then scale it up, ie. simply sum on all axes
reconstruction_loss = K.sum(K.square(x - x_decoded_mean))
# compute the KL loss
kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.square(K.exp(z_log_var)), axis=-1)
# return the average loss over all
total_loss = K.mean(reconstruction_loss + kl_loss)
return total_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment