Skip to content

Instantly share code, notes, and snippets.

@gaphex
Last active June 13, 2019 13:28
Show Gist options
  • Select an option

  • Save gaphex/4ac95de4288148470df2266bc955121a to your computer and use it in GitHub Desktop.

Select an option

Save gaphex/4ac95de4288148470df2266bc955121a to your computer and use it in GitHub Desktop.
Building a tf.Estimator from serialized GrapfDef
def model_fn(features, mode):
with tf.gfile.GFile(GRAPH_PATH, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
output = tf.import_graph_def(graph_def,
input_map={k + ':0': features[k]
for k in INPUT_NAMES},
return_elements=['final_encodes:0'])
return EstimatorSpec(mode=mode, predictions={'output': output[0]})
estimator = Estimator(model_fn=model_fn)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment