Created
September 2, 2019 14:57
-
-
Save batrlatom/ab02f057553819b0a1afd74ba7005a4a to your computer and use it in GitHub Desktop.
export onnx file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.onnx | |
from networks import EmbeddingNetwork | |
sample_batch_size = 1 | |
channel = 3 | |
height = 224 | |
width = 224 | |
def load(file_path): | |
model.load_state_dict(torch.load(file_path)) | |
#print('model loaded!') | |
return model | |
# A model class instance (class not shown) | |
model = EmbeddingNetwork(model_name='googlenet', | |
embedding_dim=128, | |
feature_extracting=False, | |
use_pretrained=True, | |
attention_flag=True, | |
cross_entropy_flag=False) | |
# Load the weights from a file (.pth usually) | |
state_dict = torch.load('n-pair_angular_apparel_googlenet2/model_18') | |
# Load the weights now into a model net architecture defined by our class | |
model.load_state_dict(state_dict) | |
# Create the right input shape (e.g. for an image) | |
dummy_input = torch.randn(sample_batch_size, channel, height, width) | |
torch.onnx.export(model, dummy_input, "onnx_model_name.onnx", input_names=['input'], output_names=['output']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment