Created November 23, 2018 09:17
Attentive Convolution with custom Attention layer
from keras.layers import Lambda, Reshape, RepeatVector, Concatenate, Conv1D, Activation
from keras.layers import Layer
from keras import activations
class Attention(Layer):
def __init__(self, kernel_activation='hard_sigmoid', before=False, **kwargs):
super(Attention, self).__init__(**kwargs)
self.kernel_activation = activations.get(kernel_activation)
self.before = before
def build(self, input_shape):
self.num_words = input_shape[0][1]
#self.em_dim = input_shape[0][2]
super(Attention, self).build(input_shape)
def get_output_shape_for(self, input_shape):
if self.before:
return input_shape
return (input_shape[0], input_shape[2])
def compute_output_shape(self, input_shape):
input_shape = input_shape[0]
if self.before:
return input_shape
return (input_shape[0], input_shape[2])
def call(self, x, mask=None):
text = x[0]
context = x[1]
length = 1
for i in range(len(context.shape)):
if i > 0:
length *= int(context.shape[i])
context = Lambda(lambda x: Reshape((length,))(x))(context)
context_repeated = RepeatVector(self.num_words)(context)
merged = Concatenate(axis=2)([context_repeated, text])
scores = Conv1D(1,1)(merged)
weights = Activation(activation='softmax')(scores)
#weighted = K.transpose(tf.multiply(K.transpose(text), weights))
if not self.before:
weigthed = K.batch_dot(K.permute_dimensions(text, (0,2,1)), weights)
return K.squeeze(weigthed, 2)
weigthed = tf.multiply(text, weights)
return weigthed
def get_config(self):
config = {'kernel_activation': activations.serialize(self.kernel_activation),
'before': self.before}
base_config = super(Attention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
from keras.layers import Flatten, Dropout, Dense, Lambda, ZeroPadding2D, MaxPooling2D, Concatenate, Conv2D, LSTM, Bidirectional, Activation
from keras import backend as K
from Core.NeuralNetwork.CustomLayers.Attention import Attention
def build_cnn(input, dropout, kernel_sizes, num_stages, num_filters, pool_sizes, attention=False, context='same', fully_connected_dimension=1000):
This method builds a cnn with the given parameters
:param input: Output of preceding layer
:param dropout: dropout-rate
:param kernel_sizes: list of lists, inner list defines different kernel-size per stage, outer list defines different stages
:param num_stages: describes the number of stages
:param num_filters: list, different number of filters can be used per stage
:param pool_sizes: list, different pool-sizes can be used per stage
:param attention: defines if attention should be applied
:param context: defines the context. 'same' means self-attention. Otherwise a layer-output can be given to apply query-attention
:return: input & output of the encoder
input_cpy = input
for i in range(num_stages):
if i > 0:
attention = False
input_cpy = _build_stage(kernel_sizes[i], input_cpy, num_filters[i], context, pool_sizes[i], attention)
flatten = Flatten()(input_cpy)
dropout = Dropout(dropout)(flatten)
fully_connected = Dense(units=fully_connected_dimension)(dropout)
return input, fully_connected
def _build_stage(kernel_sizes, pre_layer, num_filters, context, pool_size, attention):
convs = []
for size in kernel_sizes:
reshape = Lambda(lambda x: K.expand_dims(x, 3))(pre_layer)
if size % 2 == 0:
padded_input = ZeroPadding2D(padding=((int(size / 2), int(size / 2) - 1), (0,0)))(reshape)
padded_input = ZeroPadding2D(padding=((int(size / 2), int(size / 2)), (0,0)))(reshape)
conv = Conv2D(num_filters, (size, int(reshape.shape[2])), activation='relu', padding='valid')(padded_input)
if len(convs) > 1:
all_filters = Concatenate(axis=3)(convs)
all_filters = convs[0]
if attention:
all_filters = Lambda(lambda x: K.squeeze(x, 2))(all_filters)
if context == 'same':
attentive_context = Attention(before=True)([all_filters, all_filters])
attentive_context = Attention(before=True)([all_filters, context])
attentive_context = Lambda(lambda x: K.expand_dims(x, axis=2))(attentive_context)
reshape = Lambda(lambda x: K.permute_dimensions(x, (0,1,3,2)))(attentive_context)
reshape = Lambda(lambda x: K.permute_dimensions(x, (0, 1, 3, 2)))(all_filters)
if pool_size > int(reshape.shape[1]):
pool_size = int(reshape.shape[1])
filtered = MaxPooling2D(pool_size=(pool_size, 1))(reshape)
reshape = Lambda(lambda x: K.squeeze(x, 3))(filtered)
return reshape
