Last active
July 6, 2017 11:48
-
-
Save pranv/4a94c8a151703c910472c5d023be9bf4 to your computer and use it in GitHub Desktop.
An Efficient, Batched, Stateful LSTM layer in Numpy
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from utils import orthogonal, tanh, sigmoid, dtanh, dsigmoid | |
class LSTM(object): | |
"""Long Short Term Memory Unit | |
Parameters | |
---------- | |
d_input : int | |
Length of input per time step | |
d_hidden : int | |
Number of LSTM cells | |
f_bias_init : int | |
Forget Gate bias initialization. In long term memory tasks, | |
having a larger positive value could be crucial for learning | |
long range dependencies | |
name: str | |
A label for debugging purposes | |
""" | |
def __init__(self, d_input, d_hidden, f_bias_init=1.0, name=''): | |
# use a single concatenated matrix for all gates and input | |
W = np.empty((4 * d_hidden, d_input + d_hidden + 1)) | |
# orthogonal input -> hidden, identity hidden -> hidden, all biases except forget gate 0 | |
for i in range(4): | |
W[i*d_hidden:(i + 1) * d_hidden, :d_input] = orthogonal((d_hidden, d_input)) | |
W[i*d_hidden:(i + 1) * d_hidden, d_input:-1] = np.eye(d_hidden) | |
W[2 * d_hidden:3 * d_hidden, -1] = f_bias_init | |
self.W = W | |
self.d_input, self.d_hidden, self.name = d_input, d_hidden, name | |
self.forget() | |
def __call__(self, X): | |
X = X[0] | |
B = X.shape[1] | |
d_input, d_hidden = self.d_input, self.d_hidden | |
if self.t == 0: | |
self.c_acc[-1] = np.zeros((d_hidden, B)) | |
self.H_p = np.zeros((d_hidden, B)) | |
t = self.t | |
inp = np.zeros((d_input + d_hidden + 1, B)) | |
inp[:d_input] = X | |
inp[d_input:-1] = self.H_p | |
V = np.dot(self.W, inp) | |
V[:d_hidden] = tanh(V[:d_hidden]) | |
V[d_hidden:] = sigmoid(V[d_hidden:]) | |
Z, I, F, O = np.split(V, 4, axis=0) | |
c = Z * I + F * self.c_acc[t-1] | |
C = tanh(c) | |
H = O * C | |
# accumulate for backprop | |
self.c_acc[t] = c; self.C_acc[t] = C; self.Z_acc[t] = Z | |
self.I_acc[t] = I; self.F_acc[t] = F; self.O_acc[t] = O | |
self.inp_acc[t] = inp; self.H_p = H; self.t += 1 | |
return H[np.newaxis] | |
def forward(self, X): | |
T, n, B = X.shape | |
H = np.empty((T, self.d_hidden, B)) | |
for t in xrange(T): | |
H[t] = self.__call__(X[t:t+1]) | |
return H | |
def backward(self, dH): | |
T, _, B = dH.shape | |
d_input, d_hidden = self.d_input, self.d_hidden | |
dW = np.zeros_like(self.W) | |
dX = np.zeros((T, d_input, B)) | |
dh_p = np.zeros((d_hidden, B)) | |
dc_p = np.zeros((d_hidden, B)) | |
for t in reversed(xrange(T)): | |
c = self.c_acc[t]; C = self.C_acc[t]; Z = self.Z_acc[t] | |
I = self.I_acc[t]; F = self.F_acc[t]; O = self.O_acc[t] | |
inp = self.inp_acc[t] | |
dh = dH[t] + dh_p | |
dO = C * dh | |
dC = O * dh | |
dc = (1.0 - C ** 2) * dC | |
dc = dc + dc_p | |
dF = self.c_acc[t-1] * dc | |
dc_p = F * dc | |
dI = Z * dc | |
dZ = I * dc | |
dz = dtanh(dZ, Z) | |
di = dsigmoid(dI, I) | |
df = dsigmoid(dF, F) | |
do = dsigmoid(dO, O) | |
dV = np.concatenate([dz, di, df, do], axis=0) | |
dW += np.dot(dV, inp.T) | |
dinp = np.dot(self.W.T, dV) | |
dX[t] += dinp[:d_input] | |
dh_p = dinp[d_input:-1] | |
self.dW = dW | |
self.forget() | |
return dX | |
def set_parameters(self, P): | |
self.W = np.reshape(P, self.W.shape) | |
def get_parameters(self): | |
return self.W.flatten() | |
def get_gradients(self): | |
dP = self.dW.flatten() | |
dW = None | |
dX = None | |
dh_p = None | |
dc_p = None | |
return dP | |
def forget(self): | |
self.t = 0 | |
self.c_acc = {}; self.C_acc = {}; self.Z_acc = {} | |
self.I_acc = {}; self.F_acc = {}; self.O_acc = {} | |
self.inp_acc = {} | |
self.H_p = None | |
def __repr__(self): | |
return 'Layer: ' + self.name + '\tNumber of Parameters: ' + str(self.W.size) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def orthogonal(shape): | |
""" | |
taken from: https://github.com/Lasagne/Lasagne/blob/master/lasagne/init.py#L327-L367 | |
""" | |
a = np.random.normal(0.0, 1.0, shape) | |
u, _, v = np.linalg.svd(a, full_matrices=False) | |
q = u if u.shape == shape else v # pick the one with the correct shape | |
return q | |
def tanh(X): | |
return np.tanh(X) | |
def sigmoid(X): | |
return 1.0 / (1.0 + np.exp(-X)) | |
def dtanh(dY, Y): | |
return (1.0 - Y ** 2) * dY | |
def dsigmoid(dY, Y): | |
return Y * (1.0 - Y) * dY |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hey can you put up an example file to use this for a nlp task?