Skip to content

Instantly share code, notes, and snippets.

@wkcn
Last active March 29, 2019 06:30
Show Gist options
  • Save wkcn/7632423ae7195e00292f544a61b75839 to your computer and use it in GitHub Desktop.
Save wkcn/7632423ae7195e00292f544a61b75839 to your computer and use it in GitHub Desktop.
test bn
from __future__ import print_function
import sys
import os
import tempfile
import time
import multiprocessing as mp
import unittest
import random
import mxnet as mx
import numpy as np
import unittest
import math
from nose.tools import assert_raises
from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal
from mxnet.base import MXNetError
from mxnet import autograd
from numpy.testing import assert_allclose
from mxnet.test_utils import rand_ndarray
def with_seed():
def wrapper(func):
return func
return wrapper
def _check_batchnorm_result(input, num_devices=1, cuda=False):
from mxnet.gluon.utils import split_and_load
def _find_bn(module):
if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)):
return module
elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)):
return module.module
raise RuntimeError('BN not found')
def _syncParameters(bn1, bn2, ctx):
ctx = input.context
bn2.gamma.set_data(bn1.gamma.data(ctx))
bn2.beta.set_data(bn1.beta.data(ctx))
bn2.running_mean.set_data(bn1.running_mean.data(ctx))
bn2.running_var.set_data(bn1.running_var.data(ctx))
input1 = input.copy()
input2 = input.copy()
if cuda:
input1 = input.as_in_context(mx.gpu(0))
ctx_list = [mx.gpu(i) for i in range(num_devices)]
else:
ctx_list = [mx.cpu(0) for _ in range(num_devices)]
nch = input.shape[1] if input.ndim > 1 else 1
bn1 = mx.gluon.nn.BatchNorm(in_channels=nch)
bn2 = mx.gluon.contrib.nn.SyncBatchNorm(
in_channels=nch, num_devices=num_devices)
bn1.initialize(ctx=ctx_list[0])
bn2.initialize(ctx=ctx_list)
# using the same values for gamma and beta
#_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0])
input1.attach_grad()
inputs2 = split_and_load(input2, ctx_list, batch_axis=0)
for xi in inputs2:
xi.attach_grad()
with mx.autograd.record():
output1 = bn1(input1)
output2 = [bn2(xi) for xi in inputs2]
loss1 = (output1 ** 2).sum()
loss2 = [(output ** 2).sum() for output in output2]
mx.autograd.backward(loss1)
mx.autograd.backward(loss2)
output2 = mx.nd.concat(*[output.as_in_context(input.context)
for output in output2], dim=0)
# assert forwarding
assert_almost_equal(input1.asnumpy(), input2.asnumpy(),
atol=1e-3, rtol=1e-3)
assert_almost_equal(output1.asnumpy(),
output2.asnumpy(), atol=1e-3, rtol=1e-3)
assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(),
_find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(),
atol=1e-3, rtol=1e-3)
assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(),
_find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(),
atol=1e-3, rtol=1e-3)
input2grad = mx.nd.concat(
*[output.grad.as_in_context(input.context) for output in inputs2], dim=0)
assert_almost_equal(input1.grad.asnumpy(),
input2grad.asnumpy(), atol=1e-3, rtol=1e-3)
@with_seed()
def test_sync_batchnorm():
cfgs = [(1, False)]
num_gpus = mx.context.num_gpus()
for i in range(1, num_gpus + 1):
cfgs.append((i, True))
for ndev, cuda in cfgs:
# check with unsync version
for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]:
for i in range(10):
_check_batchnorm_result(mx.nd.random.uniform(shape=shape),
num_devices=ndev, cuda=cuda)
@with_seed()
def test_batchnorm():
momentum = 0.9
epsilon = 1e-5
for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]:
for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]:
for axis in range(len(shape)):
kwargs = dict()
if op == mx.nd.contrib.SyncBatchNorm:
if axis != 1:
continue
key = str(op) + str(shape) + str(axis)
kwargs.update(dict(key=key))
else:
kwargs.update(dict(axis=axis))
print(op, shape, axis)
nch = shape[axis]
bn_gamma = mx.nd.random.uniform(shape=(nch,))
bn_gamma.attach_grad()
bn_beta = mx.nd.random.uniform(shape=(nch,))
bn_beta.attach_grad()
bn_running_mean = mx.nd.zeros(nch)
bn_running_var = mx.nd.ones(nch)
running_mean = mx.nd.zeros(nch)
running_var = mx.nd.ones(nch)
num_iters = 10
expand_shape = [1] * len(shape)
expand_shape[axis] = shape[axis]
for _ in range(num_iters):
data = mx.nd.random.uniform(shape=shape)
data.attach_grad()
ograd = mx.nd.random.uniform(shape=shape)
with mx.autograd.record():
output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var,
momentum=momentum, eps=epsilon, fix_gamma=False, **kwargs)
output.backward(ograd)
mx.nd.waitall()
data_mean = data.mean(
axis=axis, exclude=True, keepdims=True)
data_var = (data - data_mean).square().mean(axis=axis,
exclude=True, keepdims=True)
target_output = (data - data_mean) / (data_var + epsilon).sqrt() * \
bn_gamma.reshape(expand_shape) + \
bn_beta.reshape(expand_shape)
# squeeze data_mean and data_var
data_mean_flat = data_mean.squeeze()
data_var_flat = data_var.squeeze()
running_mean = running_mean * momentum + \
data_mean_flat * (1 - momentum)
running_var = running_var * momentum + \
data_var_flat * (1 - momentum)
W = bn_gamma.reshape(expand_shape)
dnx = ograd * W
xsm = data - data_mean
nd = 1.0 / mx.nd.sqrt(data_var + epsilon)
nx = xsm * nd
m = np.prod(shape) / shape[axis]
dvar = (dnx * xsm).sum(axis=axis, keepdims=True,
exclude=True) * (-0.5) * mx.nd.power(nd, 3.0)
dmean = -nd * dnx.sum(axis=axis, keepdims=True, exclude=True) - \
dvar * xsm.mean(axis=axis, keepdims=True,
exclude=True) * 2.0
dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m)
dW = (ograd * nx).sum(axis=axis, exclude=True)
db = ograd.sum(axis=axis, exclude=True)
#assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=1e-3, rtol=1e-3)
assert_almost_equal(bn_running_mean.asnumpy(
), running_mean.asnumpy(), atol=1e-3, rtol=1e-3)
assert_almost_equal(bn_running_var.asnumpy(
), running_var.asnumpy(), atol=1e-3, rtol=1e-3)
assert_almost_equal(data.grad.asnumpy(),
dX.asnumpy(), atol=1e-3, rtol=1e-3)
assert_almost_equal(
bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=1e-3, rtol=1e-3)
assert_almost_equal(
bn_beta.grad.asnumpy(), db.asnumpy(), atol=1e-3, rtol=1e-3)
if __name__ == '__main__':
test_batchnorm()
set_default_context(mx.cpu())
print("Test OK")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment