Last active
March 29, 2019 06:30
-
-
Save wkcn/7632423ae7195e00292f544a61b75839 to your computer and use it in GitHub Desktop.
test bn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import sys | |
import os | |
import tempfile | |
import time | |
import multiprocessing as mp | |
import unittest | |
import random | |
import mxnet as mx | |
import numpy as np | |
import unittest | |
import math | |
from nose.tools import assert_raises | |
from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal | |
from mxnet.base import MXNetError | |
from mxnet import autograd | |
from numpy.testing import assert_allclose | |
from mxnet.test_utils import rand_ndarray | |
def with_seed(): | |
def wrapper(func): | |
return func | |
return wrapper | |
def _check_batchnorm_result(input, num_devices=1, cuda=False): | |
from mxnet.gluon.utils import split_and_load | |
def _find_bn(module): | |
if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): | |
return module | |
elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): | |
return module.module | |
raise RuntimeError('BN not found') | |
def _syncParameters(bn1, bn2, ctx): | |
ctx = input.context | |
bn2.gamma.set_data(bn1.gamma.data(ctx)) | |
bn2.beta.set_data(bn1.beta.data(ctx)) | |
bn2.running_mean.set_data(bn1.running_mean.data(ctx)) | |
bn2.running_var.set_data(bn1.running_var.data(ctx)) | |
input1 = input.copy() | |
input2 = input.copy() | |
if cuda: | |
input1 = input.as_in_context(mx.gpu(0)) | |
ctx_list = [mx.gpu(i) for i in range(num_devices)] | |
else: | |
ctx_list = [mx.cpu(0) for _ in range(num_devices)] | |
nch = input.shape[1] if input.ndim > 1 else 1 | |
bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) | |
bn2 = mx.gluon.contrib.nn.SyncBatchNorm( | |
in_channels=nch, num_devices=num_devices) | |
bn1.initialize(ctx=ctx_list[0]) | |
bn2.initialize(ctx=ctx_list) | |
# using the same values for gamma and beta | |
#_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) | |
input1.attach_grad() | |
inputs2 = split_and_load(input2, ctx_list, batch_axis=0) | |
for xi in inputs2: | |
xi.attach_grad() | |
with mx.autograd.record(): | |
output1 = bn1(input1) | |
output2 = [bn2(xi) for xi in inputs2] | |
loss1 = (output1 ** 2).sum() | |
loss2 = [(output ** 2).sum() for output in output2] | |
mx.autograd.backward(loss1) | |
mx.autograd.backward(loss2) | |
output2 = mx.nd.concat(*[output.as_in_context(input.context) | |
for output in output2], dim=0) | |
# assert forwarding | |
assert_almost_equal(input1.asnumpy(), input2.asnumpy(), | |
atol=1e-3, rtol=1e-3) | |
assert_almost_equal(output1.asnumpy(), | |
output2.asnumpy(), atol=1e-3, rtol=1e-3) | |
assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), | |
_find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), | |
atol=1e-3, rtol=1e-3) | |
assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), | |
_find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), | |
atol=1e-3, rtol=1e-3) | |
input2grad = mx.nd.concat( | |
*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) | |
assert_almost_equal(input1.grad.asnumpy(), | |
input2grad.asnumpy(), atol=1e-3, rtol=1e-3) | |
@with_seed() | |
def test_sync_batchnorm(): | |
cfgs = [(1, False)] | |
num_gpus = mx.context.num_gpus() | |
for i in range(1, num_gpus + 1): | |
cfgs.append((i, True)) | |
for ndev, cuda in cfgs: | |
# check with unsync version | |
for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: | |
for i in range(10): | |
_check_batchnorm_result(mx.nd.random.uniform(shape=shape), | |
num_devices=ndev, cuda=cuda) | |
@with_seed() | |
def test_batchnorm(): | |
momentum = 0.9 | |
epsilon = 1e-5 | |
for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: | |
for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: | |
for axis in range(len(shape)): | |
kwargs = dict() | |
if op == mx.nd.contrib.SyncBatchNorm: | |
if axis != 1: | |
continue | |
key = str(op) + str(shape) + str(axis) | |
kwargs.update(dict(key=key)) | |
else: | |
kwargs.update(dict(axis=axis)) | |
print(op, shape, axis) | |
nch = shape[axis] | |
bn_gamma = mx.nd.random.uniform(shape=(nch,)) | |
bn_gamma.attach_grad() | |
bn_beta = mx.nd.random.uniform(shape=(nch,)) | |
bn_beta.attach_grad() | |
bn_running_mean = mx.nd.zeros(nch) | |
bn_running_var = mx.nd.ones(nch) | |
running_mean = mx.nd.zeros(nch) | |
running_var = mx.nd.ones(nch) | |
num_iters = 10 | |
expand_shape = [1] * len(shape) | |
expand_shape[axis] = shape[axis] | |
for _ in range(num_iters): | |
data = mx.nd.random.uniform(shape=shape) | |
data.attach_grad() | |
ograd = mx.nd.random.uniform(shape=shape) | |
with mx.autograd.record(): | |
output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, | |
momentum=momentum, eps=epsilon, fix_gamma=False, **kwargs) | |
output.backward(ograd) | |
mx.nd.waitall() | |
data_mean = data.mean( | |
axis=axis, exclude=True, keepdims=True) | |
data_var = (data - data_mean).square().mean(axis=axis, | |
exclude=True, keepdims=True) | |
target_output = (data - data_mean) / (data_var + epsilon).sqrt() * \ | |
bn_gamma.reshape(expand_shape) + \ | |
bn_beta.reshape(expand_shape) | |
# squeeze data_mean and data_var | |
data_mean_flat = data_mean.squeeze() | |
data_var_flat = data_var.squeeze() | |
running_mean = running_mean * momentum + \ | |
data_mean_flat * (1 - momentum) | |
running_var = running_var * momentum + \ | |
data_var_flat * (1 - momentum) | |
W = bn_gamma.reshape(expand_shape) | |
dnx = ograd * W | |
xsm = data - data_mean | |
nd = 1.0 / mx.nd.sqrt(data_var + epsilon) | |
nx = xsm * nd | |
m = np.prod(shape) / shape[axis] | |
dvar = (dnx * xsm).sum(axis=axis, keepdims=True, | |
exclude=True) * (-0.5) * mx.nd.power(nd, 3.0) | |
dmean = -nd * dnx.sum(axis=axis, keepdims=True, exclude=True) - \ | |
dvar * xsm.mean(axis=axis, keepdims=True, | |
exclude=True) * 2.0 | |
dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) | |
dW = (ograd * nx).sum(axis=axis, exclude=True) | |
db = ograd.sum(axis=axis, exclude=True) | |
#assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=1e-3, rtol=1e-3) | |
assert_almost_equal(bn_running_mean.asnumpy( | |
), running_mean.asnumpy(), atol=1e-3, rtol=1e-3) | |
assert_almost_equal(bn_running_var.asnumpy( | |
), running_var.asnumpy(), atol=1e-3, rtol=1e-3) | |
assert_almost_equal(data.grad.asnumpy(), | |
dX.asnumpy(), atol=1e-3, rtol=1e-3) | |
assert_almost_equal( | |
bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=1e-3, rtol=1e-3) | |
assert_almost_equal( | |
bn_beta.grad.asnumpy(), db.asnumpy(), atol=1e-3, rtol=1e-3) | |
if __name__ == '__main__': | |
test_batchnorm() | |
set_default_context(mx.cpu()) | |
print("Test OK") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment