Last active
May 29, 2021 08:31
-
-
Save annarailton/083140321e77fc00676ce3511903e8c3 to your computer and use it in GitHub Desktop.
Save and restore a Tensorflow model with a tf.data.Dataset + initialisable iterators.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Tensorflow 1.8.0 | |
import tensorflow as tf | |
import numpy as np | |
def make_iterators(train_dataset, test_dataset): | |
"""Creates the dataset iterators needed in train().""" | |
handle = tf.placeholder(tf.string, shape=[]) | |
tf.add_to_collection('handle', handle) | |
iterator = tf.data.Iterator.from_string_handle( | |
handle, train_dataset.output_types, train_dataset.output_shapes) | |
element = iterator.get_next() | |
tf.add_to_collection('element', element) | |
train_iter = train_dataset.make_initializable_iterator() | |
test_iter = test_dataset.make_initializable_iterator() | |
return train_iter, test_iter | |
def train(train_dataset, test_dataset): | |
""" | |
Create graph with an Dataset and Iterator and save the model. | |
There is some op that is applied to the data from the iterator. | |
""" | |
train_iter, test_iter = make_iterators(train_dataset, test_dataset) | |
handle = tf.get_collection('handle')[0] | |
element = tf.get_collection('element')[0] | |
some_op = tf.multiply(element, 0.5) | |
tf.add_to_collection('some_op', some_op) | |
v = tf.get_variable('v', initializer=tf.zeros([])) | |
saver = tf.train.Saver() | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
train_handle = sess.run(train_iter.string_handle()) | |
test_handle = sess.run(test_iter.string_handle()) | |
# Run data iterator initialisation | |
sess.run(train_iter.initializer) | |
sess.run(test_iter.initializer) | |
# "Training" | |
print("Training") | |
while True: | |
try: | |
print(sess.run(some_op, feed_dict={handle: train_handle})) | |
except tf.errors.OutOfRangeError: | |
break | |
# "Test evaluation" | |
print("Testing") | |
while True: | |
try: | |
print(sess.run(some_op, feed_dict={handle: test_handle})) | |
except tf.errors.OutOfRangeError: | |
break | |
saver.save(sess, 'checkpoints/fufu') | |
def eval(dataset): | |
"""Restore the model from file and pass some new data through it""" | |
with tf.Session() as sess: | |
saver = tf.train.import_meta_graph('checkpoints/fufu.meta') | |
saver.restore(sess, 'checkpoints/fufu') | |
handle = tf.get_collection('handle')[0] | |
# Make new iterator | |
iterator = dataset.make_one_shot_iterator() | |
new_handle = sess.run(iterator.string_handle()) | |
# Don't need to call iterator.get_next() again as `some_op` will use | |
# restored `element` | |
some_op = tf.get_collection('some_op')[0] | |
# "Further evaluation" | |
print("More testing") | |
while True: | |
try: | |
print(sess.run(some_op, feed_dict={handle: new_handle})) | |
except tf.errors.OutOfRangeError: | |
break | |
if __name__ == '__main__': | |
train_dataset = tf.data.Dataset.from_tensor_slices( | |
tf.constant(np.random.randint(0, 100, (5, 2)), dtype=tf.float32)) | |
test_dataset = tf.data.Dataset.from_tensor_slices( | |
tf.constant(np.random.randint(0, 100, (2, 2)), dtype=tf.float32)) | |
train(train_dataset, test_dataset) | |
# Now want to evaluate the results of another test dataset | |
another_test_dataset = tf.data.Dataset.from_tensor_slices( | |
tf.constant(np.random.randint(0, 100, (4, 2)), dtype=tf.float32)) | |
eval(another_test_dataset) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment