Last active
July 18, 2020 08:01
-
-
Save gilbertfrancois/11474ff67466067bfee51c96da1bc6f6 to your computer and use it in GitHub Desktop.
MXNet gluon.nn.BatchNorm bug report
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://discuss.mxnet.io/t/batchnorm-running-var-value-depends-on-context/3401 | |
# https://discuss.mxnet.io/t/hw-9-how-to-prevent-overflow-to-nan/3853 | |
# https://github.com/apache/incubator-mxnet/issues/18100 | |
# https://github.com/apache/incubator-mxnet/issues/18209 | |
# https://github.com/apache/incubator-mxnet/issues/14710 | |
import matplotlib.pyplot as plt | |
import mxnet as mx | |
import mxnet.ndarray as nd | |
import numpy as np | |
from mxnet import autograd | |
from mxnet.gluon import nn | |
from mxnet.gluon.block import HybridBlock | |
BN_PARAMS = ["gamma", "beta", "running_mean", "running_var"] | |
def check_params(net, ctx, iteration=0): | |
mean_params = [] | |
for name, param in net.collect_params().items(): | |
for var in BN_PARAMS: | |
if var in name: | |
print(f"[{str(iteration):>4}]: {var:>15} on {str(ctx)}: {param.data(ctx).asnumpy()}") | |
mean_params.append(nd.mean(param.data(ctx)).asscalar()) | |
return mean_params | |
class MyNet(HybridBlock): | |
def __init__(self, **kwargs): | |
super(MyNet, self).__init__(**kwargs) | |
self.body = nn.HybridSequential() | |
self.body.add(nn.BatchNorm()) | |
def hybrid_forward(self, F, x, *args, **kwargs): | |
return self.body(x) | |
def init_forward_net(ctx_list): | |
params_list = [] | |
net = MyNet() | |
net.initialize(mx.init.Xavier(), ctx=ctx_list[0]) | |
X = nd.ones(shape=(8, 3, 32, 32), ctx=ctx_list[0]) | |
print(net.summary(X)) | |
params_list.append(check_params(net, ctx_list[0])) | |
# Couple of forward runs with autograd enabled, without trainer or backward pass. | |
for iteration in range(100): | |
with autograd.record(train_mode=True): | |
y = net(X) | |
params_list.append(check_params(net, ctx_list[0], iteration)) | |
return np.array(params_list) | |
if __name__ == '__main__': | |
params_cpu = init_forward_net([mx.cpu()]) | |
params_gpu = init_forward_net([mx.gpu(0)]) | |
print() | |
tolerance = 1.0e-5 | |
for i, var in enumerate(BN_PARAMS): | |
is_equal = np.allclose(params_cpu[:, i], params_gpu[:, i]) | |
diff = np.abs(params_cpu[:, i] - params_gpu[:, i]) | |
print(f"{var:>20} on CPU and GPU are (almost) equal: {str(is_equal):>6}, err: {np.mean(diff):0.5f}+-{np.std(diff):0.5f}") | |
plt.figure() | |
plt.plot(params_cpu[:, 2], label="running_mean CPU") | |
plt.plot(params_gpu[:, 2], label="running_mean GPU") | |
plt.plot(params_cpu[:, 3], label="running_var CPU") | |
plt.plot(params_gpu[:, 3], label="running_var GPU") | |
plt.title("parameter comparision of BatchNorm on CPU and GPU") | |
plt.ylim((-0.1, 1.1)) | |
plt.xlabel("Iteration") | |
plt.legend() | |
plt.savefig("plot_bn_test.png") | |
plt.show() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment