Last active
June 30, 2019 14:40
-
-
Save chenyaofo/3b319e233b45f4167ef999d8fbcbdf70 to your computer and use it in GitHub Desktop.
The API comparison between pytorch and mxnet
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy | |
import torch | |
import torch.nn | |
import mxnet | |
def test_forward_template(module_name, data, pt_module, mx_module, is_train=True): | |
print(f"Start to test forward process {module_name}, mode = {'train' if is_train else 'validating'}") | |
if not is_train: | |
pt_module.eval() | |
pt_output = pt_module(torch.from_numpy(data)).detach().numpy() | |
if module_name == "bn": | |
mx_output = mx_module.bind(mxnet.cpu(), | |
{ | |
"data": mxnet.nd.array(data), | |
"module_gamma": mxnet.nd.array( | |
pt_module.weight.data.numpy() | |
), | |
"module_beta": mxnet.nd.array( | |
pt_module.bias.data.numpy() | |
), | |
}, | |
aux_states={ | |
"module_moving_mean": mxnet.nd.array( | |
pt_module.running_mean.data.numpy() | |
), | |
"module_moving_var": mxnet.nd.array( | |
pt_module.running_var.data.numpy() | |
), | |
} | |
).forward(is_train=is_train)[0].asnumpy() | |
else: | |
mx_output = mx_module.bind(mxnet.cpu(), | |
{ | |
"data": mxnet.nd.array(data), | |
"module_weight": mxnet.nd.array( | |
pt_module.weight.data.numpy() | |
), | |
"module_gamma": mxnet.nd.array( | |
pt_module.weight.data.numpy() | |
), | |
}).forward(is_train=is_train)[0].asnumpy() | |
print("The absolute max diff is {}, the relative max diff is {}".format( | |
numpy.abs(pt_output - mx_output).max(), numpy.abs((pt_output - mx_output) / pt_output).max() | |
)) | |
# def test_backward_template(module_name, pt_output, mx_output, is_train): | |
# if is_train: | |
# pass | |
# else: | |
# print(f"Start to test backward process {module_name}, mode = {'train' if is_train else 'validating'}") | |
# pt_output.sum().backward() | |
def test_conv2d(): | |
data = numpy.random.rand(1, 3, 224, 224).astype(numpy.float32) | |
pt_conv2d = torch.nn.Conv2d(3, 64, 3, padding=1, bias=False) | |
mx_conv2d = mxnet.sym.Convolution(data=mxnet.symbol.Variable(name="data"), | |
num_filter=64, kernel=(3, 3), num_group=1, stride=(1, 1), pad=(1, 1), | |
no_bias=True, name="module") | |
test_forward_template("conv2d", data, pt_conv2d, mx_conv2d) | |
def test_linear(): | |
data = numpy.random.rand(1, 1024).astype(numpy.float32) | |
pt_linear = torch.nn.Linear(1024, 512, bias=False) | |
mx_linear = mxnet.sym.FullyConnected(data=mxnet.symbol.Variable(name="data"), | |
num_hidden=512, | |
no_bias=True, name="module") | |
test_forward_template("linear", data, pt_linear, mx_linear) | |
def test_batch_normalization_2d(): | |
data = numpy.random.rand(10, 64, 56, 56).astype(numpy.float32) | |
pt_bn = torch.nn.BatchNorm2d(num_features=64, eps=1e-5, momentum=0.1, | |
affine=True, track_running_stats=True) | |
mx_bn = mxnet.sym.BatchNorm(data=mxnet.symbol.Variable(name="data"), | |
axis=1, | |
eps=1e-5, | |
momentum=0.9, | |
fix_gamma=False, | |
name="module") | |
test_forward_template("bn", data, pt_bn, mx_bn, is_train=True) | |
test_forward_template("bn", data, pt_bn, mx_bn, is_train=False) | |
def test_leakyrelu(): | |
data = numpy.random.rand(1, 1024).astype(numpy.float32) | |
pt_prelu = torch.nn.PReLU() | |
mx_prelu = mxnet.sym.LeakyReLU(data=mxnet.symbol.Variable(name="data"), | |
act_type="prelu", name="module") | |
test_forward_template("prelu", data, pt_prelu, mx_prelu) | |
if __name__ == '__main__': | |
test_conv2d() | |
test_linear() | |
test_batch_normalization_2d() | |
test_leakyrelu() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment