Skip to content

Instantly share code, notes, and snippets.

@joyhuang9473
Created March 5, 2017 07:51
Show Gist options
  • Save joyhuang9473/3f7ac7833d6e97b098a8e9addb93bb55 to your computer and use it in GitHub Desktop.
Save joyhuang9473/3f7ac7833d6e97b098a8e9addb93bb55 to your computer and use it in GitHub Desktop.
"""
"""
import tensorflow as tf
def batch_generator(filenames, BATCH_SIZE):
""" filenames is the list of files you want to read from.
In this case, it contains only heart.csv
"""
filename_queue = tf.train.string_input_producer(filenames)
reader = tf.TextLineReader(skip_header_lines=1) # skip the first line in the file
_, value = reader.read(filename_queue)
# record_defaults are the default values in case some of our columns are empty
#-------------
# TODO
# record_defaults =
#-------------
# read data
content = tf.decode_csv(value, record_defaults=record_defaults)
# pack all features into a tensor
#-------------
# TODO
# features =
#-------------
# assign label
#-------------
# TODO
# label =
#-------------
# minimum number elements in the queue after a dequeue, used to ensure
# that the samples are sufficiently mixed
min_after_dequeue = 10 * BATCH_SIZE
# the maximum number of elements in the queue
capacity = 20 * BATCH_SIZE
# shuffle the data to generate BATCH_SIZE sample pairs
data_batch, label_batch = tf.train.shuffle_batch([features, label], batch_size=BATCH_SIZE,
capacity=capacity, min_after_dequeue=min_after_dequeue)
return data_batch, label_batch
def generate_batches(data_batch, label_batch):
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for _ in range(10): # generate 10 batches
features, labels = sess.run([data_batch, label_batch])
print features
coord.request_stop()
coord.join(threads)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment