Created
December 9, 2019 03:22
-
-
Save kice/764fba62d0571618ce529fd39f62437e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mxnet as mx | |
import onnx | |
import numpy as np | |
import mxnet.contrib.onnx as onnx_mxnet | |
model_name = 'somemodel' | |
model_file = model_name + '.onnx' | |
onnx_model = onnx.load(model_file) | |
onnx.checker.check_model(onnx_model) | |
print('The model is checked!') | |
input_name = onnx_model.graph.node[0].input[0] | |
input_type = onnx_model.graph.input[0].type.tensor_type | |
dtype = 'float16' if input_type.elem_type == 10 else 'float32' | |
input_shape = [] | |
for i in input_type.shape.dim: | |
input_shape += [i.dim_value] | |
print(f'Input Name: {input_name}') | |
print(f'Input Shape: {input_shape}') | |
print(f'Input DType: {dtype}') | |
if input_name != 'data': | |
print(f'WARNING: For MXNet model, input name should be "data"') | |
net, arg, aux = onnx_mxnet.import_model(model_file) | |
output_shape = net.infer_shape(data=input_shape) | |
print(f'Output shape: {output_shape}') | |
print('Saving MXNet symbol...') | |
net.save(model_name + '.json') | |
params = {} | |
for k,v in arg.items(): | |
params[f'arg:{k}'] = v.astype(dtype) | |
for k,v in aux.items(): | |
params[f'aux:{k}'] = v.astype(dtype) | |
mx.nd.save(model_name + '.param', params) | |
print('Done.') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment