Skip to content

Instantly share code, notes, and snippets.

@jeasinema
Created May 28, 2018 02:14
Show Gist options
  • Save jeasinema/804554dcd79bdb5d8e7aca4304b7592a to your computer and use it in GitHub Desktop.
Save jeasinema/804554dcd79bdb5d8e7aca4304b7592a to your computer and use it in GitHub Desktop.
Different behaviours of batch normalization
import numpy as np
def main_mxnet():
import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import nn
net = nn.Sequential()
with net.name_scope():
net.add(nn.BatchNorm(momentum=0.9))
net.initialize()
p = net.collect_params()
a = mx.nd.array(np.array([[1],[2]]))
b = mx.nd.array(np.array([[3],[3]]))
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})
lossf = gluon.loss.L2Loss()
output = net(a)
print('init val of running_mean/var(in mxnet, init occurs at the first forward)')
print(p['sequential0_batchnorm0_running_mean'].data(), p['sequential0_batchnorm0_running_var'].data())
output = net(a)
print('val of running_mean/var will not change during testing')
print(p['sequential0_batchnorm0_running_mean'].data(), p['sequential0_batchnorm0_running_var'].data())
with autograd.record():
output = net(a)
loss = lossf(output, b)
loss.backward()
trainer.step(2)
print('val of running_mean/var will update during training')
print(p['sequential0_batchnorm0_running_mean'].data(), p['sequential0_batchnorm0_running_var'].data())
def main_tf():
import tensorflow as tf
sess = tf.InteractiveSession()
ain = tf.placeholder(tf.float32, [None,1])
bin = tf.placeholder(tf.float32, [None,1])
aout = tf.layers.batch_normalization(ain, training=True, fused=True, momentum=0.9)
loss = tf.losses.mean_squared_error(aout, bin)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
grad = tf.gradients(loss, tf.trainable_variables())
opt = tf.train.AdamOptimizer(0.01)
with tf.control_dependencies(update_ops):
update = opt.apply_gradients(zip(grad, tf.trainable_variables()))
p = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
sess.run(tf.global_variables_initializer())
print('init val of running_mean/var')
print(p[2].eval(), p[3].eval())
sess.run([aout], {ain:np.array([[1],[2]]), bin:np.array([[3],[3]])})
print('val of running_mean/var will not change during testing')
print(p[2].eval(), p[3].eval())
sess.run([update, aout], {ain:np.array([[1],[2]]), bin:np.array([[3],[3]])})
print('val of running_mean/var will update during training')
print(p[2].eval(), p[3].eval())
def main_pytorch():
import torch
import torch.nn
model = torch.nn.BatchNorm1d(1, momentum=0.1)
a = torch.from_numpy(np.array([[1], [2]])).float()
print('init val of running_mean/var')
print(model.running_mean, model.running_var)
model.eval()
model(a)
print('val of running_mean/var will not change during testing')
print(model.running_mean, model.running_var)
model.train()
model(a)
print('val of running_mean/var will update during training')
print(model.running_mean, model.running_var)
# model = model.cuda()
# a = torch.from_numpy(np.array([[1,1], [2,2]])).float().cuda()
# print(model.running_mean, model.running_var)
# model.train()
# model(a)
# print(model.running_mean, model.running_var)
# model.eval()
# model(a)
# print(model.running_mean, model.running_var)
if __name__ == '__main__':
main_mxnet()
main_tf()
main_pytorch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment