Created
January 23, 2018 07:21
-
-
Save Multihuntr/d7d1d1a81d2621ec622468a3ec8effca to your computer and use it in GitHub Desktop.
A minimal example of how you can accumulate gradients across batches, allowing you to train using much larger batch sizes than can fit in memory at the cost of speed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import sys | |
from tensorflow.examples.tutorials.mnist import input_data | |
n_pseudo_batches = int(sys.argv[1]) if len(sys.argv) > 1 else 128 | |
actual_batch_size = int(sys.argv[2]) if len(sys.argv) > 2 else 32 | |
iterations = int(sys.argv[3]) if len(sys.argv) > 3 else 10 | |
tf.set_random_seed(147258) | |
np.random.seed(123456) | |
def simple_model(input): | |
# These initializers ensure that the model will always be instantiated the same, for comparison. | |
hidden_initializer = tf.constant_initializer(np.random.uniform(-0.025, 0.025, size=[784,100])) | |
hidden = tf.layers.dense(input, 100, kernel_initializer=hidden_initializer) | |
out_initializer = tf.constant_initializer(np.random.uniform(-0.025, 0.025, size=[100,10])) | |
return tf.layers.dense(tf.nn.relu(hidden), 10, kernel_initializer=out_initializer) | |
inp = tf.placeholder(tf.float32, [None,784]) | |
targ = tf.placeholder(tf.float32, [None,10]) | |
# Make our model and optimizer and gradients | |
out = simple_model(inp) | |
opt = tf.train.GradientDescentOptimizer(learning_rate=1e-1) | |
loss = tf.losses.mean_squared_error(out, targ) | |
# standard gradients for a batch | |
t_vars = tf.trainable_variables() | |
grads, graph_vars = zip(*opt.compute_gradients(loss, t_vars)) | |
# IMPORTANT: Make sure you call the zero_ops to reset the accumulated gradients | |
# tl;dr - Add the below section to your code to accumulate gradients | |
# ----------------------------------------------------------------------------------------------------- | |
# Define our divisor, used to normalise gradients across pseudo_batches | |
divisor = tf.Variable(0, trainable=False) | |
div_fl = tf.to_float(divisor) | |
zero_divisor = divisor.assign(0) | |
inc_divisor = divisor.assign(divisor+1) | |
# Accumulation ops and variables | |
# create a copy of all trainable variables with `0` as initial values | |
accum_grads = [tf.Variable(tf.zeros_like(t_var.initialized_value()), trainable=False) for t_var in t_vars] | |
# create an op to zero all accums vars (and zero the divisor again) | |
with tf.control_dependencies([zero_divisor]): | |
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_grads] | |
# Create ops for accumulating the gradient (also adds one to the final divisor) | |
with tf.control_dependencies([inc_divisor]): | |
accum_ops = [accum_grad.assign_add(grad) for (accum_grad, grad) in zip(accum_grads, grads)] | |
# Create op that updates the weights (also divides accumulated gradients by the number of steps) | |
normalised_accum_grads = [accum_grad/div_fl for (accum_grad) in accum_grads] | |
# ------------------------------------------------------------------------------------------------------ | |
train_op = opt.apply_gradients(zip(normalised_accum_grads, graph_vars)) | |
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, seed=764847) | |
with tf.Session() as sess: | |
sess.run(tf.global_variables_initializer()) | |
for x in range(iterations): | |
iteration_loss = 0 | |
for y in range(n_pseudo_batches): | |
inp_, targ_ = mnist.train.next_batch(actual_batch_size) | |
_, loss_ = sess.run((accum_ops, loss), {inp: inp_, targ: targ_}) | |
iteration_loss += loss_ | |
# To find actual loss, you need to divide by the number of pseudo batches | |
print(iteration_loss/n_pseudo_batches) | |
sess.run(train_op) | |
# vvvv --- VERY IMPORTANT! --- vvvv | |
sess.run(zero_ops) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
That's a very nice implementation for something that I needed. Thanks for sharing!