Skip to content

Instantly share code, notes, and snippets.

@Breta01
Created April 15, 2017 20:13
Show Gist options
  • Save Breta01/d63a498f79ca7a72699fce6cc9c3b8b7 to your computer and use it in GitHub Desktop.
Save Breta01/d63a498f79ca7a72699fce6cc9c3b8b7 to your computer and use it in GitHub Desktop.
Class for importing multiple TensorFlow graphs.
import tensorflow as tf
class Graph():
""" Importing and running isolated TF graph """
def __init__(self, loc):
# Create local graph and use it in the session
self.graph = tf.Graph()
self.sess = tf.Session(graph=self.graph)
with self.graph.as_default():
# Import saved model from location 'loc' into local graph
saver = tf.train.import_meta_graph(loc + '.meta', clear_devices=True)
saver.restore(self.sess, loc)
# Get activation function from saved collection
self.activation = tf.get_collection('activation')[0]
def run(self, data):
""" Running the activation function previously imported """
return self.sess.run(self.activation, feed_dict={"x:0": data})
### Using the class ###
data = 50 # random data
model = Graph('models/model_name')
result = model.run(data)
print(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment