Skip to content

Instantly share code, notes, and snippets.

@marcoleewow
Last active July 29, 2023 21:51
Show Gist options
  • Save marcoleewow/26a0d5e23e35528262556697c62c2a50 to your computer and use it in GitHub Desktop.
Save marcoleewow/26a0d5e23e35528262556697c62c2a50 to your computer and use it in GitHub Desktop.
pytorch BLSTM model to CoreML using ONNX. Note that you must use Mac OS 10.15 Catalina and iOS 13 for it to work!
import os
import numpy as np
import coremltools
from coremltools.models.neural_network import flexible_shape_utils
import torch
from torch import nn
import onnx
from onnx_coreml import convert
INT_MAX = 2 ** 30
class BLSTM(nn.Module):
def __init__(self,
input_size=10,
hidden_size=256,
num_layers=5,
rnn_dropout=0.2,
num_classes=100):
super().__init__()
self.lstm = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=rnn_dropout,
bidirectional=True,
batch_first=True) # NOTE: YOU MUST USE `batch_first=True`
self.fc = nn.Linear(hidden_size * 2, num_classes)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, h_0=None, c_0=None):
if h_0 is not None and c_0 is not None:
x, _ = self.lstm(x, (h_0, c_0))
else:
x, _ = self.lstm(x)
x = self.fc(x)
x = self.softmax(x)
return x
def _convert_slice_v9(builder, node, graph, err):
'''
convert to CoreML Slice Static Layer:
https://github.com/apple/coremltools/blob/655b3be5cc0d42c3c4fa49f0f0e4a93a26b3e492/mlmodel/format/NeuralNetwork.proto#L5082
'''
data_shape = graph.shape_dict[node.inputs[0]]
len_of_data = len(data_shape)
begin_masks = [True] * len_of_data
end_masks = [True] * len_of_data
default_axes = list(range(len_of_data))
default_steps = [1] * len_of_data
ip_starts = node.attrs.get('starts')
ip_ends = node.attrs.get('ends')
axes = node.attrs.get('axes', default_axes)
steps = node.attrs.get('steps', default_steps)
starts = [0] * len_of_data
ends = [0] * len_of_data
for i in range(len(axes)):
current_axes = axes[i]
starts[current_axes] = ip_starts[i]
ends[current_axes] = ip_ends[i]
if ends[current_axes] != INT_MAX or ends[current_axes] < data_shape[current_axes]:
end_masks[current_axes] = False
if starts[current_axes] != 0:
begin_masks[current_axes] = False
builder.add_slice_static(
name=node.name,
input_name=node.inputs[0],
output_name=node.outputs[0],
begin_ids=starts,
end_ids=ends,
strides=steps,
begin_masks=begin_masks,
end_masks=end_masks
)
if __name__ == "__main__":
# these are fixed hyperparameters
input_size = 10
hidden_size = 256
num_layers = 5
num_classes = 100
# these two can be whatever, as we will change the flexible shape later
num_timesteps = 1000
batch_size = 2
pytorch_model = BLSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
num_classes=num_classes)
# load pretrained weights
# model_state_dict = torch.load('pytorch_model_weights.pth', map_location='cpu')['model_state_dict']
# pytorch_model.load_state_dict(model_state_dict)
# pytorch_model.eval()
# pytorch forward pass
dummy_input = torch.rand(batch_size, num_timesteps, input_size)
h_0 = torch.zeros(num_layers * 2, batch_size, hidden_size)
c_0 = torch.zeros(num_layers * 2, batch_size, hidden_size)
pytorch_output = pytorch_model(dummy_input, h_0, c_0)
# convert to onnx
onnx_tmp_path = 'test_pytorch_rnn.onnx'
torch.onnx.export(pytorch_model, args=(dummy_input, h_0, c_0), f=onnx_tmp_path, verbose=True)
# load back the onnx model
onnx_model = onnx.load(onnx_tmp_path)
onnx.checker.check_model(onnx_model)
print(onnx.helper.printable_graph(onnx_model.graph))
os.remove(onnx_tmp_path) # remove the saved file when its loaded to memory
# ONNX to coreML
coreml_model = convert(onnx_model,
disable_coreml_rank5_mapping=True,
custom_conversion_functions={"Slice": _convert_slice_v9})
coreml_model.author = 'Marco Lee'
coreml_model.short_description = 'LSTM CoreML Model'
spec = coreml_model.get_spec()
print(spec.description)
# you need to find the names of the features from spec
coremltools.utils.rename_feature(spec, 'input.1', 'inputs')
coremltools.utils.rename_feature(spec, '1', 'h_0')
coremltools.utils.rename_feature(spec, '2', 'c_0')
coremltools.utils.rename_feature(spec, '306', 'softmaxOutputs')
# add flexible shape to all inputs and softmax output
shape_ranges = flexible_shape_utils.NeuralNetworkMultiArrayShapeRange()
shape_ranges.add_channel_range((1, -1))
shape_ranges.add_height_range((1, -1))
shape_ranges.add_width_range((input_size, input_size))
flexible_shape_utils.update_multiarray_shape_range(spec,
feature_name="inputs",
shape_range=shape_ranges)
shape_ranges = flexible_shape_utils.NeuralNetworkMultiArrayShapeRange()
shape_ranges.add_channel_range((input_size, input_size))
shape_ranges.add_height_range((1, -1))
shape_ranges.add_width_range((hidden_size, hidden_size))
flexible_shape_utils.update_multiarray_shape_range(spec,
feature_name="h_0",
shape_range=shape_ranges)
shape_ranges = flexible_shape_utils.NeuralNetworkMultiArrayShapeRange()
shape_ranges.add_channel_range((input_size, input_size))
shape_ranges.add_height_range((1, -1))
shape_ranges.add_width_range((hidden_size, hidden_size))
flexible_shape_utils.update_multiarray_shape_range(spec,
feature_name="c_0",
shape_range=shape_ranges)
shape_ranges = flexible_shape_utils.NeuralNetworkMultiArrayShapeRange()
shape_ranges.add_channel_range((1, -1))
shape_ranges.add_width_range((num_classes, num_classes))
shape_ranges.add_height_range((1, -1))
flexible_shape_utils.update_multiarray_shape_range(spec,
feature_name="softmaxOutputs",
shape_range=shape_ranges)
coreml_model = coremltools.models.MLModel(spec)
print(coreml_model.get_spec().description)
# these are the feature names, please read it from the printable graph log
coreml_output = coreml_model.predict({"inputs": dummy_input.numpy(),
"h_0": h_0.numpy(),
"c_0": c_0.numpy()})
coreml_output = np.squeeze(np.atleast_1d(coreml_output)[0]['softmaxOutputs'])
coreml_model_path = "LSTM.mlmodel"
coreml_model.save(coreml_model_path)
np.testing.assert_array_almost_equal(coreml_output, pytorch_output.detach().numpy())
atomicwrites==1.3.0
attrs==19.1.0
Click==7.0
coremltools==3.0b3
importlib-metadata==0.20
more-itertools==7.2.0
numpy==1.17.1
onnx==1.5.0
onnx-coreml==1.0b2
packaging==19.1
pluggy==0.12.0
protobuf==3.9.1
py==1.8.0
pyparsing==2.4.2
pytest==5.1.2
six==1.12.0
torch==1.2.0
typing==3.7.4.1
typing-extensions==3.7.4
wcwidth==0.1.7
zipp==0.6.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment