Skip to content

Instantly share code, notes, and snippets.

@jrimestad
Created February 12, 2019 01:21
Show Gist options
  • Save jrimestad/fa9f0b2d6a53ec1dcdb89b4de0dcb6e3 to your computer and use it in GitHub Desktop.
Save jrimestad/fa9f0b2d6a53ec1dcdb89b4de0dcb6e3 to your computer and use it in GitHub Desktop.
from __future__ import absolute_import, division
import sys
from os.path import dirname
sys.path.append(dirname(dirname(__file__)))
from keras import initializers
from keras.engine import InputSpec, Layer
from keras import backend as K
class AttentionWeightedAverage2D(Layer):
def __init__(self, **kwargs):
self.init = initializers.get('uniform')
super(AttentionWeightedAverage2D, self).__init__(** kwargs)
def build(self, input_shape):
self.input_spec = [InputSpec(ndim=4)]
assert len(input_shape) == 4
self.W = self.add_weight(shape=(input_shape[3], 1),
name='{}_W'.format(self.name),
initializer=self.init)
self.trainable_weights = [self.W]
super(AttentionWeightedAverage2D, self).build(input_shape)
def call(self, x):
logits = K.dot(x, self.W)
x_shape = K.shape(x)
logits = K.reshape(logits, (x_shape[0], x_shape[1], x_shape[2]))
ai = K.exp(logits - K.max(logits, axis=[1,2], keepdims=True))
att_weights = ai / (K.sum(ai, axis=[1,2], keepdims=True) + K.epsilon())
weighted_input = x * K.expand_dims(att_weights)
result = K.sum(weighted_input, axis=[1,2])
return result
def get_output_shape_for(self, input_shape):
return self.compute_output_shape(input_shape)
def compute_output_shape(self, input_shape):
output_len = input_shape[3]
return (input_shape[0], output_len)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment