Skip to content

Instantly share code, notes, and snippets.

@annarailton
Last active May 29, 2021 08:31
Show Gist options
  • Save annarailton/083140321e77fc00676ce3511903e8c3 to your computer and use it in GitHub Desktop.
Save annarailton/083140321e77fc00676ce3511903e8c3 to your computer and use it in GitHub Desktop.
Save and restore a Tensorflow model with a tf.data.Dataset + initialisable iterators.
# Tensorflow 1.8.0
import tensorflow as tf
import numpy as np
def make_iterators(train_dataset, test_dataset):
"""Creates the dataset iterators needed in train()."""
handle = tf.placeholder(tf.string, shape=[])
tf.add_to_collection('handle', handle)
iterator = tf.data.Iterator.from_string_handle(
handle, train_dataset.output_types, train_dataset.output_shapes)
element = iterator.get_next()
tf.add_to_collection('element', element)
train_iter = train_dataset.make_initializable_iterator()
test_iter = test_dataset.make_initializable_iterator()
return train_iter, test_iter
def train(train_dataset, test_dataset):
"""
Create graph with an Dataset and Iterator and save the model.
There is some op that is applied to the data from the iterator.
"""
train_iter, test_iter = make_iterators(train_dataset, test_dataset)
handle = tf.get_collection('handle')[0]
element = tf.get_collection('element')[0]
some_op = tf.multiply(element, 0.5)
tf.add_to_collection('some_op', some_op)
v = tf.get_variable('v', initializer=tf.zeros([]))
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
train_handle = sess.run(train_iter.string_handle())
test_handle = sess.run(test_iter.string_handle())
# Run data iterator initialisation
sess.run(train_iter.initializer)
sess.run(test_iter.initializer)
# "Training"
print("Training")
while True:
try:
print(sess.run(some_op, feed_dict={handle: train_handle}))
except tf.errors.OutOfRangeError:
break
# "Test evaluation"
print("Testing")
while True:
try:
print(sess.run(some_op, feed_dict={handle: test_handle}))
except tf.errors.OutOfRangeError:
break
saver.save(sess, 'checkpoints/fufu')
def eval(dataset):
"""Restore the model from file and pass some new data through it"""
with tf.Session() as sess:
saver = tf.train.import_meta_graph('checkpoints/fufu.meta')
saver.restore(sess, 'checkpoints/fufu')
handle = tf.get_collection('handle')[0]
# Make new iterator
iterator = dataset.make_one_shot_iterator()
new_handle = sess.run(iterator.string_handle())
# Don't need to call iterator.get_next() again as `some_op` will use
# restored `element`
some_op = tf.get_collection('some_op')[0]
# "Further evaluation"
print("More testing")
while True:
try:
print(sess.run(some_op, feed_dict={handle: new_handle}))
except tf.errors.OutOfRangeError:
break
if __name__ == '__main__':
train_dataset = tf.data.Dataset.from_tensor_slices(
tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32))
test_dataset = tf.data.Dataset.from_tensor_slices(
tf.constant(np.random.randint(0, 100, (2, 2)), dtype=tf.float32))
train(train_dataset, test_dataset)
# Now want to evaluate the results of another test dataset
another_test_dataset = tf.data.Dataset.from_tensor_slices(
tf.constant(np.random.randint(0, 100, (4, 2)), dtype=tf.float32))
eval(another_test_dataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment