Created
November 9, 2016 05:30
-
-
Save ericjang/1001afd374c2c3b7752545ce6d9ed349 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sample_gumbel(shape, eps=1e-20): | |
"""Sample from Gumbel(0, 1)""" | |
U = tf.random_uniform(shape,minval=0,maxval=1) | |
return -tf.log(-tf.log(U + eps) + eps) | |
def gumbel_softmax_sample(logits, temperature): | |
""" Draw a sample from the Gumbel-Softmax distribution""" | |
y = logits + sample_gumbel(tf.shape(logits)) | |
return tf.nn.softmax( y / temperature) | |
def gumbel_softmax(logits, temperature, hard=False): | |
"""Sample from the Gumbel-Softmax distribution and optionally discretize. | |
Args: | |
logits: [batch_size, n_class] unnormalized log-probs | |
temperature: non-negative scalar | |
hard: if True, take argmax, but differentiate w.r.t. soft sample y | |
Returns: | |
[batch_size, n_class] sample from the Gumbel-Softmax distribution. | |
If hard=True, then the returned sample will be one-hot, otherwise it will | |
be a probabilitiy distribution that sums to 1 across classes | |
""" | |
y = gumbel_softmax_sample(logits, temperature) | |
if hard: | |
k = tf.shape(logits)[-1] | |
#y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype) | |
y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype) | |
y = tf.stop_gradient(y_hard - y) + y | |
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, Eric! Thanks for sharing the code. I have a question regarding line 24-26. You implemented two ways to compute the one-hot. It seems that line 26 could lead to multiple ones when tie occurs, though that is very unlikely. Which way is better? Which is tested and used in your paper?