Skip to content

Instantly share code, notes, and snippets.

@kice
Created December 9, 2019 03:22
Show Gist options
  • Save kice/764fba62d0571618ce529fd39f62437e to your computer and use it in GitHub Desktop.
Save kice/764fba62d0571618ce529fd39f62437e to your computer and use it in GitHub Desktop.
import mxnet as mx
import onnx
import numpy as np
import mxnet.contrib.onnx as onnx_mxnet
model_name = 'somemodel'
model_file = model_name + '.onnx'
onnx_model = onnx.load(model_file)
onnx.checker.check_model(onnx_model)
print('The model is checked!')
input_name = onnx_model.graph.node[0].input[0]
input_type = onnx_model.graph.input[0].type.tensor_type
dtype = 'float16' if input_type.elem_type == 10 else 'float32'
input_shape = []
for i in input_type.shape.dim:
input_shape += [i.dim_value]
print(f'Input Name: {input_name}')
print(f'Input Shape: {input_shape}')
print(f'Input DType: {dtype}')
if input_name != 'data':
print(f'WARNING: For MXNet model, input name should be "data"')
net, arg, aux = onnx_mxnet.import_model(model_file)
output_shape = net.infer_shape(data=input_shape)
print(f'Output shape: {output_shape}')
print('Saving MXNet symbol...')
net.save(model_name + '.json')
params = {}
for k,v in arg.items():
params[f'arg:{k}'] = v.astype(dtype)
for k,v in aux.items():
params[f'aux:{k}'] = v.astype(dtype)
mx.nd.save(model_name + '.param', params)
print('Done.')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment