Created
July 24, 2019 01:25
-
-
Save ericl/6501eb32054c1e000dbd7ba2492ff9b1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MaskingLayerRNNmodel(TFModelV2): | |
def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kw): | |
super(MaskingLayerRNNmodel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw) | |
self.initialize_lstm_with_prev_state = model_config['custom_options']['initialize_lstm_with_prev_state'] | |
self.input_layer = tf.keras.layers.Input( | |
shape=(None, obs_space.shape[0]), | |
name='inputLayer') | |
self.state_in_c = tf.keras.layers.Input( | |
shape=(model_config['lstm_cell_size']), | |
name='c') | |
self.state_in_h = tf.keras.layers.Input( | |
shape=(model_config['lstm_cell_size']), | |
name='h') | |
self.seq_in = tf.keras.layers.Input( | |
shape=(), | |
name='seqLens') | |
dense_layer_1 = tf.keras.layers.Dense( | |
model_config['fcnet_hiddens'][0], | |
activation=tf.nn.relu, | |
name='denseLayer1')(self.input_layer) | |
# masking_layer = tf.keras.layers.Masking( | |
# mask_value=0.0)(dense_layer_1) | |
lstm_out, state_h, state_c = tf.keras.layers.LSTM( | |
model_config['lstm_cell_size'], | |
return_sequences=True, | |
return_state=True, | |
name='lstmLayer')(inputs=dense_layer_1, | |
mask=tf.sequence_mask(self.seq_in) if self.model_config['max_seq_len'] > 1 else None, | |
initial_state=[self.state_in_c, self.state_in_h]) # note that initial_states=None (not correct), how could we pass 'state' here? | |
# if we had access to batch shape, we could set stateful=True in LSTM and call reset_states() instead of passing state | |
# reshape_layer does not accept mask which is propogated through model if Masking() is used upstream, FAILS! | |
reshape_layer = tf.keras.layers.Lambda( | |
lambda x: tf.reshape(x, [-1, model_config['lstm_cell_size']]))(lstm_out) | |
dense_layer_2 = tf.keras.layers.Dense( | |
model_config['fcnet_hiddens'][1], | |
activation=tf.nn.relu, | |
name='denseLayer2')(lstm_out) | |
logits_layer = tf.keras.layers.Dense( | |
self.num_outputs, | |
activation=tf.keras.activations.linear, | |
name='logitsLayer')(dense_layer_2) | |
value_layer = tf.keras.layers.Dense( | |
1, | |
activation=None, | |
name='valueLayer')(dense_layer_2) | |
state = [state_h, state_c] | |
self.base_model = tf.keras.Model(inputs=[self.input_layer, self.state_in_c, self.state_in_h, self.seq_in], outputs=[logits_layer, value_layer, state_h, state_c]) | |
self.register_variables(self.base_model.variables) | |
self.base_model.summary() | |
# Implement the core forward method | |
def forward(self, input_dict, state, seq_lens): | |
x = input_dict['obs'] | |
if x._rank() < 3: | |
x = add_time_dimension(x, seq_lens) | |
logits, self._value_out, h, c = self.base_model((x, state[0], state[1], seq_lens)) | |
return tf.reshape(logits, [-1, 2]), [h, c] | |
def get_initial_state(self): | |
return [np.zeros(self.model_config['lstm_cell_size'], np.float32), | |
np.zeros(self.model_config['lstm_cell_size'], np.float32)] | |
def value_function(self): | |
return tf.reshape(self._value_out, [-1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment