Last active
May 31, 2024 16:42
-
-
Save bzamecnik/8ed16e361a0a6e80e2a4a259222f101e to your computer and use it in GitHub Desktop.
Residual LSTM in Keras
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
____________________________________________________________________________________________________ | |
Layer (type) Output Shape Param # Connected to | |
==================================================================================================== | |
input_1 (InputLayer) (None, 32, 10) 0 | |
____________________________________________________________________________________________________ | |
lstm_1 (LSTM) (None, 32, 10) 840 input_1[0][0] | |
____________________________________________________________________________________________________ | |
add_1 (Add) (None, 32, 10) 0 input_1[0][0] | |
lstm_1[0][0] | |
____________________________________________________________________________________________________ | |
lstm_2 (LSTM) (None, 32, 10) 840 add_1[0][0] | |
____________________________________________________________________________________________________ | |
add_2 (Add) (None, 32, 10) 0 add_1[0][0] | |
lstm_2[0][0] | |
____________________________________________________________________________________________________ | |
lstm_3 (LSTM) (None, 32, 10) 840 add_2[0][0] | |
____________________________________________________________________________________________________ | |
add_3 (Add) (None, 32, 10) 0 add_2[0][0] | |
lstm_3[0][0] | |
____________________________________________________________________________________________________ | |
lstm_4 (LSTM) (None, 32, 10) 840 add_3[0][0] | |
____________________________________________________________________________________________________ | |
add_4 (Add) (None, 32, 10) 0 add_3[0][0] | |
lstm_4[0][0] | |
____________________________________________________________________________________________________ | |
lstm_5 (LSTM) (None, 32, 10) 840 add_4[0][0] | |
____________________________________________________________________________________________________ | |
add_5 (Add) (None, 32, 10) 0 add_4[0][0] | |
lstm_5[0][0] | |
____________________________________________________________________________________________________ | |
lstm_6 (LSTM) (None, 32, 10) 840 add_5[0][0] | |
____________________________________________________________________________________________________ | |
add_6 (Add) (None, 32, 10) 0 add_5[0][0] | |
lstm_6[0][0] | |
____________________________________________________________________________________________________ | |
lstm_7 (LSTM) (None, 32, 10) 840 add_6[0][0] | |
____________________________________________________________________________________________________ | |
add_7 (Add) (None, 32, 10) 0 add_6[0][0] | |
lstm_7[0][0] | |
____________________________________________________________________________________________________ | |
lambda_1 (Lambda) (None, 10) 0 add_7[0][0] | |
____________________________________________________________________________________________________ | |
lstm_8 (LSTM) (None, 10) 840 add_7[0][0] | |
____________________________________________________________________________________________________ | |
add_8 (Add) (None, 10) 0 lambda_1[0][0] | |
lstm_8[0][0] | |
==================================================================================================== | |
Total params: 6,720 | |
Trainable params: 6,720 | |
Non-trainable params: 0 | |
____________________________________________________________________________________________________ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Stacked LSTM with residual connections in depth direction. | |
# | |
# Naturally LSTM has something like residual connections in time. | |
# Here we add residual connection in depth. | |
# | |
# Inspired by Google's Neural Machine Translation System (https://arxiv.org/abs/1609.08144). | |
# They observed that residual connections allow them to use much deeper stacked RNNs. | |
# Without residual connections they were limited to around 4 layers of depth. | |
# | |
# It uses Keras 2 API. | |
from keras.layers import LSTM, Lambda | |
from keras.layers.merge import add | |
def make_residual_lstm_layers(input, rnn_width, rnn_depth, rnn_dropout): | |
""" | |
The intermediate LSTM layers return sequences, while the last returns a single element. | |
The input is also a sequence. In order to match the shape of input and output of the LSTM | |
to sum them we can do it only for all layers but the last. | |
""" | |
x = input | |
for i in range(rnn_depth): | |
return_sequences = i < rnn_depth - 1 | |
x_rnn = LSTM(rnn_width, recurrent_dropout=rnn_dropout, dropout=rnn_dropout, return_sequences=return_sequences)(x) | |
if return_sequences: | |
# Intermediate layers return sequences, input is also a sequence. | |
if i > 0 or input.shape[-1] == rnn_width: | |
x = add([x, x_rnn]) | |
else: | |
# Note that the input size and RNN output has to match, due to the sum operation. | |
# If we want different rnn_width, we'd have to perform the sum from layer 2 on. | |
x = x_rnn | |
else: | |
# Last layer does not return sequences, just the last element | |
# so we select only the last element of the previous output. | |
def slice_last(x): | |
return x[..., -1, :] | |
x = add([Lambda(slice_last)(x), x_rnn]) | |
return x | |
if __name__ == '__main__': | |
# Example usage | |
from keras.layers import Input | |
from keras.models import Model | |
input = Input(shape=(32, 10)) | |
output = make_residual_lstm_layers(input, rnn_width=10, rnn_depth=8, rnn_dropout=0.2) | |
model = Model(inputs=input, outputs=output) | |
model.summary() |
Please tell me how to implement the same architecture only using the Bidirectional LSTM layer?
https://colab.research.google.com/drive/1qNWzuYQdwG8LQT2yT5SwRb9J-aqJIFND?usp=sharing
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Please tell me how to implement the same architecture only using the Bidirectional LSTM layer?