Created
January 12, 2018 12:25
-
Star
(114)
You must be signed in to star a gist -
Fork
(14)
You must be signed in to fork a gist
-
-
Save yzh119/fd2146d2aeb329d067568a493b20172f to your computer and use it in GitHub Desktop.
ST-Gumbel-Softmax-Pytorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
def sample_gumbel(shape, eps=1e-20): | |
U = torch.rand(shape).cuda() | |
return -Variable(torch.log(-torch.log(U + eps) + eps)) | |
def gumbel_softmax_sample(logits, temperature): | |
y = logits + sample_gumbel(logits.size()) | |
return F.softmax(y / temperature, dim=-1) | |
def gumbel_softmax(logits, temperature): | |
""" | |
input: [*, n_class] | |
return: [*, n_class] an one-hot vector | |
""" | |
y = gumbel_softmax_sample(logits, temperature) | |
shape = y.size() | |
_, ind = y.max(dim=-1) | |
y_hard = torch.zeros_like(y).view(-1, shape[-1]) | |
y_hard.scatter_(1, ind.view(-1, 1), 1) | |
y_hard = y_hard.view(*shape) | |
return (y_hard - y).detach() + y | |
if __name__ == '__main__': | |
import math | |
print(gumbel_softmax(Variable(torch.cuda.FloatTensor([[math.log(0.1), math.log(0.4), math.log(0.3), math.log(0.2)]] * 20000)), 0.8).sum(dim=0)) |
@ibrahim10h Right it's okay in this case because it's sending actual log of normalized probabilities. But in general neural network, we refer the output of network as logits
which could be the log of normalized probabilities with arbitrary offset. This stand alone example is correct, but it could induce potential error for people carelessly just copy paste.
Hi, I am trying to implement this gumbel-softmax trick to a vae autoencoder for data synthesization. Here is the implementation. Am i doing something wrong ? thank you
import logging
import pandas as pd
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.contrib.distributions import (Bernoulli, OneHotCategorical,
RelaxedOneHotCategorical,
kl_divergence)
from tensorflow.keras import backend as K
from tensorflow.keras import layers
from tensorflow.keras.models import Model
logging.getLogger('tensorflow').disabled = True
class DiscreteVAE:
def encoder(self, latent_dim, input_dim):
encoder_input = layers.Input(shape=(input_dim, ), name='encoder_input')
x = encoder_input
x = layers.Dense(256,
activation='relu',
kernel_initializer='random_uniform',
name='Dense_1')(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(256,
activation='relu',
kernel_initializer='random_uniform',
name='Dense_2')(x)
x = tf.keras.layers.Dense(latent_dim)(x)
encoder_model = Model(inputs=encoder_input, outputs=x)
encoder_model.summary()
return encoder_model
def decoder(self, latent_dim, input_dim):
decoder_input = layers.Input(latent_dim, name='decoder_input')
x = decoder_input
x = layers.Dense(256,
activation='relu',
kernel_initializer='random_uniform',
name='Dense_1')(x)
x = layers.Dense(256,
activation='relu',
kernel_initializer='random_uniform',
name='Dense_2')(x)
decoded_input = layers.Dense(input_dim, name='decoded_input')(x)
decoder_model = Model(decoder_input, decoded_input)
decoder_model.summary()
return decoder_model
def sample_gumbel(self, shape, eps=1e-20):
"""Sample from Gumbel(0, 1)"""
U = tf.random_uniform(shape, minval=0, maxval=1, dtype=tf.float32)
return -tf.log(-tf.log(U + eps) + eps)
def gumbel_softmax_sample(self, logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
y = logits + self.sample_gumbel(tf.shape(logits))
return tf.nn.softmax(y / temperature)
def gumbel_softmax(self, args):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
Args:
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
Returns:
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probability distribution that sums to 1 across classes
"""
logits, temperature = args
y = self.gumbel_softmax_sample(logits, temperature)
# k = tf.shape(logits)[-1]
# y_hard = tf.cast(tf.one_hot(tf.argmax(y, 1), k), y.dtype)
y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 1, keep_dims=True)),
y.dtype)
y = tf.stop_gradient(y_hard - y) + y
return y
def CatVAE_loss(self, encoded_input, decoded_input, z, x, tau, latent_dim):
reconstruction_error = tf.reduce_sum(
Bernoulli(logits=decoded_input).log_prob(x), 1)
logits_pz = tf.ones_like(decoded_input) * (1. / latent_dim)
q_cat_z = OneHotCategorical(logits=encoded_input)
p_cat_z = OneHotCategorical(logits=logits_pz)
KL_qp = kl_divergence(q_cat_z, p_cat_z)
ELBO = tf.reduce_mean(reconstruction_error - KL_qp)
loss = -ELBO
return loss
def build_vae(self, latent_dim, input_dim, opt, data):
tau = 0.5
input_x = layers.Input(shape=input_dim, name='vae_input')
encoder_m = self.encoder(latent_dim, input_dim)
logits_y = encoder_m(input_x)
z = layers.Lambda(self.gumbel_softmax)([logits_y, tau])
decoder_m = self.decoder(latent_dim, input_dim)
decoded_input = decoder_m(z)
# loss = self.vae_loss(input_x, input_dim, decoded_input, data)
loss = self.CatVAE_loss(logits_y, decoded_input, z, input_x, tau,
latent_dim)
vae = Model(input_x, decoded_input)
vae.add_loss(loss)
vae.compile(optimizer=opt)
return vae, decoder_m, encoder_m
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@JACKHAHA363 Line 30 already provides the 'logits' parameter as the log of what appears to be the softmax of a vector: math.log([0.1, 0.4 ,0.3, 0.2]). This may be why F.log_softmax(logits) was not done on Line 12.