Skip to content

Instantly share code, notes, and snippets.

@charliememory
Created August 28, 2018 16:27
Show Gist options
  • Save charliememory/035a1607682f6f3663c101d0fe013d9a to your computer and use it in GitHub Desktop.
Save charliememory/035a1607682f6f3663c101d0fe013d9a to your computer and use it in GitHub Desktop.
tensorflow implement of bernoulli sampling
#################### Bernoulli Sample #####################
## ref code: https://r2rt.com/binary-stochastic-neurons-in-tensorflow.html
def bernoulliSample(x):
"""
Uses a tensor whose values are in [0,1] to sample a tensor with values in {0, 1},
using the straight through estimator for the gradient.
E.g.,:
if x is 0.6, bernoulliSample(x) will be 1 with probability 0.6, and 0 otherwise,
and the gradient will be pass-through (identity).
"""
g = tf.get_default_graph()
with ops.name_scope("BernoulliSample") as name:
with g.gradient_override_map({"Ceil": "Identity","Sub": "BernoulliSample_ST"}):
return tf.ceil(x - tf.random_uniform(tf.shape(x)), name=name)
@ops.RegisterGradient("BernoulliSample_ST")
def bernoulliSample_ST(op, grad):
return [grad, tf.zeros(tf.shape(op.inputs[1]))]
###########################################################
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment