Created
August 8, 2017 17:43
-
-
Save timvieira/aceb64047aed1b13bf4e4da3b9a4c0ea to your computer and use it in GitHub Desktop.
Memory-efficient backpropagation in an RNN. Accompanies blog post: http://timvieira.github.io/blog/post/2016/10/01/reversing-a-sequence-with-sublinear-space/
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Memory-efficient backpropagation in an RNN. | |
Accompanies blog post: | |
http://timvieira.github.io/blog/post/2016/10/01/reversing-a-sequence-with-sublinear-space/ | |
""" | |
import numpy as np | |
from arsenal.math.checkgrad import fdcheck | |
def rnn(f, x, n, df=None, adj=None): | |
"""Simple RNN where we only care about the final state. | |
It should be straightforward to extend this to have shared params at each | |
time step and to output things other the last state. | |
""" | |
xs = [x] | |
for _ in range(n): | |
x = f(x) | |
xs.append(x) | |
if df: | |
if adj is None: | |
adj = np.ones_like(x) | |
for x in reversed(xs): | |
adj = adj * df(x) | |
return xs[-1], adj | |
return xs[-1] | |
def rnn_memory_efficient(f, x, n, df=None, adj=None): | |
"Simple RNN with memory-efficient backprop." | |
x0 = x*1 # remember first | |
for _ in range(n): | |
x = f(x) | |
xn = x*1 # remember last | |
if df: | |
if adj is None: | |
adj = np.ones_like(x) | |
# use log-space helper function to *reconstruct* hidden state at each | |
# time step *in reverse*. | |
for x in recursive(f, x0, 0, n): | |
adj = adj * df(x) | |
return xn, adj | |
return xn | |
def recursive(f, s0, b, e): | |
"Helper function for memory-efficient sequence reversal." | |
if e - b == 1: | |
yield s0 | |
else: | |
# do O(n/2) work to find the midpoint with O(1) space. | |
s = s0 | |
d = (e-b)//2 | |
for _ in range(d): | |
s = f(s) | |
for s in recursive(f, s, b+d, e): | |
yield s | |
for s in recursive(f, s0, b, b+d): | |
yield s | |
def test(): | |
x = np.random.uniform(-1, 1, size=100) | |
n = 100 # length of RNN sequence | |
# Pick a non-linearity | |
#from scipy.special import expit as sigmoid | |
#f = lambda x: sigmoid(x) | |
#df = lambda x: sigmoid(x) * (1-sigmoid(x)) | |
#f = lambda x: x**2 | |
#df = lambda x: 2*x | |
f = np.sin | |
df = np.cos | |
# randomly weight the encoding vector to get a simpler scalar objective, | |
# which is easier to test with finite-differences. | |
w = np.random.uniform(-1,1,size=x.shape) | |
#_, grad = rnn(f, x, n, df, adj=w) | |
_, grad = rnn_memory_efficient(f, x, n, df, adj=w) | |
fdcheck(lambda: rnn(f, x, n).dot(w), x, grad).show() | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment