Skip to content

Instantly share code, notes, and snippets.

@batrlatom
Created September 2, 2019 14:57
Show Gist options
  • Save batrlatom/ab02f057553819b0a1afd74ba7005a4a to your computer and use it in GitHub Desktop.
Save batrlatom/ab02f057553819b0a1afd74ba7005a4a to your computer and use it in GitHub Desktop.
export onnx file
import torch
import torch.onnx
from networks import EmbeddingNetwork
sample_batch_size = 1
channel = 3
height = 224
width = 224
def load(file_path):
model.load_state_dict(torch.load(file_path))
#print('model loaded!')
return model
# A model class instance (class not shown)
model = EmbeddingNetwork(model_name='googlenet',
embedding_dim=128,
feature_extracting=False,
use_pretrained=True,
attention_flag=True,
cross_entropy_flag=False)
# Load the weights from a file (.pth usually)
state_dict = torch.load('n-pair_angular_apparel_googlenet2/model_18')
# Load the weights now into a model net architecture defined by our class
model.load_state_dict(state_dict)
# Create the right input shape (e.g. for an image)
dummy_input = torch.randn(sample_batch_size, channel, height, width)
torch.onnx.export(model, dummy_input, "onnx_model_name.onnx", input_names=['input'], output_names=['output'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment