Skip to content

Instantly share code, notes, and snippets.

@annarailton
Created June 22, 2018 10:02
Show Gist options
  • Save annarailton/98920701a744762a91f2c67f3f5bb467 to your computer and use it in GitHub Desktop.
Save annarailton/98920701a744762a91f2c67f3f5bb467 to your computer and use it in GitHub Desktop.
Save and restore a Tensorflow model with a tf.data.Dataset + one_shot_iterator()
# Tensorflow 1.8.0
import tensorflow as tf
import numpy as np
def save(dataset):
"""
Create graph with an Dataset and Iterator and save the model.
There is some op that is applied to the data from the iterator.
"""
iterator_handle = tf.placeholder(tf.string, shape=[])
tf.add_to_collection('iterator_handle', iterator_handle)
iterator = tf.data.Iterator.from_string_handle(
iterator_handle,
dataset.output_types,
dataset.output_shapes)
dataset_iterator = dataset.make_one_shot_iterator()
element = iterator.get_next()
some_op = tf.multiply(element, 0.5)
tf.add_to_collection('some_op', some_op)
v = tf.get_variable('v', initializer=tf.zeros([])) # Needs a variable to save the model
saver = tf.train.Saver()
with tf.Session() as session:
session.run(tf.global_variables_initializer())
handle_val = session.run(dataset_iterator.string_handle())
for _ in range(4):
print(session.run(some_op,
feed_dict={iterator_handle: handle_val}))
saver.save(session, 'checkpoints/fufu')
def restore(dataset):
"""Restore the model from file and pass some new data through it"""
with tf.Session() as session:
saver = tf.train.import_meta_graph('checkpoints/fufu.meta')
saver.restore(session, 'checkpoints/fufu')
iterator_handle = tf.get_collection('iterator_handle')[0]
# Make new iterator
iterator = dataset.make_one_shot_iterator()
new_handle = session.run(iterator.string_handle())
# Don't need to call iterator.get_next() again as `some_op` will use
# restored `element`
some_op = tf.get_collection('some_op')[0]
for _ in range(4):
print(session.run(some_op, {iterator_handle: new_handle}))
if __name__ == '__main__':
raw_data = np.array([[0, 0],
[0, 1],
[1, 0],
[1, 1]])
dataset1 = tf.data.Dataset.from_tensor_slices(tf.constant(raw_data, dtype=tf.float32))
dataset2 = tf.data.Dataset.from_tensor_slices(tf.constant(raw_data * 2, dtype=tf.float32))
save(dataset1)
# Restore works with any data of the same shape in a tf.data.Dataset
# To use different shaped data use initialisable iterator
restore(dataset1)
restore(dataset2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment