Created
June 22, 2018 10:02
-
-
Save annarailton/98920701a744762a91f2c67f3f5bb467 to your computer and use it in GitHub Desktop.
Save and restore a Tensorflow model with a tf.data.Dataset + one_shot_iterator()
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Tensorflow 1.8.0 | |
import tensorflow as tf | |
import numpy as np | |
def save(dataset): | |
""" | |
Create graph with an Dataset and Iterator and save the model. | |
There is some op that is applied to the data from the iterator. | |
""" | |
iterator_handle = tf.placeholder(tf.string, shape=[]) | |
tf.add_to_collection('iterator_handle', iterator_handle) | |
iterator = tf.data.Iterator.from_string_handle( | |
iterator_handle, | |
dataset.output_types, | |
dataset.output_shapes) | |
dataset_iterator = dataset.make_one_shot_iterator() | |
element = iterator.get_next() | |
some_op = tf.multiply(element, 0.5) | |
tf.add_to_collection('some_op', some_op) | |
v = tf.get_variable('v', initializer=tf.zeros([])) # Needs a variable to save the model | |
saver = tf.train.Saver() | |
with tf.Session() as session: | |
session.run(tf.global_variables_initializer()) | |
handle_val = session.run(dataset_iterator.string_handle()) | |
for _ in range(4): | |
print(session.run(some_op, | |
feed_dict={iterator_handle: handle_val})) | |
saver.save(session, 'checkpoints/fufu') | |
def restore(dataset): | |
"""Restore the model from file and pass some new data through it""" | |
with tf.Session() as session: | |
saver = tf.train.import_meta_graph('checkpoints/fufu.meta') | |
saver.restore(session, 'checkpoints/fufu') | |
iterator_handle = tf.get_collection('iterator_handle')[0] | |
# Make new iterator | |
iterator = dataset.make_one_shot_iterator() | |
new_handle = session.run(iterator.string_handle()) | |
# Don't need to call iterator.get_next() again as `some_op` will use | |
# restored `element` | |
some_op = tf.get_collection('some_op')[0] | |
for _ in range(4): | |
print(session.run(some_op, {iterator_handle: new_handle})) | |
if __name__ == '__main__': | |
raw_data = np.array([[0, 0], | |
[0, 1], | |
[1, 0], | |
[1, 1]]) | |
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant(raw_data, dtype=tf.float32)) | |
dataset2 = tf.data.Dataset.from_tensor_slices(tf.constant(raw_data * 2, dtype=tf.float32)) | |
save(dataset1) | |
# Restore works with any data of the same shape in a tf.data.Dataset | |
# To use different shaped data use initialisable iterator | |
restore(dataset1) | |
restore(dataset2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment