Last active
July 15, 2017 03:37
-
-
Save kevinduh/68e07ca68940e8252861dc7c48d25a06 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
import numpy as np | |
## Example of Gumbel-softmax ## | |
## user settings | |
batch_size = 2 | |
cardinality = 3 | |
num_samples = 5 | |
temperature = 1.0 | |
def gumbel_softmax_sample(output_dimension, hard=True): | |
""" Draw a sample from the Gumbel-Softmax distribution (mxnet implementation)""" | |
eps = 1e-20 | |
uniform = mx.sym.Variable('uniform') | |
logits = mx.sym.Variable('logits') | |
temperature = mx.sym.Variable('temperature') | |
gumbel = -mx.sym.log(-mx.sym.log(uniform + eps) + eps) | |
y = mx.sym.softmax(mx.sym.broadcast_div((logits + gumbel),temperature)) | |
if hard: | |
#y_hard0 = mx.sym.cast(mx.sym.argmax(y, axis=1, keepdims=True), dtype='int32') | |
#y_hard0 = mx.sym.cast(mx.sym.argmax(y, axis=1), dtype='int32') | |
y_hard = mx.sym.one_hot(indices=mx.sym.argmax(y, axis=1), depth=output_dimension) | |
y = mx.sym.BlockGrad(y_hard - y) + y | |
return y | |
def usual_sample(logits, num_samples): | |
""" Draw a sample the usual way (numpy implementation)""" | |
e = np.exp(logits) | |
normalizer = np.sum(e, axis=1).reshape((e.shape[0],1)) | |
probabilities = e/normalizer | |
index = range(0,e.shape[1]) | |
c = {} # sampled indices | |
for row in xrange(e.shape[0]): | |
c[row] = np.random.choice(index, num_samples, p=probabilities[row]) | |
print "probabilities (each row is a distribution):\n", probabilities | |
for n in xrange(num_samples): | |
onehot = np.zeros(e.shape) | |
for row in xrange(e.shape[0]): | |
onehot[row][c[row][n]] = 1 | |
print "usual sample %d:" %n | |
print onehot | |
return c | |
## create computational graph for gumbel | |
x = mx.nd.array(np.random.randn(batch_size, cardinality)) | |
y1 = gumbel_softmax_sample(cardinality, True) | |
y2 = gumbel_softmax_sample(cardinality, False) | |
## print out samples | |
print "logits x for discrete distribution exp(x_k)/sum_{j=1}^K{exp(x_j)}:\n", x.asnumpy() | |
print "batch_size=%d, cardinality k=%d " %(batch_size, cardinality) | |
print "\nsamples from gumbel-softmax, temperature=%f" % temperature | |
for i in xrange(num_samples): | |
uniform_sample = mx.nd.random_uniform(low=0,high=1,shape=(batch_size, cardinality)) | |
ex1=y1.eval(mx.cpu(),logits=x,uniform=uniform_sample,temperature=mx.nd.array([temperature])) | |
ex2=y2.eval(mx.cpu(),logits=x,uniform=uniform_sample,temperature=mx.nd.array([temperature])) | |
print "gumbel-softmax sample %d (straight-through & original soft version):"%i | |
print ex1[0].asnumpy() | |
print ex2[0].asnumpy() | |
print "\nsamples from usual procedure:" | |
usual_sample(x.asnumpy(), num_samples) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment