Created
March 5, 2017 07:51
-
-
Save joyhuang9473/3f7ac7833d6e97b098a8e9addb93bb55 to your computer and use it in GitHub Desktop.
reader template: https://github.com/chiphuyen/tf-stanford-tutorials/blob/master/examples/05_csv_reader.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
""" | |
import tensorflow as tf | |
def batch_generator(filenames, BATCH_SIZE): | |
""" filenames is the list of files you want to read from. | |
In this case, it contains only heart.csv | |
""" | |
filename_queue = tf.train.string_input_producer(filenames) | |
reader = tf.TextLineReader(skip_header_lines=1) # skip the first line in the file | |
_, value = reader.read(filename_queue) | |
# record_defaults are the default values in case some of our columns are empty | |
#------------- | |
# TODO | |
# record_defaults = | |
#------------- | |
# read data | |
content = tf.decode_csv(value, record_defaults=record_defaults) | |
# pack all features into a tensor | |
#------------- | |
# TODO | |
# features = | |
#------------- | |
# assign label | |
#------------- | |
# TODO | |
# label = | |
#------------- | |
# minimum number elements in the queue after a dequeue, used to ensure | |
# that the samples are sufficiently mixed | |
min_after_dequeue = 10 * BATCH_SIZE | |
# the maximum number of elements in the queue | |
capacity = 20 * BATCH_SIZE | |
# shuffle the data to generate BATCH_SIZE sample pairs | |
data_batch, label_batch = tf.train.shuffle_batch([features, label], batch_size=BATCH_SIZE, | |
capacity=capacity, min_after_dequeue=min_after_dequeue) | |
return data_batch, label_batch | |
def generate_batches(data_batch, label_batch): | |
with tf.Session() as sess: | |
coord = tf.train.Coordinator() | |
threads = tf.train.start_queue_runners(coord=coord) | |
for _ in range(10): # generate 10 batches | |
features, labels = sess.run([data_batch, label_batch]) | |
print features | |
coord.request_stop() | |
coord.join(threads) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment