Skip to content

Instantly share code, notes, and snippets.

@bzamecnik
Created October 28, 2017 13:38
Show Gist options
  • Save bzamecnik/05ff3bda1f82b16e325a3a1857fd2994 to your computer and use it in GitHub Desktop.
Save bzamecnik/05ff3bda1f82b16e325a3a1857fd2994 to your computer and use it in GitHub Desktop.
TF StagingArea with variable batch size
# Does StagingArea support batches of variable size? Yes.
#
# The training or validation set might not be exactly divisible by the batch
# size. Thus at the end one batch might be smaller. We can either ignore
# (incorrect with respect to the loss) it or provide batches with variable size.
# On the other hand we'd like to ensure the data points have the same shape.
#
# It turns out
import numpy as np
import tensorflow as tf
from tensorflow.python.ops.data_flow_ops import StagingArea
# Basic StagingArea with constant batch size
with tf.Session() as sess:
# batches of 4 samples with shape (1,)
area = StagingArea(dtypes=[tf.int32], shapes=[(4, 1)])
sess.run(area.put([tf.constant([[1],[2],[3],[4]])]))
# smaller batch
try:
sess.run(area.put([tf.constant([[5],[6]])]))
except ValueError as ex:
assert ex.args[0] == "Shapes (2, 1) and (4, 1) are incompatible"
# different data sample shape
try:
sess.run(area.put([tf.constant([[1,1], [2,2], [3,3], [4,4]])]))
except ValueError as ex:
assert ex.args[0] == "Shapes (4, 2) and (4, 1) are incompatible"
# StagingArea support batches of variable size
with tf.Session() as sess:
area = StagingArea(dtypes=[tf.int32])
sess.run(area.put([tf.constant([1,2,3,4])]))
# smaller batch
sess.run(area.put([tf.constant([5,6])]))
assert np.all(sess.run(area.get()) == [1,2,3,4])
assert np.all(sess.run(area.get()) == [5,6])
# data sample shape not enforced :/
sess.run(area.put([tf.constant([[5,5], [6,6]])]))
assert np.all(sess.run(area.get()) == [[5,5], [6,6]])
# StagingArea supports batches of variable number of samples, each with the same shape.
with tf.Session() as sess:
area = StagingArea(dtypes=[tf.int32], shapes=[(None, 2)])
# normal batch
sess.run(area.put([tf.constant([[1,1],[2,2],[3,3],[4,4]])]))
# smaller batch (eg. at the end of training set)
sess.run(area.put([tf.constant([[5,5],[6,6]])]))
assert np.all(sess.run(area.get()) == [[1,1],[2,2],[3,3],[4,4]])
assert np.all(sess.run(area.get()) == [[5,5],[6,6]])
# data sample shape not enforced :/
try:
sess.run(area.put([tf.constant([[5,5,5], [6,6,6]])]))
except ValueError as ex:
assert ex.args[0] == "Shapes (2, 3) and (?, 2) are incompatible"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment