Skip to content

Instantly share code, notes, and snippets.

@ericjang
Created November 9, 2016 05:32
Show Gist options
  • Save ericjang/19c94c216ed7bbe21d7142e3fcdc8afa to your computer and use it in GitHub Desktop.
Save ericjang/19c94c216ed7bbe21d7142e3fcdc8afa to your computer and use it in GitHub Desktop.
# temperature
tau = tf.Variable(5.0,name="temperature")
# sample and reshape back (shape=(batch_size,N,K))
# set hard=True for ST Gumbel-Softmax
y = tf.reshape(gumbel_softmax(logits_y,tau,hard=False),[-1,N,K])
# generative model p(x|y), i.e. the decoder (shape=(batch_size,200))
net = slim.stack(slim.flatten(y),slim.fully_connected,[256,512])
logits_x = slim.fully_connected(net,784,activation_fn=None)
# (shape=(batch_size,784))
p_x = Bernoulli(logits=logits_x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment