Created
September 4, 2017 11:58
-
-
Save tomrunia/dc6253228490fbb2ceca07d00c9f8343 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def load_and_enqueue(input_dir, sess, coord, enqueue_op, queue_inputs, queue_targets, num_examples, examples_per_file=100, rewrite_targets=True): | |
# Check if we have a sufficient number of HDF5 files to load all the samples | |
filenames_queue = glob.glob(os.path.join(input_dir, "train/*.h5")) | |
filenames_queue.sort() | |
assert len(filenames_queue) > 0 | |
examples_available = len(filenames_queue)*examples_per_file | |
num_examples = min(examples_available, num_examples) | |
num_files = int(math.ceil(num_examples / examples_per_file)) | |
filenames_queue = filenames_queue[0:num_files] | |
while True: | |
for filename in filenames_queue: | |
# Read 100 examples from HDF5 file, shuffle files within file | |
inputs, targets, motion_types = read_examples_from_file(filename, shuffle=True, rewrite_targets=rewrite_targets) | |
feed_dict = { | |
queue_inputs: inputs, | |
queue_targets: targets | |
} | |
# Feed examples to the FIFO queue | |
sess.run(enqueue_op, feed_dict=feed_dict) | |
if coord.should_stop(): return | |
random.shuffle(filenames_queue) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment