Last active
April 20, 2020 08:12
-
-
Save ChuaCheowHuan/8a3ab68d114b3b1a2689a7cac9bf65e7 to your computer and use it in GitHub Desktop.
Fast-slow LSTM with variational unit (VU)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import numpy as np | |
latent_dim = 2 | |
#class FSRNNCell_VU(tf.contrib.rnn.RNNCell): | |
class FSRNNCell_VU(tf.compat.v1.nn.rnn_cell.RNNCell): | |
def __init__(self, fast_cells, slow_cell, input_keep_prob=1.0, keep_prob=1.0, training=True): | |
"""Initialize the basic Fast-Slow RNN. | |
Args: | |
fast_cells: A list of RNN cells that will be used for the fast RNN. | |
The cells must be callable, implement zero_state() and all have the | |
same hidden size, like for example tf.contrib.rnn.BasicLSTMCell. | |
slow_cell: A single RNN cell for the slow RNN. | |
keep_prob: Keep probability for the non recurrent dropout. Any kind of | |
recurrent dropout should be implemented in the RNN cells. | |
training: If False, no dropout is applied. | |
""" | |
self.fast_layers = len(fast_cells) | |
assert self.fast_layers >= 2, 'At least two fast layers are needed' | |
self.fast_cells = fast_cells | |
self.slow_cell = slow_cell | |
self.keep_prob = keep_prob | |
self.input_keep_prob = input_keep_prob | |
# VU | |
self.S_mean = None | |
self.S_sigma = None | |
#self.S_norm_args = None | |
self.F_mean = None | |
self.F_sigma = None | |
#self.F_norm_args = None | |
if not training: self.keep_prob = 1.0 | |
def __call__(self, inputs, state, scope='FS-RNN'): | |
F_state = state[0] | |
S_state = state[1] | |
# VU | |
#a_w = tf.random_normal_initializer(seed=tf_operation_level_seed+10) | |
a_w = tf.random_normal_initializer(seed=10) | |
with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): | |
inputs = tf.nn.dropout(inputs, self.input_keep_prob) | |
with tf.variable_scope('Fast_0'): | |
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[0], inputs=inputs, initial_state=F_state, time_major=True) | |
F_output_drop = tf.nn.dropout(F_output, self.keep_prob) | |
with tf.variable_scope('Slow'): | |
S_output, S_state = tf.nn.dynamic_rnn(cell=self.slow_cell, inputs=F_output_drop, initial_state=S_state, time_major=True) | |
# VU | |
self.S_mean = tf.layers.dense(S_output, latent_dim, activation=None, kernel_initializer = a_w, name='mean', trainable=True) | |
self.S_sigma = tf.layers.dense(S_output, latent_dim, tf.nn.softplus, kernel_initializer = a_w, name='sigma', trainable=True) | |
#self.S_norm_args = tf.concat([S_mean, S_sigma], 0) | |
S_output_drop = tf.nn.dropout(S_output, self.keep_prob) | |
with tf.variable_scope('Fast_1'): | |
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[1], inputs=S_output_drop, initial_state=F_state, time_major=True) | |
for i in range(2, self.fast_layers): | |
with tf.variable_scope('Fast_' + str(i)): | |
# Input cannot be empty for many RNN cells | |
#F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[i], inputs=F_output[:, 0:1] * 0.0, initial_state=F_state, time_major=True) | |
F_output, F_state = tf.nn.dynamic_rnn(cell=self.fast_cells[i], inputs=F_output, initial_state=F_state, time_major=True) | |
# VU | |
self.F_mean = tf.layers.dense(F_output, latent_dim, activation=None, kernel_initializer = a_w, name='mean', trainable=True) | |
self.F_sigma = tf.layers.dense(F_output, latent_dim, tf.nn.softplus, kernel_initializer = a_w, name='sigma', trainable=True) | |
#self.F_norm_args = tf.concat([F_mean, F_sigma], 0) | |
F_output_drop = tf.nn.dropout(F_output, self.keep_prob) | |
#return F_output_drop, (F_state, S_state) | |
#return F_output_drop, (F_state, S_state), (self.S_norm_args, self.F_norm_args) | |
return F_output_drop, (F_state, S_state), (self.S_mean, self.S_sigma), (self.F_mean, self.F_sigma) | |
def zero_state(self, batch_size, dtype): | |
F_state = self.fast_cells[0].zero_state(batch_size, dtype) | |
S_state = self.slow_cell.zero_state(batch_size, dtype) | |
return (F_state, S_state) | |
batch_size = 3 | |
cell_size = 5 | |
state_size = 7 | |
#Create one Slow and three Fast cells | |
slow = tf.contrib.rnn.BasicLSTMCell(cell_size) # cell_size | |
fast = [tf.contrib.rnn.BasicLSTMCell(cell_size), | |
tf.contrib.rnn.BasicLSTMCell(cell_size), | |
tf.contrib.rnn.BasicLSTMCell(cell_size)] | |
#Create a single FS-RNN using the cells | |
fs_lstm_vu = FSRNNCell_VU(fast, slow) | |
#Get initial state and create tf op to run one timestep | |
init_state = fs_lstm_vu.zero_state(1, tf.float32) # batch_size | |
#output, final_state, N_args = fs_lstm_vu(np.zeros((batch_size, 1, state_size), np.float32), init_state) # (batch_size, state_size) | |
output, final_state, S_N_args, F_N_args = fs_lstm_vu(np.zeros((batch_size, 1, state_size), np.float32), init_state) # (batch_size, state_size) | |
output, final_state, S_N_args, F_N_args = fs_lstm_vu(np.ones((batch_size, 1, state_size), np.float32), init_state) # (batch_size, state_size) | |
S_norm_dist = tf.distributions.Normal(loc=S_N_args[0], scale=S_N_args[1]) | |
F_norm_dist = tf.distributions.Normal(loc=F_N_args[0], scale=F_N_args[1]) | |
F_norm_dist_sample_z = tf.squeeze(F_norm_dist.sample(1), axis=0) # choosing action | |
KL = tf.distributions.kl_divergence(S_norm_dist, F_norm_dist) # to be added to loss function, try to minimize KL | |
with tf.Session() as sess: | |
for i in range(1): | |
sess.run(tf.global_variables_initializer()) | |
#print('output', sess.run(output)) # (12,32) = (batch_size, cell_size) | |
#print(sess.run(output).shape) # (12,32) = (batch_size, cell_size) | |
#print(len(final_state)) | |
#print(S_norm_dist) | |
#print(F_norm_dist) | |
#print(sess.run(F_norm_dist_sample_z)) | |
print('KL', sess.run(KL)) | |
print('S_N_args[0]', sess.run(S_N_args[0])) | |
print('S_N_args[1]', sess.run(S_N_args[1])) | |
print('F_N_args[0]', sess.run(F_N_args[0])) | |
print('F_N_args[1]', sess.run(F_N_args[1])) | |
assert (sess.run(S_N_args[0]) != sess.run(F_N_args[0])).all() | |
assert (sess.run(S_N_args[1]) != sess.run(F_N_args[1])).all() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment