Skip to content

Instantly share code, notes, and snippets.

@mmourafiq
Created May 3, 2017 14:05
Show Gist options
  • Save mmourafiq/88e1cb5c391f6a6a7df6371fe89f3880 to your computer and use it in GitHub Desktop.
Save mmourafiq/88e1cb5c391f6a6a7df6371fe89f3880 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
class FeederConfig(object):
"""The FeederConfig holds information needed to create data feeders for training and evaluating.
Arguments:
num_threads: `int`. Total number of simultaneous threads to process data.
max_queue: `int`. Maximum number of data stored in a queue.
shuffle: `bool`. If True, data will be shuffle.
ensure_data_order: `bool`. Ensure that data order is keeped when using
'next' to retrieve data (Processing will be slower).
"""
def __init__(self, num_threads=4, max_queue=32, capacity=2000, shuffle=True, ensure_data_order=False):
self.num_threads = num_threads
self.max_queue = max_queue
self.shuffle = shuffle
self.capacity = capacity
if ensure_data_order:
self.num_threads = 1
self.max_queue = 1
class Feeder(object):
"""This class manages the the background threads needed to fill a queue full of data."""
class QueueIndex(object):
TRAIN = 0
VAL = 1
TEST = 2
def __init__(self, inputs, outputs, config):
self.inputs = inputs
self.outputs = outputs
self.queue_index = tf.placeholder(dtype=tf.int32, shape=[])
self.batch_size = tf.placeholder(dtype=tf.int32, shape=[])
self.config = config
self.num_samples = 0
self.step = 0
self.epoch = 0
self.current_iter = 0
self.queue_train = None
self.queue_val = None
self.queue_test = None
self.queue = None
self.set_queue()
self.enqueue_train_op = self.queue_train.enqueue_many(self.placeholders)
self.enqueue_val_op = self.queue_val.enqueue_many(self.placeholders)
self.enqueue_test_op = self.queue_test.enqueue_many(self.inputs)
self.dequeue_op = self.queue.dequeue_many(self.batch_size)
tf.add_to_collection(name='queues', value=self.dequeue_op)
@property
def placeholders(self):
return [self.inputs, self.outputs] if self.outputs is not None else [self.inputs]
@staticmethod
def get_shape(x):
return x.get_shape().as_list()
def set_queue(self):
if self.config.shuffle:
self.queue_train = tf.RandomShuffleQueue(
dtypes=[x.dtype for x in self.placeholders],
shapes=[self.get_shape(x)[1:] for x in self.placeholders],
capacity=self.config.capacity, min_after_dequeue=1000)
else:
self.queue_train = tf.FIFOQueue(dtypes=[x.dtype for x in self.placeholders],
shapes=[self.get_shape(x)[1:] for x in self.placeholders],
capacity=2000)
self.queue_val = tf.FIFOQueue(dtypes=[x.dtype for x in self.placeholders],
shapes=[self.get_shape(x)[1:] for x in self.placeholders],
capacity=self.config.capacity)
self.queue_test = tf.FIFOQueue(dtypes=[x.dtype for x in self.placeholders],
shapes=[self.get_shape(x)[1:] for x in self.placeholders],
capacity=self.config.capacity)
self.queue = tf.QueueBase.from_list(index=self.queue_index, queues=[self.queue_train,
self.queue_val,
self.queue_test])
def _update_counters(self, index, batch_size):
if index != self.QueueIndex.TRAIN:
return
self.step += 1
self.current_iter = min(self.step * batch_size, self.num_samples)
if self.current_iter == self.num_samples:
self.epoch += 1
self.step = 0
def get_inputs(self, session, queue_index, batch_size):
"""Return's tensors containing a batch of X and y or a batch of X.
If the Feeder is used for evaluation and only X is enqueued and dequeued.
"""
self._update_counters(queue_index, batch_size)
X_batch, y_batch = session.run(self.dequeue_op,
{self.queue_index: queue_index, self.batch_size: batch_size})
return X_batch, y_batch
def get_enqueue_op(self, queue_index):
enqueue_op = None
if queue_index == self.QueueIndex.TRAIN:
enqueue_op = self.enqueue_train_op
elif queue_index == self.QueueIndex.VAL:
enqueue_op = self.enqueue_val_op
elif queue_index == self.QueueIndex.TEST:
enqueue_op = self.enqueue_test_op
return enqueue_op
def get_queue(self, queue_index):
queue = None
if queue_index == self.QueueIndex.TRAIN:
queue = self.queue_train
elif queue_index == self.QueueIndex.VAL:
queue = self.queue_val
elif queue_index == self.QueueIndex.TEST:
queue = self.queue_test
return queue
def close_queue(self, session):
self.queue.close(cancel_pending_enqueues=True)
session.run(self.queue.close(), {self.queue_index: self.QueueIndex.TRAIN})
session.run(self.queue.close(), {self.queue_index: self.QueueIndex.VAL})
session.run(self.queue.close(), {self.queue_index: self.QueueIndex.TEST})
def enqueue(self, queue_index, X, y=None):
queue = self.get_queue(queue_index)
enqueue_op = queue.enqueue_many([X, y] if y is not None else [X])
queue_runner = tf.train.QueueRunner(queue=self.queue, enqueue_ops=[enqueue_op])
tf.train.add_queue_runner(queue_runner)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment