Created
April 3, 2019 10:30
-
-
Save vishal-keshav/7fa502ffc9f8fd592a1fc400c031c113 to your computer and use it in GitHub Desktop.
Optimizing tensorflow model for inference
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import tensorflow as tf | |
from tensorflow.python.tools import freeze_graph | |
from tensorflow.python.tools import optimize_for_inference_lib | |
def convert_frozen_to_inference(model_path = "generated_model", | |
frozen_file = "generated_model.pb", inputs = ["input"], | |
outputs = ["output"], out_file = "generated_model_opt.pb"): | |
frozen_graph = tf.GraphDef() | |
with tf.gfile.Open(model_path + "/" + frozen_file) as f: | |
file_data = f.read() | |
frozen_graph.ParseFromString(file_data) | |
optimized_graph_def = optimize_for_inference_lib.optimize_for_inference( | |
frozen_graph, inputs, outputs, tf.float32.as_datatype_enum) | |
f = tf.gfile.FastGFile(model_path + "/" + out_file, "w") | |
f.write(optimized_graph_def.SerializeToString()) | |
def main(): | |
inputs = ["Placeholder"] # Set this list appropriately | |
#outputs = ["DepthToSpace"] # Set this list appropriately | |
outputs = ["g_conv10/BiasAdd"] | |
convert_frozen_to_inference(inputs = inputs, outputs = outputs) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment