Created
February 3, 2022 07:44
-
-
Save ryujaehun/3c914acb83dec7b453ae63b72d79098b to your computer and use it in GitHub Desktop.
onnx file generator though PyTorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torchvision.models as models | |
# import argparse | |
# parser = argparse.ArgumentParser() | |
# parser.add_argument("-n", "--network", help="network for onnx file", required=False) | |
# args = parser.parse_args() | |
input_size = (1,3,224,224) | |
dummy_input = torch.randn(*input_size, device="cuda") | |
MODEL_LIST = { | |
models.resnet: models.resnet.__all__[1:], | |
models.densenet: models.densenet.__all__[1:], | |
models.squeezenet: models.squeezenet.__all__[1:], | |
models.vgg: models.vgg.__all__[1:], | |
} | |
for model_type in MODEL_LIST.keys(): | |
for model_name in MODEL_LIST[model_type]: | |
model = getattr(model_type, model_name)(pretrained=True).cuda() | |
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ] | |
output_names = [ "output1" ] | |
torch.onnx.export(model, dummy_input, f"{model_name}.onnx", verbose=True, input_names=input_names, output_names=output_names) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment