Skip to content

Instantly share code, notes, and snippets.

@timvieira
Created August 8, 2017 17:43
Show Gist options
  • Save timvieira/aceb64047aed1b13bf4e4da3b9a4c0ea to your computer and use it in GitHub Desktop.
Save timvieira/aceb64047aed1b13bf4e4da3b9a4c0ea to your computer and use it in GitHub Desktop.
Memory-efficient backpropagation in an RNN. Accompanies blog post: http://timvieira.github.io/blog/post/2016/10/01/reversing-a-sequence-with-sublinear-space/
"""
Memory-efficient backpropagation in an RNN.
Accompanies blog post:
http://timvieira.github.io/blog/post/2016/10/01/reversing-a-sequence-with-sublinear-space/
"""
import numpy as np
from arsenal.math.checkgrad import fdcheck
def rnn(f, x, n, df=None, adj=None):
"""Simple RNN where we only care about the final state.
It should be straightforward to extend this to have shared params at each
time step and to output things other the last state.
"""
xs = [x]
for _ in range(n):
x = f(x)
xs.append(x)
if df:
if adj is None:
adj = np.ones_like(x)
for x in reversed(xs):
adj = adj * df(x)
return xs[-1], adj
return xs[-1]
def rnn_memory_efficient(f, x, n, df=None, adj=None):
"Simple RNN with memory-efficient backprop."
x0 = x*1 # remember first
for _ in range(n):
x = f(x)
xn = x*1 # remember last
if df:
if adj is None:
adj = np.ones_like(x)
# use log-space helper function to *reconstruct* hidden state at each
# time step *in reverse*.
for x in recursive(f, x0, 0, n):
adj = adj * df(x)
return xn, adj
return xn
def recursive(f, s0, b, e):
"Helper function for memory-efficient sequence reversal."
if e - b == 1:
yield s0
else:
# do O(n/2) work to find the midpoint with O(1) space.
s = s0
d = (e-b)//2
for _ in range(d):
s = f(s)
for s in recursive(f, s, b+d, e):
yield s
for s in recursive(f, s0, b, b+d):
yield s
def test():
x = np.random.uniform(-1, 1, size=100)
n = 100 # length of RNN sequence
# Pick a non-linearity
#from scipy.special import expit as sigmoid
#f = lambda x: sigmoid(x)
#df = lambda x: sigmoid(x) * (1-sigmoid(x))
#f = lambda x: x**2
#df = lambda x: 2*x
f = np.sin
df = np.cos
# randomly weight the encoding vector to get a simpler scalar objective,
# which is easier to test with finite-differences.
w = np.random.uniform(-1,1,size=x.shape)
#_, grad = rnn(f, x, n, df, adj=w)
_, grad = rnn_memory_efficient(f, x, n, df, adj=w)
fdcheck(lambda: rnn(f, x, n).dot(w), x, grad).show()
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment