-
-
Save adujardin/5d0a9ec73aa81c694330d39e1638d512 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import torch | |
import torchvision | |
import shutil | |
# https://github.com/vita-epfl/openpifpaf/blob/master/openpifpaf/export_onnx.py | |
try: | |
import onnx | |
import onnx.utils | |
except ImportError: | |
onnx = None | |
try: | |
import onnxsim | |
except ImportError: | |
onnxsim = None | |
def optimize(infile, outfile=None): | |
if outfile is None: | |
assert infile.endswith('.onnx') | |
outfile = infile | |
infile = infile.replace('.onnx', '.unoptimized.onnx') | |
shutil.copyfile(outfile, infile) | |
model = onnx.load(infile) | |
optimized_model = onnx.optimizer.optimize(model) | |
onnx.save(optimized_model, outfile) | |
def check(modelfile): | |
model = onnx.load(modelfile) | |
onnx.checker.check_model(model) | |
def polish(infile, outfile=None): | |
if outfile is None: | |
assert infile.endswith('.onnx') | |
outfile = infile | |
infile = infile.replace('.onnx', '.unpolished.onnx') | |
shutil.copyfile(outfile, infile) | |
model = onnx.load(infile) | |
polished_model = onnx.utils.polish_model(model) | |
onnx.save(polished_model, outfile) | |
def simplify(infile, outfile=None): | |
if outfile is None: | |
assert infile.endswith('.onnx') | |
outfile = infile | |
infile = infile.replace('.onnx', '.unsimplified.onnx') | |
shutil.copyfile(outfile, infile) | |
simplified_model = onnxsim.simplify(infile, check_n=0, perform_optimization=False) | |
onnx.save(simplified_model, outfile) | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--opset", type=int, default=11, help="ONNX opset version to generate models with.") | |
parser.add_argument('--outfile', default='bar.onnx') | |
parser.add_argument('--dynamic-dimensions', dest='dynamic', default=True, action='store_true') | |
args = parser.parse_args() | |
dummy_input = torch.randn(10, 3, 224, 224, device='cuda') | |
model = torchvision.models.alexnet(pretrained=True).cuda() | |
input_names = [ "actual_input_1" ] #+ [ "learned_%d" % i for i in range(16) ] | |
output_names = [ "output1" ] | |
if args.dynamic: | |
# Dynamic Shape | |
dynamic_axes = {"actual_input_1":{0:"batch_size"}, "output1":{0:"batch_size"}} | |
print(dynamic_axes) | |
torch.onnx.export(model, dummy_input, args.outfile, verbose=True, opset_version=args.opset, | |
input_names=input_names, output_names=output_names, | |
dynamic_axes=dynamic_axes) | |
else: | |
# Fixed Shape | |
torch.onnx.export(model, dummy_input, args.outfile, verbose=True, opset_version=args.opset, | |
input_names=input_names, output_names=output_names) | |
if onnx: | |
if True and onnxsim: | |
simplify(args.outfile) | |
if False: | |
optimize(args.outfile) | |
if False: | |
polish(args.outfile) | |
if True: | |
check(args.outfile) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment