Skip to content

Instantly share code, notes, and snippets.

@tlaitinen
Created December 29, 2017 14:06
Show Gist options
  • Select an option

  • Save tlaitinen/963d7c4d68edd0f4eda7a8e7bb739a6c to your computer and use it in GitHub Desktop.

Select an option

Save tlaitinen/963d7c4d68edd0f4eda7a8e7bb739a6c to your computer and use it in GitHub Desktop.
import tensorflow as tf
window_size = 10
batch_size = 10
data = tf.range(100)
indices = tf.data.Dataset.range(90)
window_fn = lambda x: data[x:x+window_size]
train = indices.filter(lambda x:tf.less(tf.mod(x,10),7)).map(window_fn).batch(batch_size).repeat()
test = indices.filter(lambda x:tf.logical_and(tf.greater_equal(tf.mod(x,10), 7), tf.less(tf.mod(x,10), 9))).map(window_fn).batch(batch_size).repeat()
validation = indices.filter(lambda x:tf.equal(tf.mod(x, 10),9)).map(window_fn).batch(batch_size).repeat()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
train_iterator = train.make_one_shot_iterator()
test_iterator = test.make_one_shot_iterator()
validation_iterator = validation.make_one_shot_iterator()
next_train_element = train_iterator.get_next()
next_test_element = test_iterator.get_next()
next_validation_element = validation_iterator.get_next()
print ('train')
for i in range(10):
print(sess.run(next_train_element))
print ('test')
for i in range(10):
print(sess.run(next_test_element))
print ('validation')
for i in range(10):
print(sess.run(next_validation_element))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment