Skip to content

Instantly share code, notes, and snippets.

@ChaiBapchya
Created July 11, 2020 06:44
Show Gist options
  • Select an option

  • Save ChaiBapchya/70366c13dee97d7b48af45f015333367 to your computer and use it in GitHub Desktop.

Select an option

Save ChaiBapchya/70366c13dee97d7b48af45f015333367 to your computer and use it in GitHub Desktop.
Test MX Model export to ONNX
import mxnet as mx
import numpy as np
from mxnet.contrib import onnx as onnx_mxnet
import logging
logging.basicConfig(level=logging.INFO)
model_list = [
['imagenet-11k/resnet-152/resnet-152-0000.params', 'imagenet-11k/resnet-152/resnet-152-symbol.json', 'imagenet-11k/synset.txt'],
['imagenet/resnet/101-layers/resnet-101-0000.params', 'imagenet/resnet/101-layers/resnet-101-symbol.json', 'imagenet/resnet/synset.txt'],
['imagenet/resnet/152-layers/resnet-152-0000.params', 'imagenet/resnet/152-layers/resnet-152-symbol.json', 'imagenet/resnet/synset.txt'],
['imagenet/resnet/18-layers/resnet-18-0000.params', 'imagenet/resnet/18-layers/resnet-18-symbol.json', 'imagenet/resnet/synset.txt'],
['imagenet/resnext/50-layers/resnext-50-0000.params', 'imagenet/resnext/50-layers/resnext-50-symbol.json', 'imagenet/resnext/synset.txt'],
['imagenet/squeezenet/squeezenet_v1.1-0000.params', 'imagenet/squeezenet/squeezenet_v1.1-symbol.json', 'imagenet/synset.txt'],
['imagenet/squeezenet/squeezenet_v1.0-0000.params', 'imagenet/squeezenet/squeezenet_v1.0-symbol.json', 'imagenet/synset.txt'],
['imagenet/vgg/vgg16-0000.params', 'imagenet/vgg/vgg16-symbol.json', 'imagenet/synset.txt']
]
# Download pre-trained resnet model - json and params by running following code.
for model in model_list:
param, symbol, synset = model[0], model[1], model[2]
model_name = ''.join(symbol.split('/')[-1].split('-')[0:-1])
print('--------------------'+model_name+'------------------------')
path='http://data.mxnet.io/models/'
[mx.test_utils.download(path+param), mx.test_utils.download(path+symbol), mx.test_utils.download(path+synset)]
# Downloaded input symbol and params files
sym = symbol.split('/')[-1]
params = param.split('/')[-1]
# Standard Imagenet input - 3 channels, 224*224
input_shape = (1,3,224,224)
# Path of the output file
onnx_file = './mxnet_exported_'+model_name+'.onnx'
# Invoke export model API. It returns path of the converted onnx model
converted_model_path = onnx_mxnet.export_model(sym, params, [input_shape], np.float32, onnx_file)
from onnx import checker
import onnx
# Load onnx model
model_proto = onnx.load_model(converted_model_path)
# Check if converted ONNX protobuf is valid
checker.check_graph(model_proto.graph)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment