Created
September 22, 2020 13:26
-
-
Save ivanpanshin/f9c4227c654e4d51b222a88f3fe16c9d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import onnx | |
import torch | |
import torchvision | |
import onnxruntime as ort | |
from onnxsim import simplify | |
import numpy as np | |
if __name__ == '__main__': | |
model = torchvision.models.resnet18(pretrained=True).cuda().eval() | |
dummy_input = torch.ones(1, 3, 224, 224, device="cuda") | |
with torch.no_grad(): | |
pre_det_res = model(dummy_input.cuda()) | |
dummy_output = torch.ones(*pre_det_res.shape, device="cuda") | |
EXPORT_NAME = 'resnet18' | |
input_names = ["input_1"] | |
output_names = ["output_1"] | |
torch.onnx.export( | |
model=model, | |
args=dummy_input, | |
f=EXPORT_NAME + ".onnx", | |
input_names=input_names, | |
output_names=output_names, | |
example_outputs=dummy_output, | |
opset_version=11, | |
export_params=True, | |
) | |
onnx_model = onnx.load(EXPORT_NAME + ".onnx") | |
onnx.checker.check_model(onnx_model) | |
model = onnx.load(EXPORT_NAME + ".onnx") | |
model_simp, check = simplify(model, skip_fuse_bn=True) | |
assert check, "Simplified ONNX model could not be validated" | |
print("Simplified") | |
onnx.save(model_simp, EXPORT_NAME + "_simple.onnx") | |
print("Validating result") | |
ort_session = ort.InferenceSession(EXPORT_NAME + "_simple.onnx") | |
outputs = ort_session.run(None, {"input_1": dummy_input.cpu().numpy()}) | |
print("Validated. Output shape:", outputs[0].shape) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment