Skip to content

Instantly share code, notes, and snippets.

@post2web
Created May 4, 2018 04:49
Show Gist options
  • Select an option

  • Save post2web/af750b461de38d76ac0d810b484eab44 to your computer and use it in GitHub Desktop.

Select an option

Save post2web/af750b461de38d76ac0d810b484eab44 to your computer and use it in GitHub Desktop.
import tensorflow as tf
train_dataset = tf.data.Dataset.range(10)
train_iterator = train_dataset.make_one_shot_iterator()
train_next = train_iterator.get_next()
test_dataset = tf.data.Dataset.range(10, 20)
test_iterator = test_dataset.make_one_shot_iterator()
test_next = test_iterator.get_next()
training = tf.Variable(True, trainable=False)
set_testing = tf.assign(training, False)
set_training = tf.assign(training, True)
next_element = tf.cond(
training,
train_iterator.get_next,
test_iterator.get_next
)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(next_element)
sess.run(set_testing)
sess.run(next_element)
sess.run(set_training)
sess.run(next_element)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment