Skip to content

Instantly share code, notes, and snippets.

Forked from ericjang/
Created January 11, 2018 21:18
Show Gist options
  • Save yzh119/fae4b4fc89b9f777207ebeb9261e87bf to your computer and use it in GitHub Desktop.
Save yzh119/fae4b4fc89b9f777207ebeb9261e87bf to your computer and use it in GitHub Desktop.
def sample_gumbel(shape, eps=1e-20):
"""Sample from Gumbel(0, 1)"""
U = tf.random_uniform(shape,minval=0,maxval=1)
return -tf.log(-tf.log(U + eps) + eps)
def gumbel_softmax_sample(logits, temperature):
""" Draw a sample from the Gumbel-Softmax distribution"""
y = logits + sample_gumbel(tf.shape(logits))
return tf.nn.softmax( y / temperature)
def gumbel_softmax(logits, temperature, hard=False):
"""Sample from the Gumbel-Softmax distribution and optionally discretize.
logits: [batch_size, n_class] unnormalized log-probs
temperature: non-negative scalar
hard: if True, take argmax, but differentiate w.r.t. soft sample y
[batch_size, n_class] sample from the Gumbel-Softmax distribution.
If hard=True, then the returned sample will be one-hot, otherwise it will
be a probabilitiy distribution that sums to 1 across classes
y = gumbel_softmax_sample(logits, temperature)
if hard:
k = tf.shape(logits)[-1]
#y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)
y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
y = tf.stop_gradient(y_hard - y) + y
return y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment