-
-
Save bzamecnik/8ed16e361a0a6e80e2a4a259222f101e to your computer and use it in GitHub Desktop.
____________________________________________________________________________________________________ | |
Layer (type) Output Shape Param # Connected to | |
==================================================================================================== | |
input_1 (InputLayer) (None, 32, 10) 0 | |
____________________________________________________________________________________________________ | |
lstm_1 (LSTM) (None, 32, 10) 840 input_1[0][0] | |
____________________________________________________________________________________________________ | |
add_1 (Add) (None, 32, 10) 0 input_1[0][0] | |
lstm_1[0][0] | |
____________________________________________________________________________________________________ | |
lstm_2 (LSTM) (None, 32, 10) 840 add_1[0][0] | |
____________________________________________________________________________________________________ | |
add_2 (Add) (None, 32, 10) 0 add_1[0][0] | |
lstm_2[0][0] | |
____________________________________________________________________________________________________ | |
lstm_3 (LSTM) (None, 32, 10) 840 add_2[0][0] | |
____________________________________________________________________________________________________ | |
add_3 (Add) (None, 32, 10) 0 add_2[0][0] | |
lstm_3[0][0] | |
____________________________________________________________________________________________________ | |
lstm_4 (LSTM) (None, 32, 10) 840 add_3[0][0] | |
____________________________________________________________________________________________________ | |
add_4 (Add) (None, 32, 10) 0 add_3[0][0] | |
lstm_4[0][0] | |
____________________________________________________________________________________________________ | |
lstm_5 (LSTM) (None, 32, 10) 840 add_4[0][0] | |
____________________________________________________________________________________________________ | |
add_5 (Add) (None, 32, 10) 0 add_4[0][0] | |
lstm_5[0][0] | |
____________________________________________________________________________________________________ | |
lstm_6 (LSTM) (None, 32, 10) 840 add_5[0][0] | |
____________________________________________________________________________________________________ | |
add_6 (Add) (None, 32, 10) 0 add_5[0][0] | |
lstm_6[0][0] | |
____________________________________________________________________________________________________ | |
lstm_7 (LSTM) (None, 32, 10) 840 add_6[0][0] | |
____________________________________________________________________________________________________ | |
add_7 (Add) (None, 32, 10) 0 add_6[0][0] | |
lstm_7[0][0] | |
____________________________________________________________________________________________________ | |
lambda_1 (Lambda) (None, 10) 0 add_7[0][0] | |
____________________________________________________________________________________________________ | |
lstm_8 (LSTM) (None, 10) 840 add_7[0][0] | |
____________________________________________________________________________________________________ | |
add_8 (Add) (None, 10) 0 lambda_1[0][0] | |
lstm_8[0][0] | |
==================================================================================================== | |
Total params: 6,720 | |
Trainable params: 6,720 | |
Non-trainable params: 0 | |
____________________________________________________________________________________________________ |
# Stacked LSTM with residual connections in depth direction. | |
# | |
# Naturally LSTM has something like residual connections in time. | |
# Here we add residual connection in depth. | |
# | |
# Inspired by Google's Neural Machine Translation System (https://arxiv.org/abs/1609.08144). | |
# They observed that residual connections allow them to use much deeper stacked RNNs. | |
# Without residual connections they were limited to around 4 layers of depth. | |
# | |
# It uses Keras 2 API. | |
from keras.layers import LSTM, Lambda | |
from keras.layers.merge import add | |
def make_residual_lstm_layers(input, rnn_width, rnn_depth, rnn_dropout): | |
""" | |
The intermediate LSTM layers return sequences, while the last returns a single element. | |
The input is also a sequence. In order to match the shape of input and output of the LSTM | |
to sum them we can do it only for all layers but the last. | |
""" | |
x = input | |
for i in range(rnn_depth): | |
return_sequences = i < rnn_depth - 1 | |
x_rnn = LSTM(rnn_width, recurrent_dropout=rnn_dropout, dropout=rnn_dropout, return_sequences=return_sequences)(x) | |
if return_sequences: | |
# Intermediate layers return sequences, input is also a sequence. | |
if i > 0 or input.shape[-1] == rnn_width: | |
x = add([x, x_rnn]) | |
else: | |
# Note that the input size and RNN output has to match, due to the sum operation. | |
# If we want different rnn_width, we'd have to perform the sum from layer 2 on. | |
x = x_rnn | |
else: | |
# Last layer does not return sequences, just the last element | |
# so we select only the last element of the previous output. | |
def slice_last(x): | |
return x[..., -1, :] | |
x = add([Lambda(slice_last)(x), x_rnn]) | |
return x | |
if __name__ == '__main__': | |
# Example usage | |
from keras.layers import Input | |
from keras.models import Model | |
input = Input(shape=(32, 10)) | |
output = make_residual_lstm_layers(input, rnn_width=10, rnn_depth=8, rnn_dropout=0.2) | |
model = Model(inputs=input, outputs=output) | |
model.summary() |
I fixed several bugs (since the code was not properly tested...), upgraded to Keras 2 API and added support to make residual connections at the last layer (just select the last element of the previous output sequence) and also make residual connection at the input optional only if the input matches the RNN output size.
@Seanny123 Thanks for a tip. This was Google's Neural Machine Translation System (https://arxiv.org/abs/1609.08144). What is different in their architecture?
Do you mean 'and' instead of 'or' in line 27?
Please tell me how to implement the same architecture only using the Bidirectional LSTM layer?
Please tell me how to implement the same architecture only using the Bidirectional LSTM layer?
https://colab.research.google.com/drive/1qNWzuYQdwG8LQT2yT5SwRb9J-aqJIFND?usp=sharing
@thingumajig - aah, thanks.