Skip to content

Instantly share code, notes, and snippets.

@lucidfrontier45
Created October 17, 2023 00:44
Show Gist options
  • Save lucidfrontier45/23658ee478b6bcf6e0ded11181899100 to your computer and use it in GitHub Desktop.
Save lucidfrontier45/23658ee478b6bcf6e0ded11181899100 to your computer and use it in GitHub Desktop.
PyTorch Multiple Input Model Onnx Example
import torch
import torch.nn.functional as F
from torch.nn import Linear, Module
class MyModel(Module):
def __init__(self):
super().__init__()
self.linear11 = Linear(3, 8)
self.linear12 = Linear(5, 8)
self.linear13 = Linear(10, 8)
self.linear2 = Linear(24, 128)
self.linear31 = Linear(128, 1)
self.linear32 = Linear(128, 1)
def forward(self, x1, x2, x3):
z1 = self.linear11(x1)
z2 = self.linear12(x2)
z3 = self.linear13(x3)
z = torch.concatenate((z1, z2, z3), dim=1)
z = F.relu(z)
z = F.relu(self.linear2(z))
return self.linear31(z), self.linear32(z)
model = MyModel()
dummy_x = (torch.randn(1, 3), torch.randn(1, 5), torch.randn(1, 10))
y = model(*dummy_x)
print(y)
onnx_model = torch.onnx.export(
model,
dummy_x,
"model.onnx",
verbose=True,
input_names=["x1", "x2", "x3"],
output_names=["y1", "y2"],
)
import onnx
import onnxruntime as ort
import numpy as np
model = onnx.load("model.onnx")
print(model.graph.input)
print(model.graph.output)
sess = ort.InferenceSession("model.onnx")
dummy_x = (
np.random.randn(1, 3).astype(np.float32),
np.random.randn(1, 5).astype(np.float32),
np.random.randn(1, 10).astype(np.float32),
)
y = sess.run(
None,
{
"x1": dummy_x[0],
"x2": dummy_x[1],
"x3": dummy_x[2],
},
)
print(y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment