Created
July 22, 2019 06:42
-
-
Save himaprasoonpt/4c63f5bf2ff02dad58588babcb996b66 to your computer and use it in GitHub Desktop.
Understand return_sequences and return_state in Tensorflow 2.0 Keras RNN layer.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tensorflow.python import keras | |
output_dim = 3 | |
time_steps = 10 | |
input_dim = 4 | |
cells = [ | |
keras.layers.LSTMCell(10, name="l1"), | |
keras.layers.SimpleRNNCell(11, name="l2"), | |
keras.layers.LSTMCell(12, name="l3"), | |
] | |
inputs = keras.Input((time_steps, input_dim)) | |
simple = keras.layers.RNN(cells, return_sequences=False, return_state=False, name="A")(inputs) | |
# Returns a tensor of shape (None,12) which is the output of the last lstm `l3' for the last time step [12 = units of l3 lstm] | |
print(simple) | |
sequence = keras.layers.RNN(cells, return_sequences=True, return_state=False, name="B")(inputs) | |
# Returns a tensor of shape (None,10, 12) which is the output of the last lstm for each time step | |
print(sequence) | |
output_plus_states = keras.layers.RNN(cells, return_sequences=False, return_state=True, name="C")(inputs) | |
print(len(output_plus_states)) # 4 outputs | |
""" | |
Returns a list of length 4 | |
1. Output of the last lstm as in `simple` scenario [last time step] | |
The other 3 are states of each cell (3 because there are 3 cells) | |
2. Returns the final state of l1 lstm cell . its a list of length 2 as lstm has two states | |
3. Returns the final state of l2 rnn cell . its a single tensor | |
4. Returns the final state of l3 lstm cell . its a list of length 2 as lstm has two states | |
""" | |
print(output_plus_states) | |
output_plus_states_plus_sequences = keras.layers.RNN(cells, return_sequences=True, return_state=True, name="C")(inputs) | |
print(len(output_plus_states_plus_sequences)) # 4 outputs | |
""" | |
Returns a list of length 4 | |
1. Returns a tensor of shape (None,10, 12) which is the output of the last lstm for each time step | |
The other 3 are states of each cell (3 because there are 3 cells) | |
2. Returns the final state of l1 lstm cell . its a list of length 2 as lstm has two states | |
3. Returns the final state of l2 rnn cell . its a single tensor | |
4. Returns the final state of l3 lstm cell . its a list of length 2 as lstm has two states | |
""" | |
print(output_plus_states_plus_sequences) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment