Created
January 9, 2020 05:37
-
-
Save pekaalto/026d0248b7a5477380dd21c4ca637c09 to your computer and use it in GitHub Desktop.
Investigate keras-lstm inputs, outputs and weights.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Investigate keras-lstm inputs, outputs and weights. | |
Needs tensorflow 2.0 | |
Note: The explanation of weights matches the CPU-implementation of LSTM-layer. | |
In GPU-implementation the weights are organized slightly differently | |
""" | |
import numpy as np | |
import tensorflow as tf | |
from scipy.special import expit as sigmoid | |
LSTM_UNITS = 5 | |
TIME_STEPS = 3 | |
INPUT_DIM = 6 | |
BATCH_SIZE = 2 | |
lstm_layer = tf.keras.layers.LSTM( | |
units=LSTM_UNITS, | |
return_sequences=True, | |
return_state=True, | |
use_bias=True, | |
bias_initializer="uniform", | |
) | |
input_x_tensor = tf.random.normal(shape=(BATCH_SIZE, TIME_STEPS, INPUT_DIM)) | |
initial_c_state_tensor = tf.random.normal(shape=(BATCH_SIZE, LSTM_UNITS)) | |
initial_h_state_tensor = tf.random.normal(shape=(BATCH_SIZE, LSTM_UNITS)) | |
h_state_sequence, h_state_last, cell_state_last = lstm_layer( | |
input_x_tensor, initial_state=[initial_h_state_tensor, initial_c_state_tensor] | |
) | |
np.testing.assert_array_equal(h_state_sequence[:, -1, :], h_state_last) | |
assert h_state_sequence.shape == (BATCH_SIZE, TIME_STEPS, LSTM_UNITS) | |
assert h_state_last.shape == (BATCH_SIZE, LSTM_UNITS) | |
assert cell_state_last.shape == (BATCH_SIZE, LSTM_UNITS) | |
lstm_weights = lstm_layer.get_weights() | |
assert [w.shape for w in lstm_weights] == [ | |
(INPUT_DIM, 4 * LSTM_UNITS), | |
(LSTM_UNITS, 4 * LSTM_UNITS), | |
(4 * LSTM_UNITS,), | |
] | |
kernel, recurrent_kernel, bias = lstm_layer.get_weights() | |
big_w = np.concatenate([recurrent_kernel, kernel], axis=0).T | |
W_i, W_f, W_c, W_o = np.split(big_w, indices_or_sections=4, axis=0) | |
b_i, b_f, b_c, b_o = np.split(bias, indices_or_sections=4, axis=0) | |
for w in [W_i, W_f, W_c, W_o]: | |
assert w.shape == (LSTM_UNITS, INPUT_DIM + LSTM_UNITS) | |
for b in [b_i, b_f, b_c, b_o]: | |
assert b.shape == (LSTM_UNITS,) | |
class LstmSimpleForward: | |
def __init__(self, W_i, W_f, W_c, W_o, b_i, b_f, b_c, b_o): | |
""" | |
W's have shape (LSTM_UNITS, INPUT_DIM + LSTM_UNITS) | |
b's have shape (LSTM_UNITS,) | |
""" | |
self.W_i = W_i | |
self.W_f = W_f | |
self.W_c = W_c | |
self.W_o = W_o | |
self.b_i = b_i | |
self.b_f = b_f | |
self.b_c = b_c | |
self.b_o = b_o | |
def step_one(self, h_t1, c_t1, xt): | |
""" | |
Calculates one time-step in lstm | |
:param h_t1: shape [BATCH_SIZE, LSTM_UNITS] | |
:param c_t1: shape [BATCH_SIZE, LSTM_UNITS] | |
:param xt: shape [BATCH_SIZE, INPUT_DIM] | |
:return: new (h-state, c-state) -pair | |
""" | |
# x_t shape | |
# h_t1 shape | |
# c_t1 shape [BATCH_SIZE, LSTM_UNITS] | |
# [BATCH_SIZE, LSTM_UNITS + INPUT_DIM] | |
hx = np.concatenate([h_t1, xt], axis=-1) | |
# Note that we could also concatenate the weights | |
# into one big matrix and split the result. | |
# That would be cleaner implmenetation | |
# but we will want to align here with the operations described | |
# https://colah.github.io/posts/2015-08-Understanding-LSTMs/ | |
i_raw, f_raw, c_hat_raw, o_raw = [ | |
(np.dot(hx, W.T) + b) | |
for (W, b) in zip( | |
[self.W_i, self.W_f, self.W_c, self.W_o], | |
[self.b_i, self.b_f, self.b_c, self.b_o], | |
) | |
] | |
i = sigmoid(i_raw) | |
f = sigmoid(f_raw) | |
o = sigmoid(o_raw) | |
c_hat = np.tanh(c_hat_raw) | |
c = f * c_t1 + i * c_hat | |
h = o * np.tanh(c) | |
return [np.array(t) for t in [h, c]] | |
def step_all(self, input_x, initial_h_state, initial_c_state): | |
timesteps = input_x.shape[1] | |
h_state_sequence = [] | |
h_state, c_state = initial_h_state, initial_c_state | |
for i in range(timesteps): | |
h_state, c_state = self.step_one(h_state, c_state, input_x[:, i, :]) | |
h_state_sequence.append(h_state) | |
return ( | |
# swap back to batch-major from time-major | |
np.array(h_state_sequence).swapaxes(0, 1), | |
c_state, | |
) | |
h_state_sequence_2, cell_state_manual_2 = LstmSimpleForward( | |
W_i, W_f, W_c, W_o, b_i, b_f, b_c, b_o | |
).step_all( | |
input_x=input_x_tensor.numpy(), | |
initial_h_state=initial_h_state_tensor.numpy(), | |
initial_c_state=initial_c_state_tensor.numpy(), | |
) | |
np.testing.assert_almost_equal(cell_state_last, cell_state_manual_2) | |
np.testing.assert_almost_equal(h_state_sequence, h_state_sequence_2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment