Created
November 23, 2018 09:17
-
-
Save stoney95/3fc4216db7af0675df4d6326adb8aee0 to your computer and use it in GitHub Desktop.
Attentive Convolution with custom Attention layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Lambda, Reshape, RepeatVector, Concatenate, Conv1D, Activation | |
from keras.layers import Layer | |
from keras import activations | |
class Attention(Layer): | |
def __init__(self, kernel_activation='hard_sigmoid', before=False, **kwargs): | |
super(Attention, self).__init__(**kwargs) | |
self.kernel_activation = activations.get(kernel_activation) | |
K.set_floatx('float32') | |
self.before = before | |
def build(self, input_shape): | |
self.num_words = input_shape[0][1] | |
#self.em_dim = input_shape[0][2] | |
super(Attention, self).build(input_shape) | |
def get_output_shape_for(self, input_shape): | |
if self.before: | |
return input_shape | |
return (input_shape[0], input_shape[2]) | |
def compute_output_shape(self, input_shape): | |
input_shape = input_shape[0] | |
if self.before: | |
return input_shape | |
return (input_shape[0], input_shape[2]) | |
def call(self, x, mask=None): | |
text = x[0] | |
context = x[1] | |
length = 1 | |
for i in range(len(context.shape)): | |
if i > 0: | |
length *= int(context.shape[i]) | |
context = Lambda(lambda x: Reshape((length,))(x))(context) | |
context_repeated = RepeatVector(self.num_words)(context) | |
merged = Concatenate(axis=2)([context_repeated, text]) | |
scores = Conv1D(1,1)(merged) | |
weights = Activation(activation='softmax')(scores) | |
#weighted = K.transpose(tf.multiply(K.transpose(text), weights)) | |
if not self.before: | |
weigthed = K.batch_dot(K.permute_dimensions(text, (0,2,1)), weights) | |
return K.squeeze(weigthed, 2) | |
weigthed = tf.multiply(text, weights) | |
return weigthed | |
def get_config(self): | |
config = {'kernel_activation': activations.serialize(self.kernel_activation), | |
'before': self.before} | |
base_config = super(Attention, self).get_config() | |
return dict(list(base_config.items()) + list(config.items())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from keras.layers import Flatten, Dropout, Dense, Lambda, ZeroPadding2D, MaxPooling2D, Concatenate, Conv2D, LSTM, Bidirectional, Activation | |
from keras import backend as K | |
from Core.NeuralNetwork.CustomLayers.Attention import Attention | |
def build_cnn(input, dropout, kernel_sizes, num_stages, num_filters, pool_sizes, attention=False, context='same', fully_connected_dimension=1000): | |
''' | |
This method builds a cnn with the given parameters | |
:param input: Output of preceding layer | |
:param dropout: dropout-rate | |
:param kernel_sizes: list of lists, inner list defines different kernel-size per stage, outer list defines different stages | |
:param num_stages: describes the number of stages | |
:param num_filters: list, different number of filters can be used per stage | |
:param pool_sizes: list, different pool-sizes can be used per stage | |
:param attention: defines if attention should be applied | |
:param context: defines the context. 'same' means self-attention. Otherwise a layer-output can be given to apply query-attention | |
:return: input & output of the encoder | |
''' | |
input_cpy = input | |
for i in range(num_stages): | |
if i > 0: | |
attention = False | |
input_cpy = _build_stage(kernel_sizes[i], input_cpy, num_filters[i], context, pool_sizes[i], attention) | |
flatten = Flatten()(input_cpy) | |
dropout = Dropout(dropout)(flatten) | |
fully_connected = Dense(units=fully_connected_dimension)(dropout) | |
return input, fully_connected | |
def _build_stage(kernel_sizes, pre_layer, num_filters, context, pool_size, attention): | |
convs = [] | |
for size in kernel_sizes: | |
reshape = Lambda(lambda x: K.expand_dims(x, 3))(pre_layer) | |
if size % 2 == 0: | |
padded_input = ZeroPadding2D(padding=((int(size / 2), int(size / 2) - 1), (0,0)))(reshape) | |
else: | |
padded_input = ZeroPadding2D(padding=((int(size / 2), int(size / 2)), (0,0)))(reshape) | |
conv = Conv2D(num_filters, (size, int(reshape.shape[2])), activation='relu', padding='valid')(padded_input) | |
convs.append(conv) | |
if len(convs) > 1: | |
all_filters = Concatenate(axis=3)(convs) | |
else: | |
all_filters = convs[0] | |
if attention: | |
all_filters = Lambda(lambda x: K.squeeze(x, 2))(all_filters) | |
if context == 'same': | |
attentive_context = Attention(before=True)([all_filters, all_filters]) | |
else: | |
attentive_context = Attention(before=True)([all_filters, context]) | |
attentive_context = Lambda(lambda x: K.expand_dims(x, axis=2))(attentive_context) | |
reshape = Lambda(lambda x: K.permute_dimensions(x, (0,1,3,2)))(attentive_context) | |
else: | |
reshape = Lambda(lambda x: K.permute_dimensions(x, (0, 1, 3, 2)))(all_filters) | |
if pool_size > int(reshape.shape[1]): | |
pool_size = int(reshape.shape[1]) | |
filtered = MaxPooling2D(pool_size=(pool_size, 1))(reshape) | |
reshape = Lambda(lambda x: K.squeeze(x, 3))(filtered) | |
return reshape |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment