-
-
Save sdwfrost/8b0d515883c48b0ce93c770c47912c48 to your computer and use it in GitHub Desktop.
An Efficient, Batched, Stateful LSTM layer in Numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from utils import orthogonal, tanh, sigmoid, dtanh, dsigmoid | |
class LSTM(object): | |
"""Long Short Term Memory Unit | |
Parameters | |
---------- | |
d_input : int | |
Length of input per time step | |
d_hidden : int | |
Number of LSTM cells | |
f_bias_init : int | |
Forget Gate bias initialization. In long term memory tasks, | |
having a larger positive value could be crucial for learning | |
long range dependencies | |
name: str | |
A label for debugging purposes | |
""" | |
def __init__(self, d_input, d_hidden, f_bias_init=1.0, name=''): | |
# use a single concatenated matrix for all gates and input | |
W = np.empty((4 * d_hidden, d_input + d_hidden + 1)) | |
# orthogonal input -> hidden, identity hidden -> hidden, all biases except forget gate 0 | |
for i in range(4): | |
W[i*d_hidden:(i + 1) * d_hidden, :d_input] = orthogonal((d_hidden, d_input)) | |
W[i*d_hidden:(i + 1) * d_hidden, d_input:-1] = np.eye(d_hidden) | |
W[2 * d_hidden:3 * d_hidden, -1] = f_bias_init | |
self.W = W | |
self.d_input, self.d_hidden, self.name = d_input, d_hidden, name | |
self.forget() | |
def __call__(self, X): | |
X = X[0] | |
B = X.shape[1] | |
d_input, d_hidden = self.d_input, self.d_hidden | |
if self.t == 0: | |
self.c_acc[-1] = np.zeros((d_hidden, B)) | |
self.H_p = np.zeros((d_hidden, B)) | |
t = self.t | |
inp = np.zeros((d_input + d_hidden + 1, B)) | |
inp[:d_input] = X | |
inp[d_input:-1] = self.H_p | |
V = np.dot(self.W, inp) | |
V[:d_hidden] = tanh(V[:d_hidden]) | |
V[d_hidden:] = sigmoid(V[d_hidden:]) | |
Z, I, F, O = np.split(V, 4, axis=0) | |
c = Z * I + F * self.c_acc[t-1] | |
C = tanh(c) | |
H = O * C | |
# accumulate for backprop | |
self.c_acc[t] = c; self.C_acc[t] = C; self.Z_acc[t] = Z | |
self.I_acc[t] = I; self.F_acc[t] = F; self.O_acc[t] = O | |
self.inp_acc[t] = inp; self.H_p = H; self.t += 1 | |
return H[np.newaxis] | |
def forward(self, X): | |
T, n, B = X.shape | |
H = np.empty((T, self.d_hidden, B)) | |
for t in xrange(T): | |
H[t] = self.__call__(X[t:t+1]) | |
return H | |
def backward(self, dH): | |
T, _, B = dH.shape | |
d_input, d_hidden = self.d_input, self.d_hidden | |
dW = np.zeros_like(self.W) | |
dX = np.zeros((T, d_input, B)) | |
dh_p = np.zeros((d_hidden, B)) | |
dc_p = np.zeros((d_hidden, B)) | |
for t in reversed(xrange(T)): | |
c = self.c_acc[t]; C = self.C_acc[t]; Z = self.Z_acc[t] | |
I = self.I_acc[t]; F = self.F_acc[t]; O = self.O_acc[t] | |
inp = self.inp_acc[t] | |
dh = dH[t] + dh_p | |
dO = C * dh | |
dC = O * dh | |
dc = (1.0 - C ** 2) * dC | |
dc = dc + dc_p | |
dF = self.c_acc[t-1] * dc | |
dc_p = F * dc | |
dI = Z * dc | |
dZ = I * dc | |
dz = dtanh(dZ, Z) | |
di = dsigmoid(dI, I) | |
df = dsigmoid(dF, F) | |
do = dsigmoid(dO, O) | |
dV = np.concatenate([dz, di, df, do], axis=0) | |
dW += np.dot(dV, inp.T) | |
dinp = np.dot(self.W.T, dV) | |
dX[t] += dinp[:d_input] | |
dh_p = dinp[d_input:-1] | |
self.dW = dW | |
self.forget() | |
return dX | |
def set_parameters(self, P): | |
self.W = np.reshape(P, self.W.shape) | |
def get_parameters(self): | |
return self.W.flatten() | |
def get_gradients(self): | |
dP = self.dW.flatten() | |
dW = None | |
dX = None | |
dh_p = None | |
dc_p = None | |
return dP | |
def forget(self): | |
self.t = 0 | |
self.c_acc = {}; self.C_acc = {}; self.Z_acc = {} | |
self.I_acc = {}; self.F_acc = {}; self.O_acc = {} | |
self.inp_acc = {} | |
self.H_p = None | |
def __repr__(self): | |
return 'Layer: ' + self.name + '\tNumber of Parameters: ' + str(self.W.size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def orthogonal(shape): | |
""" | |
taken from: https://github.com/Lasagne/Lasagne/blob/master/lasagne/init.py#L327-L367 | |
""" | |
a = np.random.normal(0.0, 1.0, shape) | |
u, _, v = np.linalg.svd(a, full_matrices=False) | |
q = u if u.shape == shape else v # pick the one with the correct shape | |
return q | |
def tanh(X): | |
return np.tanh(X) | |
def sigmoid(X): | |
return 1.0 / (1.0 + np.exp(-X)) | |
def dtanh(dY, Y): | |
return (1.0 - Y ** 2) * dY | |
def dsigmoid(dY, Y): | |
return Y * (1.0 - Y) * dY |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment