-
-
Save mbsariyildiz/f4fe854c93bd4cad2bdae94b45fd0d3a to your computer and use it in GitHub Desktop.
Simple example of using tf.data.Dataset to create a data input pipeline from RAM to GPU.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
print 'tf_version: ', tf.__version__ # it is 1.4.0 right now | |
np.set_printoptions(linewidth=150, precision=3, suppress=True) | |
M = 10 | |
d = 2 | |
# samples | |
X = tf.constant(np.random.randn(M, d), 'float32') | |
# ids of samples, say each sample have different id | |
Y = tf.constant(np.expand_dims(np.random.permutation(M), 1), 'float32') | |
dset_items = (X,Y) | |
# first dimensions must match | |
# also they should be at least 2 rank | |
first_dims = [item.shape.as_list()[0] for item in dset_items] | |
assert np.all(np.equal(first_dims, first_dims[0])) | |
batch_size = 2 | |
n_epochs = 5 | |
dset = tf.data.Dataset.from_tensor_slices(dset_items) | |
dset = dset.shuffle(M) | |
dset = dset.repeat(n_epochs) | |
dset = dset.batch(batch_size) | |
dset = dset.prefetch(2) | |
dset_iterator = dset.make_initializable_iterator() | |
next_batch = dset_iterator.get_next() | |
sess = tf.Session() | |
sess.run(tf.global_variables_initializer()) | |
sess.run(dset_iterator.initializer) | |
occurrence = np.zeros([M], 'int32') | |
it = 0 | |
while True: | |
try: | |
it += 1 | |
xb, yb = sess.run(next_batch) | |
occurrence[np.int32(yb)] += 1 | |
print '%03d, x:%s, y:%s ' % (it, str(xb.ravel()), str(yb.ravel())) | |
except tf.errors.OutOfRangeError: | |
print 'end of dataset' | |
break | |
sess.close() | |
print 'occurrence array:', occurrence # all entries should be M, indicating each sample is fetched M times |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment