Created
February 4, 2020 10:48
-
-
Save jeremycochoy/fead77ca3733fa3b146a62fba4ad709f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import onnxruntime as rt | |
import onnx | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn | |
import torch.onnx | |
def onnx_export(model: torch.nn.Module, input_shape, filename: str, | |
input_names=["network_input"], | |
output_names=["network_output"]) -> (): | |
# Create dummy input | |
device = model.parameters().__next__().device | |
dummy_input = torch.rand(input_shape).to(device) | |
# Convert the PyTorch model to ONNX | |
torch.onnx.export(model, | |
dummy_input, | |
filename, | |
verbose=False, | |
input_names=input_names, | |
output_names=output_names) | |
# Build model | |
input_shape = (1, 8, 16, 16) | |
model = nn.Sequential( | |
nn.Conv2d(8, 8, kernel_size=(3, 3), padding=(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.BatchNorm2d(8) | |
) | |
# ONNX Export if model doesn't exists | |
# Model layers are randomly initialized by torch, | |
# so we want to keep this model once generated. | |
import os.path | |
if not os.path.isfile("model.onnx"): | |
onnx_export(model, input_shape, filename="model.onnx") | |
# Edit model | |
model = onnx.load("model.onnx") | |
model.graph.node[1].op_type = "Elu" | |
onnx.save(model, "model-.onnx") | |
# Load ONNX Model and run it | |
sess = rt.InferenceSession("model-.onnx") | |
input_name = sess.get_inputs()[0].name | |
label_name = sess.get_outputs()[0].name | |
X_test = np.ones(input_shape) | |
pred = sess.run([label_name], {input_name: X_test.astype(np.float32)})[0] | |
print(pred) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment