Skip to content

Instantly share code, notes, and snippets.

@ericjang
Created January 17, 2018 18:24
Show Gist options
  • Save ericjang/0e02789384999a016ed5d964555d7464 to your computer and use it in GitHub Desktop.
Save ericjang/0e02789384999a016ed5d964555d7464 to your computer and use it in GitHub Desktop.
class BatchNorm(tfb.Bijector):
def __init__(self, eps=1e-5, decay=0.95, validate_args=False, name="batch_norm"):
super(BatchNorm, self).__init__(
event_ndims=1, validate_args=validate_args, name=name)
self._vars_created = False
self.eps = eps
self.decay = decay
def _create_vars(self, x):
n = x.get_shape().as_list()[1]
with tf.variable_scope(self.name):
self.beta = tf.get_variable('beta', [1, n], dtype=DTYPE)
self.gamma = tf.get_variable('gamma', [1, n], dtype=DTYPE)
self.train_m = tf.get_variable(
'mean', [1, n], dtype=DTYPE, trainable=False)
self.train_v = tf.get_variable(
'var', [1, n], dtype=DTYPE, initializer=tf.ones_initializer, trainable=False)
self._vars_created = True
def _forward(self, u):
if not self._vars_created:
self._create_vars(u)
return (u - self.beta) * tf.exp(-self.gamma) * tf.sqrt(self.train_v + self.eps) + self.train_m
def _inverse(self, x):
# Eq 22. Called during training of a normalizing flow.
if not self._vars_created:
self._create_vars(x)
# statistics of current minibatch
m, v = tf.nn.moments(x, axes=[0], keep_dims=True)
# update train statistics via exponential moving average
update_train_m = tf.assign_sub(
self.train_m, self.decay * (self.train_m - m))
update_train_v = tf.assign_sub(
self.train_v, self.decay * (self.train_v - v))
# normalize using current minibatch statistics, followed by BN scale and shift
with tf.control_dependencies([update_train_m, update_train_v]):
return (x - m) * 1. / tf.sqrt(v + self.eps) * tf.exp(self.gamma) + self.beta
def _inverse_log_det_jacobian(self, x):
# at training time, the log_det_jacobian is computed from statistics of the
# current minibatch.
if not self._vars_created:
self._create_vars(x)
_, v = tf.nn.moments(x, axes=[0], keep_dims=True)
abs_log_det_J_inv = tf.reduce_sum(
self.gamma - .5 * tf.log(v + self.eps))
return abs_log_det_J_inv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment