Skip to content

Instantly share code, notes, and snippets.

@muayyad-alsadi
Last active August 1, 2019 11:53
Show Gist options
  • Save muayyad-alsadi/dceca624147556076af7d645d9778382 to your computer and use it in GitHub Desktop.
Save muayyad-alsadi/dceca624147556076af7d645d9778382 to your computer and use it in GitHub Desktop.
pure python tensorflow transform_graph
#! /usr/bin/env python
from __future__ import print_function
import argparse
import os
import tensorflow as tf
# from tensorflow.python.framework import graph_util
# from tensorflow.python.framework import graph_io
#from tensorflow.python.tools import optimize_for_inference_lib
from tensorflow.tools.graph_transforms import TransformGraph
# used similar to transform_graph without bazel build
# https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms#using-the-graph-transform-tool
# transform_graph.py --in_graph=data/merge/car-brands/frozen.pb --out_graph=data/merge/car-brands/frozen2.pb --inputs='input:0' --outputs='InceptionV1/Logits/Predictions/Reshape_1:0' --transforms='remove_nodes(op=Identity, op=CheckNumerics)'
strip_n_filter = lambda l: [ i.strip() for i in l if i.strip() ]
def tr(args):
in_graph_def = tf.GraphDef()
with open(args.in_graph, 'rb') as f:
in_graph_def.ParseFromString(f.read())
input_node_names = strip_n_filter(args.inputs.split(','))
output_node_names = strip_n_filter(args.outputs.split(','))
transforms = strip_n_filter(args.transforms.splitlines())
out_graph_def = TransformGraph(
in_graph_def,
input_node_names,
output_node_names,
transforms)
with open(args.out_graph, 'wb') as f:
f.write(out_graph_def.SerializeToString())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in_graph", help="input .pb file", type=str)
parser.add_argument("--out_graph", help="output .pb file", type=str)
parser.add_argument("--inputs", help="tensor names", type=str, default="input:0")
parser.add_argument("--outputs", help="tensor names", type=str, default="output:0")
parser.add_argument("--transforms", help="operators with new lines", type=str)
args = parser.parse_args()
tr(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment