Last active
May 2, 2018 04:12
-
-
Save hanxiao/1879d361da3b06abbf6f47ff937a7831 to your computer and use it in GitHub Desktop.
MLSTM to match passage encodes with question encodes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Python 3.6 + TF 1.6 | |
# Han Xiao ([email protected]) | |
import tensorflow as tf | |
import tensorflow.contrib as tc | |
#### Usage: matching passage encodes with question encodes | |
ml = MatchLSTMLayer(hidden_size=16, | |
control_gate=False, | |
pooling_window=5, # do pooling otherwise attention on long passage will give OOM | |
name='demo', | |
act_fn=tf.nn.relu, | |
attend_hidden_size=16) | |
ml.match(input_encodes=p_encodes, | |
attended_encodes=q_encodes, | |
input_length=p_length, | |
input_mask=p_mask, | |
attended_mask=q_mask) | |
#### Details | |
def attend_pooling(pooling_vectors, ref_vector, hidden_size, scope=None, | |
pooling_mask=None, activation_fn=tf.tanh, output_logit=False): | |
""" | |
Applies attend pooling to a set of vectors according to a reference vector. | |
Args: | |
pooling_vectors: the vectors to pool in B x T x D size | |
ref_vector: the reference vector in B x D size, at a single time t, D can be different than pooling_vectors | |
hidden_size: the hidden size for attention function | |
scope: score name | |
Returns: | |
the pooled vector in B x D size | |
""" | |
with tf.variable_scope(scope or 'attend_pooling'): | |
# pooling vectors must be B x T x D size | |
assert pooling_vectors.get_shape().ndims == 3 | |
# ref_vector must be B x D size, at a single time t | |
assert ref_vector.get_shape().ndims == 2 | |
if pooling_mask is not None: | |
# pooling_mask must be B x T size | |
assert pooling_mask.get_shape().ndims == 2 | |
U = activation_fn(tc.layers.fully_connected(pooling_vectors, | |
num_outputs=hidden_size, | |
activation_fn=None) | |
+ tf.expand_dims(tc.layers.fully_connected(ref_vector, | |
num_outputs=hidden_size, | |
activation_fn=None), axis=1)) | |
logits = tc.layers.fully_connected(U, num_outputs=1, activation_fn=None) | |
if pooling_mask is not None: | |
logits -= tf.expand_dims(1.0 - pooling_mask, axis=2) * 1e30 | |
scores = tf.nn.softmax(logits, 1) | |
pooled_vector = tf.reduce_sum(pooling_vectors * scores, axis=1) | |
# pooled vector is B x D size, score is B x L size | |
return pooled_vector, logits if output_logit else scores | |
class MatchLSTMAttnCell(tc.rnn.LSTMCell): | |
""" | |
Implements the Match-LSTM attention cell | |
""" | |
def __init__(self, num_units, context_to_attend, control_gate, pooling_window, attend_mask, act_fn, | |
attend_hidden_size): | |
super().__init__(num_units, state_is_tuple=True) | |
if pooling_window: | |
self.context_to_attend = tf.nn.pool(context_to_attend, | |
window_shape=[pooling_window], | |
strides=[pooling_window], | |
pooling_type='MAX', padding='SAME', | |
name='max_pool_on_context') | |
# do the same max-pooling for mask | |
# mask is usually BxT. To do pooling, we need to first expand it to BxTx1 | |
# then squeeze it back to BxT | |
self.attend_mask = tf.squeeze(tf.nn.pool(tf.expand_dims(attend_mask, -1), | |
window_shape=[pooling_window], | |
strides=[pooling_window], | |
pooling_type='MAX', padding='SAME', | |
name='max_pool_on_context_mask'), -1) | |
else: | |
self.context_to_attend = context_to_attend | |
self.attend_mask = attend_mask | |
self.control_gate = control_gate | |
self.act_fn = act_fn | |
self.attend_hidden_size = attend_hidden_size | |
def __call__(self, inputs, state, scope=None): | |
(c_prev, h_prev) = state | |
with tf.variable_scope(scope or type(self).__name__): | |
ref_vector = tf.concat([inputs, h_prev], -1) | |
attended_context, scores = attend_pooling(self.context_to_attend, | |
ref_vector, | |
self.attend_hidden_size, | |
pooling_mask=self.attend_mask, | |
activation_fn=self.act_fn) | |
new_inputs = tf.concat([inputs, attended_context, | |
inputs - attended_context, | |
inputs * attended_context], | |
-1) | |
if self.control_gate: | |
# modified by adding another gate to the input | |
control_gate = tc.layers.fully_connected(new_inputs, | |
num_outputs=self.output_size * 8, | |
activation_fn=tf.nn.sigmoid) | |
new_inputs *= control_gate | |
return super().__call__(new_inputs, state, scope) | |
class MatchLSTMLayer: | |
""" | |
Implements the Match-LSTM layer, which attend to the question dynamically in a LSTM fashion. | |
""" | |
def __init__(self, hidden_size: int, control_gate: bool, | |
pooling_window: int, name: str, act_fn, attend_hidden_size: int): | |
self.hidden_size = hidden_size | |
self.output_size = hidden_size * 2 # bi-directional | |
self.act_fn = act_fn | |
self.name = name | |
self.control_gate = control_gate | |
self.pooling_window = pooling_window # useful when context is too long. | |
self.attend_hidden_size = attend_hidden_size | |
def match(self, input_encodes, attended_encodes, input_length, input_mask, attended_mask): | |
""" | |
Match the passage_encodes with question_encodes using Match-LSTM algorithm | |
""" | |
with tf.variable_scope(self.name): | |
cell_fw = MatchLSTMAttnCell(self.hidden_size, attended_encodes, self.control_gate, self.pooling_window, | |
attended_mask, self.act_fn, self.attend_hidden_size) | |
cell_bw = MatchLSTMAttnCell(self.hidden_size, attended_encodes, self.control_gate, self.pooling_window, | |
attended_mask, self.act_fn, self.attend_hidden_size) | |
outputs, state = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, | |
inputs=input_encodes, | |
sequence_length=input_length, | |
dtype=tf.float32) | |
match_outputs = tf.concat(outputs, 2) | |
state_fw, state_bw = state | |
c_fw, h_fw = state_fw | |
c_bw, h_bw = state_bw | |
match_state = tf.concat([h_fw, h_bw], 1) | |
return match_outputs, match_state | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment