Last active
October 12, 2018 06:23
-
-
Save marcoleewow/2afb5762ed74f5244c9bd85eae35147d to your computer and use it in GitHub Desktop.
Test for converting a RNN model from pyTorch to ONNX to coreML
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import os | |
import onnx | |
from onnx_coreml import convert | |
import pytest | |
from coremltools.proto import NeuralNetwork_pb2 #type: ignore | |
""" | |
For the conversion functions, I am using the examples from here: | |
https://github.com/onnx/onnx-coreml/blob/dc58afc4fce51d3376f434b20498a9abcf20edf9/tests/custom_layers_test.py#L75 | |
and read ops from here | |
https://github.com/onnx/onnx/blob/master/docs/Operators.md#ConstantFill | |
""" | |
def convert_gather(node): | |
params = NeuralNetwork_pb2.CustomLayerParams() | |
params.className = node.op_type | |
params.description = "Custom layer that corresponds to the ONNX op {}".format(node.op_type, ) | |
params.parameters["axis"].intValue = node.attrs.get('axis', 0) | |
return params | |
def convert_constantfill(node): | |
params = NeuralNetwork_pb2.CustomLayerParams() | |
params.className = node.op_type | |
params.description = "Custom layer that corresponds to the ONNX op {}".format(node.op_type, ) | |
params.parameters["dtype"].intValue = node.attrs.get('dtype', 1) | |
params.parameters["input_as_shape"].intValue = node.attrs['input_as_shape'] | |
params.parameters["value"].doubleValue = node.attrs.get('value', 0) | |
return params | |
class RNNModel(nn.Module): | |
def __init__(self, num_layers): | |
super().__init__() | |
self.blstm = nn.LSTM(input_size=256, | |
hidden_size=100, | |
num_layers=num_layers) | |
def forward(self, x): | |
x, _ = self.blstm(x) | |
return x | |
def pytorch_to_coreml(pytorch_model, custom_conversion_functions=None): | |
# pytorch to ONNX | |
dummy_input = torch.rand(1000, 2, 256) # (random_seq_length, batch_size, input_size) | |
assert pytorch_model(dummy_input) is not None # check if forward pass works | |
onnx_tmp_path = '/tmp/temp.onnx' | |
torch.onnx.export(pytorch_model, args=dummy_input, f=onnx_tmp_path, verbose=True) | |
# load back the onnx model | |
onnx_model = onnx.load(onnx_tmp_path) | |
# print(onnx.helper.printable_graph(onnx_model.graph)) | |
os.remove(onnx_tmp_path) # remove the saved file when its loaded to memory | |
# ONNX to coreML | |
if custom_conversion_functions is not None: | |
coreml_model = convert(onnx_model, | |
add_custom_layers=True, | |
custom_conversion_functions=custom_conversion_functions) | |
else: | |
coreml_model = convert(onnx_model) | |
coreml_model.author = 'Marco Lee' | |
coreml_model.short_description = 'Failed RNN Example' | |
coreml_model.save('./failed_rnn.mlmodel') | |
class TestRNNModelPyTorchtoONNXtoCoreML: | |
def test_num_rnn_layer_equal_1(self): | |
""" | |
This gives RuntimeError: Inferred shape and existing shape differ in rank: (4) vs (3) | |
""" | |
model = RNNModel(num_layers=1) | |
pytorch_to_coreml(model) | |
def test_num_rnn_layers_equal_2(self): | |
""" | |
This gives NotImplementedError: Unsupported ONNX ops of type: Gather,ConstantFill | |
""" | |
model = RNNModel(num_layers=2) | |
pytorch_to_coreml(model) | |
def test_num_rnn_layers_equal_2_with_custom_conversion_functions(self): | |
""" | |
We use our own implemented custom conversion functions here. | |
However, if we drag and drop this mlmodel to XCode, it will give: | |
There was a problem decoding this CoreML document | |
validator error: Layer '31' of type 500 has 0 inputs but expects at least 1. | |
""" | |
model = RNNModel(num_layers=2) | |
custom_conversion_functions = { "Gather": convert_gather, "ConstantFill": convert_constantfill} | |
pytorch_to_coreml(model, custom_conversion_functions=custom_conversion_functions) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment