Skip to content

Instantly share code, notes, and snippets.

@ivanpanshin
Created September 11, 2020 08:26
Show Gist options
  • Save ivanpanshin/e44b62823a5f536564bc04b10df5e28c to your computer and use it in GitHub Desktop.
Save ivanpanshin/e44b62823a5f536564bc04b10df5e28c to your computer and use it in GitHub Desktop.
import onnx
import torch
import numpy as np
import onnxruntime as ort
import torchvision as tv
from onnxsim import simplify
dummy_input = torch.ones(1, 3, 300, 300, device="cuda")
pt_model_det = tv.models.resnet50().cuda().eval().requires_grad_(False)
pre_det_res = pt_model_det(dummy_input.cuda())
print(f"Torch output shape: {pre_det_res.shape}")
dummy_output = torch.ones(*pre_det_res.shape, device="cuda")
# print(pre_det_res.shape)
print("Converting")
input_names = ["input_1"]
output_names = ["output_1"]
EXPORT_NAME = "res50"
torch.onnx.export(
model=pt_model_det,
args=dummy_input,
f=EXPORT_NAME + ".onnx",
input_names=input_names,
output_names=output_names,
example_outputs=dummy_output,
opset_version=11,
)
print("Converted")
print("Simplifying")
model = onnx.load(EXPORT_NAME + ".onnx")
model_simp, check = simplify(model, skip_fuse_bn=True)
assert check, "Simplified ONNX model could not be validated"
print("Simplified")
onnx.save(model_simp, EXPORT_NAME + "_simple.onnx")
print("Validating result")
ort_session = ort.InferenceSession(EXPORT_NAME + "_simple.onnx")
outputs = ort_session.run(None, {"input_1": dummy_input.cpu().numpy()})
print("Validated. Output shape:", outputs[0].shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment