-
-
Save insujeon/e2e7932c8db8e559e15cf2d53b60d8fe to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def sample_gumbel(shape, eps=1e-20): | |
"""Sample from Gumbel(0, 1)""" | |
U = tf.random_uniform(shape,minval=0,maxval=1) | |
return -tf.log(-tf.log(U + eps) + eps) | |
def gumbel_softmax_sample(logits, temperature): | |
""" Draw a sample from the Gumbel-Softmax distribution""" | |
y = logits + sample_gumbel(tf.shape(logits)) | |
return tf.nn.softmax( y / temperature) | |
def gumbel_softmax(logits, temperature, hard=False): | |
"""Sample from the Gumbel-Softmax distribution and optionally discretize. | |
Args: | |
logits: [batch_size, n_class] unnormalized log-probs | |
temperature: non-negative scalar | |
hard: if True, take argmax, but differentiate w.r.t. soft sample y | |
Returns: | |
[batch_size, n_class] sample from the Gumbel-Softmax distribution. | |
If hard=True, then the returned sample will be one-hot, otherwise it will | |
be a probabilitiy distribution that sums to 1 across classes | |
""" | |
y = gumbel_softmax_sample(logits, temperature) | |
if hard: | |
k = tf.shape(logits)[-1] | |
#y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype) | |
y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype) | |
y = tf.stop_gradient(y_hard - y) + y | |
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment