Skip to content

Instantly share code, notes, and snippets.

@rmdort
Last active December 18, 2017 13:22
Show Gist options
  • Select an option

  • Save rmdort/dd1b39346b6fabd42f254e09ef45aa63 to your computer and use it in GitHub Desktop.

Select an option

Save rmdort/dd1b39346b6fabd42f254e09ef45aa63 to your computer and use it in GitHub Desktop.
Similar to attention layer in https://github.com/synthesio/hierarchical-attention-networks . Ported to keras 2
class AttentionLayer(Layer):
'''
Attention layer.
Usage:
lstm_layer = LSTM(dim, return_sequences=True)
attention = AttentionLayer()(lstm_layer)
sentenceEmb = merge([lstm_layer, attention], mode=lambda x:x[1]*x[0], output_shape=lambda x:x[0])
sentenceEmb = Lambda(lambda x:K.sum(x, axis=1), output_shape=lambda x:(x[0],x[2]))(sentenceEmb)
'''
def __init__(self, init='glorot_uniform', kernel_regularizer=None, bias_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs):
self.supports_masking = True
self.init = initializers.get(init)
self.kernel_initializer = initializers.get('glorot_uniform')
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
super(AttentionLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight((input_shape[-1], 1),
initializer=self.kernel_initializer,
name='{}_W'.format(self.name),
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.built = True
def compute_mask(self, input, mask):
return mask
def call(self, x, mask=None):
multData = K.exp(K.dot(x, self.kernel))
if mask is not None:
mask = K.cast(mask, K.floatx())
mask = K.expand_dims(mask)
multData = mask*multData
output = multData/(K.sum(multData, axis=1)+K.epsilon())[:,None]
return output
def get_output_shape_for(self, input_shape):
newShape = list(input_shape)
newShape[-1] = 1
return tuple(newShape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment