-
-
Save chris-chris/37aae7be91827a799ec3e860211e546f to your computer and use it in GitHub Desktop.
An efficient, batched LSTM.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This is a batched LSTM forward and backward pass | |
""" | |
import numpy as np | |
import code | |
class LSTM: | |
@staticmethod | |
def init(input_size, hidden_size, fancy_forget_bias_init = 3): | |
""" | |
LSTM의 파라미터들을 초기화합니다 (한 행렬 안의 weight과 bias가 들어있음) | |
fancy_forget_bias_init 가 양수인 경우가 있을 수 있습니다. (예를 들어, 어떤 논문에서는 5까지 가능함) | |
""" | |
# bias에는 +1을 해줍니다. bias는 WLSTM의 첫번째 행이 될 것입니다. | |
WLSTM = np.random.randn(input_size + hidden_size + 1, 4 * hidden_size) / np.sqrt(input_size + hidden_size) | |
WLSTM[0,:] = 0 # bias를 모두 0으로 초기화합니다. | |
if fancy_forget_bias_init != 0: | |
# forget gate는 초기화하면서 꺼두기 위해서, 주로 forget gates는 약간 작은 음수로 설정합니다. | |
# Xavier 초기화를 한 것을 기억하세요. | |
# nonlinearity 다음의 raw output activations은 평균이 0이고 표준편차가 ~1입니다. | |
WLSTM[0,hidden_size:2*hidden_size] = fancy_forget_bias_init | |
return WLSTM | |
@staticmethod | |
def forward(X, WLSTM, c0 = None, h0 = None): | |
""" | |
X 는 (n,b,input_size)의 shape를 갖습니다. n = 시퀀스의 길이, b = 배치 사이즈 | |
""" | |
n,b,input_size = X.shape | |
d = WLSTM.shape[1]/4 # 은닉층 크기 | |
if c0 is None: c0 = np.zeros((b,d)) | |
if h0 is None: h0 = np.zeros((b,d)) | |
# X를 입력값으로 LSTM 포워드 패스를 실행합니다. | |
xphpb = WLSTM.shape[0] # x 더하기 h 더하기 bias | |
Hin = np.zeros((n, b, xphpb)) # 입력값 [1, xt, ht-1] 을 각 LSTM tick에 입력 | |
Hout = np.zeros((n, b, d)) # LSTM의 은닉층 (gated cell content) | |
IFOG = np.zeros((n, b, d * 4)) # input, forget, output, gate (IFOG) | |
IFOGf = np.zeros((n, b, d * 4)) # 비선형함수(nonlinearity) 이후 | |
C = np.zeros((n, b, d)) # cell content | |
Ct = np.zeros((n, b, d)) # cell content 의 tanh | |
for t in xrange(n): | |
# [x,h]를 붙여서 LSTM의 입력값으로 활용 | |
prevh = Hout[t-1] if t > 0 else h0 | |
Hin[t,:,0] = 1 # bias | |
Hin[t,:,1:input_size+1] = X[t] | |
Hin[t,:,input_size+1:] = prevh | |
# compute all gate activations. dots: (most work is this line) | |
IFOG[t] = Hin[t].dot(WLSTM) | |
# non-linearities | |
IFOGf[t,:,:3*d] = 1.0/(1.0+np.exp(-IFOG[t,:,:3*d])) # sigmoids; these are the gates | |
IFOGf[t,:,3*d:] = np.tanh(IFOG[t,:,3*d:]) # tanh | |
# compute the cell activation | |
prevc = C[t-1] if t > 0 else c0 | |
C[t] = IFOGf[t,:,:d] * IFOGf[t,:,3*d:] + IFOGf[t,:,d:2*d] * prevc | |
Ct[t] = np.tanh(C[t]) | |
Hout[t] = IFOGf[t,:,2*d:3*d] * Ct[t] | |
cache = {} | |
cache['WLSTM'] = WLSTM | |
cache['Hout'] = Hout | |
cache['IFOGf'] = IFOGf | |
cache['IFOG'] = IFOG | |
cache['C'] = C | |
cache['Ct'] = Ct | |
cache['Hin'] = Hin | |
cache['c0'] = c0 | |
cache['h0'] = h0 | |
# return C[t], as well so we can continue LSTM with prev state init if needed | |
return Hout, C[t], Hout[t], cache | |
@staticmethod | |
def backward(dHout_in, cache, dcn = None, dhn = None): | |
WLSTM = cache['WLSTM'] | |
Hout = cache['Hout'] | |
IFOGf = cache['IFOGf'] | |
IFOG = cache['IFOG'] | |
C = cache['C'] | |
Ct = cache['Ct'] | |
Hin = cache['Hin'] | |
c0 = cache['c0'] | |
h0 = cache['h0'] | |
n,b,d = Hout.shape | |
input_size = WLSTM.shape[0] - d - 1 # -1 due to bias | |
# backprop the LSTM | |
dIFOG = np.zeros(IFOG.shape) | |
dIFOGf = np.zeros(IFOGf.shape) | |
dWLSTM = np.zeros(WLSTM.shape) | |
dHin = np.zeros(Hin.shape) | |
dC = np.zeros(C.shape) | |
dX = np.zeros((n,b,input_size)) | |
dh0 = np.zeros((b, d)) | |
dc0 = np.zeros((b, d)) | |
dHout = dHout_in.copy() # make a copy so we don't have any funny side effects | |
if dcn is not None: dC[n-1] += dcn.copy() # carry over gradients from later | |
if dhn is not None: dHout[n-1] += dhn.copy() | |
for t in reversed(xrange(n)): | |
tanhCt = Ct[t] | |
dIFOGf[t,:,2*d:3*d] = tanhCt * dHout[t] | |
# backprop tanh non-linearity first then continue backprop | |
dC[t] += (1-tanhCt**2) * (IFOGf[t,:,2*d:3*d] * dHout[t]) | |
if t > 0: | |
dIFOGf[t,:,d:2*d] = C[t-1] * dC[t] | |
dC[t-1] += IFOGf[t,:,d:2*d] * dC[t] | |
else: | |
dIFOGf[t,:,d:2*d] = c0 * dC[t] | |
dc0 = IFOGf[t,:,d:2*d] * dC[t] | |
dIFOGf[t,:,:d] = IFOGf[t,:,3*d:] * dC[t] | |
dIFOGf[t,:,3*d:] = IFOGf[t,:,:d] * dC[t] | |
# backprop activation functions | |
dIFOG[t,:,3*d:] = (1 - IFOGf[t,:,3*d:] ** 2) * dIFOGf[t,:,3*d:] | |
y = IFOGf[t,:,:3*d] | |
dIFOG[t,:,:3*d] = (y*(1.0-y)) * dIFOGf[t,:,:3*d] | |
# backprop matrix multiply | |
dWLSTM += np.dot(Hin[t].transpose(), dIFOG[t]) | |
dHin[t] = dIFOG[t].dot(WLSTM.transpose()) | |
# backprop the identity transforms into Hin | |
dX[t] = dHin[t,:,1:input_size+1] | |
if t > 0: | |
dHout[t-1,:] += dHin[t,:,input_size+1:] | |
else: | |
dh0 += dHin[t,:,input_size+1:] | |
return dX, dWLSTM, dc0, dh0 | |
# ------------------- | |
# TEST CASES | |
# ------------------- | |
def checkSequentialMatchesBatch(): | |
""" check LSTM I/O forward/backward interactions """ | |
n,b,d = (5, 3, 4) # sequence length, batch size, hidden size | |
input_size = 10 | |
WLSTM = LSTM.init(input_size, d) # input size, hidden size | |
X = np.random.randn(n,b,input_size) | |
h0 = np.random.randn(b,d) | |
c0 = np.random.randn(b,d) | |
# sequential forward | |
cprev = c0 | |
hprev = h0 | |
caches = [{} for t in xrange(n)] | |
Hcat = np.zeros((n,b,d)) | |
for t in xrange(n): | |
xt = X[t:t+1] | |
_, cprev, hprev, cache = LSTM.forward(xt, WLSTM, cprev, hprev) | |
caches[t] = cache | |
Hcat[t] = hprev | |
# sanity check: perform batch forward to check that we get the same thing | |
H, _, _, batch_cache = LSTM.forward(X, WLSTM, c0, h0) | |
assert np.allclose(H, Hcat), 'Sequential and Batch forward don''t match!' | |
# eval loss | |
wrand = np.random.randn(*Hcat.shape) | |
loss = np.sum(Hcat * wrand) | |
dH = wrand | |
# get the batched version gradients | |
BdX, BdWLSTM, Bdc0, Bdh0 = LSTM.backward(dH, batch_cache) | |
# now perform sequential backward | |
dX = np.zeros_like(X) | |
dWLSTM = np.zeros_like(WLSTM) | |
dc0 = np.zeros_like(c0) | |
dh0 = np.zeros_like(h0) | |
dcnext = None | |
dhnext = None | |
for t in reversed(xrange(n)): | |
dht = dH[t].reshape(1, b, d) | |
dx, dWLSTMt, dcprev, dhprev = LSTM.backward(dht, caches[t], dcnext, dhnext) | |
dhnext = dhprev | |
dcnext = dcprev | |
dWLSTM += dWLSTMt # accumulate LSTM gradient | |
dX[t] = dx[0] | |
if t == 0: | |
dc0 = dcprev | |
dh0 = dhprev | |
# and make sure the gradients match | |
print 'Making sure batched version agrees with sequential version: (should all be True)' | |
print np.allclose(BdX, dX) | |
print np.allclose(BdWLSTM, dWLSTM) | |
print np.allclose(Bdc0, dc0) | |
print np.allclose(Bdh0, dh0) | |
def checkBatchGradient(): | |
""" check that the batch gradient is correct """ | |
# lets gradient check this beast | |
n,b,d = (5, 3, 4) # sequence length, batch size, hidden size | |
input_size = 10 | |
WLSTM = LSTM.init(input_size, d) # input size, hidden size | |
X = np.random.randn(n,b,input_size) | |
h0 = np.random.randn(b,d) | |
c0 = np.random.randn(b,d) | |
# batch forward backward | |
H, Ct, Ht, cache = LSTM.forward(X, WLSTM, c0, h0) | |
wrand = np.random.randn(*H.shape) | |
loss = np.sum(H * wrand) # weighted sum is a nice hash to use I think | |
dH = wrand | |
dX, dWLSTM, dc0, dh0 = LSTM.backward(dH, cache) | |
def fwd(): | |
h,_,_,_ = LSTM.forward(X, WLSTM, c0, h0) | |
return np.sum(h * wrand) | |
# now gradient check all | |
delta = 1e-5 | |
rel_error_thr_warning = 1e-2 | |
rel_error_thr_error = 1 | |
tocheck = [X, WLSTM, c0, h0] | |
grads_analytic = [dX, dWLSTM, dc0, dh0] | |
names = ['X', 'WLSTM', 'c0', 'h0'] | |
for j in xrange(len(tocheck)): | |
mat = tocheck[j] | |
dmat = grads_analytic[j] | |
name = names[j] | |
# gradcheck | |
for i in xrange(mat.size): | |
old_val = mat.flat[i] | |
mat.flat[i] = old_val + delta | |
loss0 = fwd() | |
mat.flat[i] = old_val - delta | |
loss1 = fwd() | |
mat.flat[i] = old_val | |
grad_analytic = dmat.flat[i] | |
grad_numerical = (loss0 - loss1) / (2 * delta) | |
if grad_numerical == 0 and grad_analytic == 0: | |
rel_error = 0 # both are zero, OK. | |
status = 'OK' | |
elif abs(grad_numerical) < 1e-7 and abs(grad_analytic) < 1e-7: | |
rel_error = 0 # not enough precision to check this | |
status = 'VAL SMALL WARNING' | |
else: | |
rel_error = abs(grad_analytic - grad_numerical) / abs(grad_numerical + grad_analytic) | |
status = 'OK' | |
if rel_error > rel_error_thr_warning: status = 'WARNING' | |
if rel_error > rel_error_thr_error: status = '!!!!! NOTOK' | |
# print stats | |
print '%s checking param %s index %s (val = %+8f), analytic = %+8f, numerical = %+8f, relative error = %+8f' \ | |
% (status, name, `np.unravel_index(i, mat.shape)`, old_val, grad_analytic, grad_numerical, rel_error) | |
if __name__ == "__main__": | |
checkSequentialMatchesBatch() | |
raw_input('check OK, press key to continue to gradient check') | |
checkBatchGradient() | |
print 'every line should start with OK. Have a nice day!' |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment