Skip to content

Instantly share code, notes, and snippets.

@batrlatom
batrlatom / export_onnx_file.py
Created September 2, 2019 14:57
export onnx file
import torch
import torch.onnx
from networks import EmbeddingNetwork
sample_batch_size = 1
channel = 3
height = 224
width = 224
def load(file_path):
import tensorflow as tf
import argparse
import numpy as np
import cv2
def load_graph(frozen_graph_filename):
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
import onnx
from onnx_tf.backend import prepare
import tensorflow as tf
from tensorflow.python.client import timeline
import time
# Prepare the inputs, here we use numpy to generate some random inputs for demo purpose
import numpy as np
img = np.random.randn(1, 3, 224, 224).astype(np.float32)