Skip to content

Instantly share code, notes, and snippets.

@ivanpanshin
Created September 22, 2020 13:26
Show Gist options
  • Save ivanpanshin/f9c4227c654e4d51b222a88f3fe16c9d to your computer and use it in GitHub Desktop.
Save ivanpanshin/f9c4227c654e4d51b222a88f3fe16c9d to your computer and use it in GitHub Desktop.
import onnx
import torch
import torchvision
import onnxruntime as ort
from onnxsim import simplify
import numpy as np
if __name__ == '__main__':
model = torchvision.models.resnet18(pretrained=True).cuda().eval()
dummy_input = torch.ones(1, 3, 224, 224, device="cuda")
with torch.no_grad():
pre_det_res = model(dummy_input.cuda())
dummy_output = torch.ones(*pre_det_res.shape, device="cuda")
EXPORT_NAME = 'resnet18'
input_names = ["input_1"]
output_names = ["output_1"]
torch.onnx.export(
model=model,
args=dummy_input,
f=EXPORT_NAME + ".onnx",
input_names=input_names,
output_names=output_names,
example_outputs=dummy_output,
opset_version=11,
export_params=True,
)
onnx_model = onnx.load(EXPORT_NAME + ".onnx")
onnx.checker.check_model(onnx_model)
model = onnx.load(EXPORT_NAME + ".onnx")
model_simp, check = simplify(model, skip_fuse_bn=True)
assert check, "Simplified ONNX model could not be validated"
print("Simplified")
onnx.save(model_simp, EXPORT_NAME + "_simple.onnx")
print("Validating result")
ort_session = ort.InferenceSession(EXPORT_NAME + "_simple.onnx")
outputs = ort_session.run(None, {"input_1": dummy_input.cpu().numpy()})
print("Validated. Output shape:", outputs[0].shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment