Created
September 22, 2020 13:27
-
-
Save ivanpanshin/e1c5c18ee4fd1ff1f3c318cc48fbe88f to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import onnx | |
import torch | |
import torch.nn as nn | |
import onnxruntime as ort | |
from onnxsim import simplify | |
from PIL import Image | |
from ssd.config import cfg | |
from ssd.data.datasets import COCODataset, VOCDataset | |
import argparse | |
import numpy as np | |
from ssd.data.transforms import build_transforms | |
from ssd.modeling.detector import build_detection_model | |
from ssd.utils import mkdir | |
from ssd.utils.checkpoint import CheckPointer | |
if __name__ == '__main__': | |
cfg.merge_from_file('configs/mobilenet_v2_ssd320_voc0712.yaml') | |
model = build_detection_model(cfg) | |
state_dict = torch.load('weights/mobilenet_v2_ssd320_voc0712_v2.pth', map_location=lambda storage, loc: storage)['model'] | |
model.load_state_dict(state_dict) | |
model.cuda() | |
model.eval() | |
dummy_input = torch.randn(1, 3, 320, 320, device='cuda') | |
pre_det_res = model(dummy_input.cuda()) | |
dummy_output = torch.ones(*pre_det_res[0].shape, device="cuda") | |
EXPORT_NAME = 'ssd320_mobilenet_v2' | |
input_names = ["input_1"] | |
output_names = ["output_1"] | |
torch.onnx.export( | |
model=model, | |
args=dummy_input, | |
f=EXPORT_NAME + ".onnx", | |
input_names=input_names, | |
output_names=output_names, | |
example_outputs=dummy_output, | |
opset_version=11, | |
) | |
onnx_model = onnx.load(EXPORT_NAME + ".onnx") | |
onnx.checker.check_model(onnx_model) | |
model = onnx.load(EXPORT_NAME + ".onnx") | |
model_simp, check = simplify(model, skip_fuse_bn=True) | |
assert check, "Simplified ONNX model could not be validated" | |
print("Simplified") | |
onnx.save(model_simp, EXPORT_NAME + "_simple.onnx") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment