-
-
Save kirk86/f6ef0baabff07d748c3e7c9c0a744ee9 to your computer and use it in GitHub Desktop.
An efficient, batched LSTM.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is a batched LSTM forward and backward pass | |
""" | |
import numpy as np | |
import code | |
class LSTM: | |
@staticmethod | |
def init(input_size, hidden_size, fancy_forget_bias_init = 3): | |
""" | |
Initialize parameters of the LSTM (both weights and biases in one matrix) | |
One might way to have a positive fancy_forget_bias_init number (e.g. maybe even up to 5, in some papers) | |
""" | |
# +1 for the biases, which will be the first row of WLSTM | |
WLSTM = np.random.randn(input_size + hidden_size + 1, 4 * hidden_size) / np.sqrt(input_size + hidden_size) | |
WLSTM[0,:] = 0 # initialize biases to zero | |
if fancy_forget_bias_init != 0: | |
# forget gates get little bit negative bias initially to encourage them to be turned off | |
# remember that due to Xavier initialization above, the raw output activations from gates before | |
# nonlinearity are zero mean and on order of standard deviation ~1 | |
WLSTM[0,hidden_size:2*hidden_size] = fancy_forget_bias_init | |
return WLSTM | |
@staticmethod | |
def forward(X, WLSTM, c0 = None, h0 = None): | |
""" | |
X should be of shape (n,b,input_size), where n = length of sequence, b = batch size | |
""" | |
n,b,input_size = X.shape | |
d = WLSTM.shape[1]/4 # hidden size | |
if c0 is None: c0 = np.zeros((b,d)) | |
if h0 is None: h0 = np.zeros((b,d)) | |
# Perform the LSTM forward pass with X as the input | |
xphpb = WLSTM.shape[0] # x plus h plus bias, lol | |
Hin = np.zeros((n, b, xphpb)) # input [1, xt, ht-1] to each tick of the LSTM | |
Hout = np.zeros((n, b, d)) # hidden representation of the LSTM (gated cell content) | |
IFOG = np.zeros((n, b, d * 4)) # input, forget, output, gate (IFOG) | |
IFOGf = np.zeros((n, b, d * 4)) # after nonlinearity | |
C = np.zeros((n, b, d)) # cell content | |
Ct = np.zeros((n, b, d)) # tanh of cell content | |
for t in xrange(n): | |
# concat [x,h] as input to the LSTM | |
prevh = Hout[t-1] if t > 0 else h0 | |
Hin[t,:,0] = 1 # bias | |
Hin[t,:,1:input_size+1] = X[t] | |
Hin[t,:,input_size+1:] = prevh | |
# compute all gate activations. dots: (most work is this line) | |
IFOG[t] = Hin[t].dot(WLSTM) | |
# non-linearities | |
IFOGf[t,:,:3*d] = 1.0/(1.0+np.exp(-IFOG[t,:,:3*d])) # sigmoids; these are the gates | |
IFOGf[t,:,3*d:] = np.tanh(IFOG[t,:,3*d:]) # tanh | |
# compute the cell activation | |
prevc = C[t-1] if t > 0 else c0 | |
C[t] = IFOGf[t,:,:d] * IFOGf[t,:,3*d:] + IFOGf[t,:,d:2*d] * prevc | |
Ct[t] = np.tanh(C[t]) | |
Hout[t] = IFOGf[t,:,2*d:3*d] * Ct[t] | |
cache = {} | |
cache['WLSTM'] = WLSTM | |
cache['Hout'] = Hout | |
cache['IFOGf'] = IFOGf | |
cache['IFOG'] = IFOG | |
cache['C'] = C | |
cache['Ct'] = Ct | |
cache['Hin'] = Hin | |
cache['c0'] = c0 | |
cache['h0'] = h0 | |
# return C[t], as well so we can continue LSTM with prev state init if needed | |
return Hout, C[t], Hout[t], cache | |
@staticmethod | |
def backward(dHout_in, cache, dcn = None, dhn = None): | |
WLSTM = cache['WLSTM'] | |
Hout = cache['Hout'] | |
IFOGf = cache['IFOGf'] | |
IFOG = cache['IFOG'] | |
C = cache['C'] | |
Ct = cache['Ct'] | |
Hin = cache['Hin'] | |
c0 = cache['c0'] | |
h0 = cache['h0'] | |
n,b,d = Hout.shape | |
input_size = WLSTM.shape[0] - d - 1 # -1 due to bias | |
# backprop the LSTM | |
dIFOG = np.zeros(IFOG.shape) | |
dIFOGf = np.zeros(IFOGf.shape) | |
dWLSTM = np.zeros(WLSTM.shape) | |
dHin = np.zeros(Hin.shape) | |
dC = np.zeros(C.shape) | |
dX = np.zeros((n,b,input_size)) | |
dh0 = np.zeros((b, d)) | |
dc0 = np.zeros((b, d)) | |
dHout = dHout_in.copy() # make a copy so we don't have any funny side effects | |
if dcn is not None: dC[n-1] += dcn.copy() # carry over gradients from later | |
if dhn is not None: dHout[n-1] += dhn.copy() | |
for t in reversed(xrange(n)): | |
tanhCt = Ct[t] | |
dIFOGf[t,:,2*d:3*d] = tanhCt * dHout[t] | |
# backprop tanh non-linearity first then continue backprop | |
dC[t] += (1-tanhCt**2) * (IFOGf[t,:,2*d:3*d] * dHout[t]) | |
if t > 0: | |
dIFOGf[t,:,d:2*d] = C[t-1] * dC[t] | |
dC[t-1] += IFOGf[t,:,d:2*d] * dC[t] | |
else: | |
dIFOGf[t,:,d:2*d] = c0 * dC[t] | |
dc0 = IFOGf[t,:,d:2*d] * dC[t] | |
dIFOGf[t,:,:d] = IFOGf[t,:,3*d:] * dC[t] | |
dIFOGf[t,:,3*d:] = IFOGf[t,:,:d] * dC[t] | |
# backprop activation functions | |
dIFOG[t,:,3*d:] = (1 - IFOGf[t,:,3*d:] ** 2) * dIFOGf[t,:,3*d:] | |
y = IFOGf[t,:,:3*d] | |
dIFOG[t,:,:3*d] = (y*(1.0-y)) * dIFOGf[t,:,:3*d] | |
# backprop matrix multiply | |
dWLSTM += np.dot(Hin[t].transpose(), dIFOG[t]) | |
dHin[t] = dIFOG[t].dot(WLSTM.transpose()) | |
# backprop the identity transforms into Hin | |
dX[t] = dHin[t,:,1:input_size+1] | |
if t > 0: | |
dHout[t-1,:] += dHin[t,:,input_size+1:] | |
else: | |
dh0 += dHin[t,:,input_size+1:] | |
return dX, dWLSTM, dc0, dh0 | |
# ------------------- | |
# TEST CASES | |
# ------------------- | |
def checkSequentialMatchesBatch(): | |
""" check LSTM I/O forward/backward interactions """ | |
n,b,d = (5, 3, 4) # sequence length, batch size, hidden size | |
input_size = 10 | |
WLSTM = LSTM.init(input_size, d) # input size, hidden size | |
X = np.random.randn(n,b,input_size) | |
h0 = np.random.randn(b,d) | |
c0 = np.random.randn(b,d) | |
# sequential forward | |
cprev = c0 | |
hprev = h0 | |
caches = [{} for t in xrange(n)] | |
Hcat = np.zeros((n,b,d)) | |
for t in xrange(n): | |
xt = X[t:t+1] | |
_, cprev, hprev, cache = LSTM.forward(xt, WLSTM, cprev, hprev) | |
caches[t] = cache | |
Hcat[t] = hprev | |
# sanity check: perform batch forward to check that we get the same thing | |
H, _, _, batch_cache = LSTM.forward(X, WLSTM, c0, h0) | |
assert np.allclose(H, Hcat), 'Sequential and Batch forward don''t match!' | |
# eval loss | |
wrand = np.random.randn(*Hcat.shape) | |
loss = np.sum(Hcat * wrand) | |
dH = wrand | |
# get the batched version gradients | |
BdX, BdWLSTM, Bdc0, Bdh0 = LSTM.backward(dH, batch_cache) | |
# now perform sequential backward | |
dX = np.zeros_like(X) | |
dWLSTM = np.zeros_like(WLSTM) | |
dc0 = np.zeros_like(c0) | |
dh0 = np.zeros_like(h0) | |
dcnext = None | |
dhnext = None | |
for t in reversed(xrange(n)): | |
dht = dH[t].reshape(1, b, d) | |
dx, dWLSTMt, dcprev, dhprev = LSTM.backward(dht, caches[t], dcnext, dhnext) | |
dhnext = dhprev | |
dcnext = dcprev | |
dWLSTM += dWLSTMt # accumulate LSTM gradient | |
dX[t] = dx[0] | |
if t == 0: | |
dc0 = dcprev | |
dh0 = dhprev | |
# and make sure the gradients match | |
print 'Making sure batched version agrees with sequential version: (should all be True)' | |
print np.allclose(BdX, dX) | |
print np.allclose(BdWLSTM, dWLSTM) | |
print np.allclose(Bdc0, dc0) | |
print np.allclose(Bdh0, dh0) | |
def checkBatchGradient(): | |
""" check that the batch gradient is correct """ | |
# lets gradient check this beast | |
n,b,d = (5, 3, 4) # sequence length, batch size, hidden size | |
input_size = 10 | |
WLSTM = LSTM.init(input_size, d) # input size, hidden size | |
X = np.random.randn(n,b,input_size) | |
h0 = np.random.randn(b,d) | |
c0 = np.random.randn(b,d) | |
# batch forward backward | |
H, Ct, Ht, cache = LSTM.forward(X, WLSTM, c0, h0) | |
wrand = np.random.randn(*H.shape) | |
loss = np.sum(H * wrand) # weighted sum is a nice hash to use I think | |
dH = wrand | |
dX, dWLSTM, dc0, dh0 = LSTM.backward(dH, cache) | |
def fwd(): | |
h,_,_,_ = LSTM.forward(X, WLSTM, c0, h0) | |
return np.sum(h * wrand) | |
# now gradient check all | |
delta = 1e-5 | |
rel_error_thr_warning = 1e-2 | |
rel_error_thr_error = 1 | |
tocheck = [X, WLSTM, c0, h0] | |
grads_analytic = [dX, dWLSTM, dc0, dh0] | |
names = ['X', 'WLSTM', 'c0', 'h0'] | |
for j in xrange(len(tocheck)): | |
mat = tocheck[j] | |
dmat = grads_analytic[j] | |
name = names[j] | |
# gradcheck | |
for i in xrange(mat.size): | |
old_val = mat.flat[i] | |
mat.flat[i] = old_val + delta | |
loss0 = fwd() | |
mat.flat[i] = old_val - delta | |
loss1 = fwd() | |
mat.flat[i] = old_val | |
grad_analytic = dmat.flat[i] | |
grad_numerical = (loss0 - loss1) / (2 * delta) | |
if grad_numerical == 0 and grad_analytic == 0: | |
rel_error = 0 # both are zero, OK. | |
status = 'OK' | |
elif abs(grad_numerical) < 1e-7 and abs(grad_analytic) < 1e-7: | |
rel_error = 0 # not enough precision to check this | |
status = 'VAL SMALL WARNING' | |
else: | |
rel_error = abs(grad_analytic - grad_numerical) / abs(grad_numerical + grad_analytic) | |
status = 'OK' | |
if rel_error > rel_error_thr_warning: status = 'WARNING' | |
if rel_error > rel_error_thr_error: status = '!!!!! NOTOK' | |
# print stats | |
print '%s checking param %s index %s (val = %+8f), analytic = %+8f, numerical = %+8f, relative error = %+8f' \ | |
% (status, name, `np.unravel_index(i, mat.shape)`, old_val, grad_analytic, grad_numerical, rel_error) | |
if __name__ == "__main__": | |
checkSequentialMatchesBatch() | |
raw_input('check OK, press key to continue to gradient check') | |
checkBatchGradient() | |
print 'every line should start with OK. Have a nice day!' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment