Skip to content

Instantly share code, notes, and snippets.

@himaprasoonpt
Created July 22, 2019 06:42
Show Gist options
  • Save himaprasoonpt/4c63f5bf2ff02dad58588babcb996b66 to your computer and use it in GitHub Desktop.
Save himaprasoonpt/4c63f5bf2ff02dad58588babcb996b66 to your computer and use it in GitHub Desktop.
Understand return_sequences and return_state in Tensorflow 2.0 Keras RNN layer.
from tensorflow.python import keras
output_dim = 3
time_steps = 10
input_dim = 4
cells = [
keras.layers.LSTMCell(10, name="l1"),
keras.layers.SimpleRNNCell(11, name="l2"),
keras.layers.LSTMCell(12, name="l3"),
]
inputs = keras.Input((time_steps, input_dim))
simple = keras.layers.RNN(cells, return_sequences=False, return_state=False, name="A")(inputs)
# Returns a tensor of shape (None,12) which is the output of the last lstm `l3' for the last time step [12 = units of l3 lstm]
print(simple)
sequence = keras.layers.RNN(cells, return_sequences=True, return_state=False, name="B")(inputs)
# Returns a tensor of shape (None,10, 12) which is the output of the last lstm for each time step
print(sequence)
output_plus_states = keras.layers.RNN(cells, return_sequences=False, return_state=True, name="C")(inputs)
print(len(output_plus_states)) # 4 outputs
"""
Returns a list of length 4
1. Output of the last lstm as in `simple` scenario [last time step]
The other 3 are states of each cell (3 because there are 3 cells)
2. Returns the final state of l1 lstm cell . its a list of length 2 as lstm has two states
3. Returns the final state of l2 rnn cell . its a single tensor
4. Returns the final state of l3 lstm cell . its a list of length 2 as lstm has two states
"""
print(output_plus_states)
output_plus_states_plus_sequences = keras.layers.RNN(cells, return_sequences=True, return_state=True, name="C")(inputs)
print(len(output_plus_states_plus_sequences)) # 4 outputs
"""
Returns a list of length 4
1. Returns a tensor of shape (None,10, 12) which is the output of the last lstm for each time step
The other 3 are states of each cell (3 because there are 3 cells)
2. Returns the final state of l1 lstm cell . its a list of length 2 as lstm has two states
3. Returns the final state of l2 rnn cell . its a single tensor
4. Returns the final state of l3 lstm cell . its a list of length 2 as lstm has two states
"""
print(output_plus_states_plus_sequences)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment