Skip to content

Instantly share code, notes, and snippets.

@ivanpanshin
Created September 22, 2020 13:27
Show Gist options
  • Save ivanpanshin/e1c5c18ee4fd1ff1f3c318cc48fbe88f to your computer and use it in GitHub Desktop.
Save ivanpanshin/e1c5c18ee4fd1ff1f3c318cc48fbe88f to your computer and use it in GitHub Desktop.
import os
import onnx
import torch
import torch.nn as nn
import onnxruntime as ort
from onnxsim import simplify
from PIL import Image
from ssd.config import cfg
from ssd.data.datasets import COCODataset, VOCDataset
import argparse
import numpy as np
from ssd.data.transforms import build_transforms
from ssd.modeling.detector import build_detection_model
from ssd.utils import mkdir
from ssd.utils.checkpoint import CheckPointer
if __name__ == '__main__':
cfg.merge_from_file('configs/mobilenet_v2_ssd320_voc0712.yaml')
model = build_detection_model(cfg)
state_dict = torch.load('weights/mobilenet_v2_ssd320_voc0712_v2.pth', map_location=lambda storage, loc: storage)['model']
model.load_state_dict(state_dict)
model.cuda()
model.eval()
dummy_input = torch.randn(1, 3, 320, 320, device='cuda')
pre_det_res = model(dummy_input.cuda())
dummy_output = torch.ones(*pre_det_res[0].shape, device="cuda")
EXPORT_NAME = 'ssd320_mobilenet_v2'
input_names = ["input_1"]
output_names = ["output_1"]
torch.onnx.export(
model=model,
args=dummy_input,
f=EXPORT_NAME + ".onnx",
input_names=input_names,
output_names=output_names,
example_outputs=dummy_output,
opset_version=11,
)
onnx_model = onnx.load(EXPORT_NAME + ".onnx")
onnx.checker.check_model(onnx_model)
model = onnx.load(EXPORT_NAME + ".onnx")
model_simp, check = simplify(model, skip_fuse_bn=True)
assert check, "Simplified ONNX model could not be validated"
print("Simplified")
onnx.save(model_simp, EXPORT_NAME + "_simple.onnx")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment