-
-
Save pydemia/5977fc1d44dfe6e0f9502cef79a27905 to your computer and use it in GitHub Desktop.
A keras attention layer that wraps RNN layers.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A keras attention layer that wraps RNN layers. | |
Based on tensorflows [attention_decoder](https://github.com/tensorflow/tensorflow/blob/c8a45a8e236776bed1d14fd71f3b6755bd63cc58/tensorflow/python/ops/seq2seq.py#L506) | |
and [Grammar as a Foreign Language](https://arxiv.org/abs/1412.7449). | |
date: 20161101 | |
author: wassname | |
url: https://gist.github.com/wassname/5292f95000e409e239b9dc973295327a | |
""" | |
from keras import backend as K | |
from keras.engine import InputSpec | |
from keras.layers import LSTM, activations, Wrapper, Recurrent | |
class Attention(Wrapper): | |
""" | |
This wrapper will provide an attention layer to a recurrent layer. | |
# Arguments: | |
layer: `Recurrent` instance with consume_less='gpu' or 'mem' | |
# Examples: | |
```python | |
model = Sequential() | |
model.add(LSTM(10, return_sequences=True), batch_input_shape=(4, 5, 10)) | |
model.add(TFAttentionRNNWrapper(LSTM(10, return_sequences=True, consume_less='gpu'))) | |
model.add(Dense(5)) | |
model.add(Activation('softmax')) | |
model.compile(loss='categorical_crossentropy', optimizer='rmsprop') | |
``` | |
# References | |
- [Grammar as a Foreign Language](https://arxiv.org/abs/1412.7449) | |
""" | |
def __init__(self, layer, **kwargs): | |
assert isinstance(layer, Recurrent) | |
if layer.get_config()['consume_less']=='cpu': | |
raise Exception("AttentionLSTMWrapper doesn't support RNN's with consume_less='cpu'") | |
self.supports_masking = True | |
super(Attention, self).__init__(layer, **kwargs) | |
def build(self, input_shape): | |
assert len(input_shape) >= 3 | |
self.input_spec = [InputSpec(shape=input_shape)] | |
nb_samples, nb_time, input_dim = input_shape | |
if not self.layer.built: | |
self.layer.build(input_shape) | |
self.layer.built = True | |
super(Attention, self).build() | |
self.W1 = self.layer.init((input_dim, input_dim, 1, 1), name='{}_W1'.format(self.name)) | |
self.W2 = self.layer.init((self.layer.output_dim, input_dim), name='{}_W2'.format(self.name)) | |
self.b2 = K.zeros((input_dim,), name='{}_b2'.format(self.name)) | |
self.W3 = self.layer.init((input_dim*2, input_dim), name='{}_W3'.format(self.name)) | |
self.b3 = K.zeros((input_dim,), name='{}_b3'.format(self.name)) | |
self.V = self.layer.init((input_dim,), name='{}_V'.format(self.name)) | |
self.trainable_weights = [self.W1, self.W2, self.W3, self.V, self.b2, self.b3] | |
def get_output_shape_for(self, input_shape): | |
return self.layer.get_output_shape_for(input_shape) | |
def step(self, x, states): | |
# This is based on [tensorflows implementation](https://github.com/tensorflow/tensorflow/blob/c8a45a8e236776bed1d14fd71f3b6755bd63cc58/tensorflow/python/ops/seq2seq.py#L506). | |
# First, we calculate new attention masks: | |
# attn = softmax(V^T * tanh(W2 * X +b2 + W1 * h)) | |
# and we make the input as a concatenation of the input and weighted inputs which is then | |
# transformed back to the shape x of using W3 | |
# x = W3*(x+X*attn)+b3 | |
# Then, we run the cell on a combination of the input and previous attention masks: | |
# h, state = cell(x, h). | |
nb_samples, nb_time, input_dim = self.input_spec[0].shape | |
h = states[0] | |
X = states[-1] | |
xW1 = states[-2] | |
Xr = K.reshape(X,(-1,nb_time,1,input_dim)) | |
hW2 = K.dot(h,self.W2)+self.b2 | |
hW2 = K.reshape(hW2,(-1,1,1,input_dim)) | |
u = K.tanh(xW1+hW2) | |
a = K.sum(self.V*u,[2,3]) | |
a = K.softmax(a) | |
a = K.reshape(a,(-1, nb_time, 1, 1)) | |
# Weight attention vector by attention | |
Xa = K.sum(a*Xr,[1,2]) | |
Xa = K.reshape(Xa,(-1,input_dim)) | |
# Merge input and attention weighted inputs into one vector of the right size. | |
x = K.dot(K.concatenate([x,Xa],1),self.W3)+self.b3 | |
h, new_states = self.layer.step(x, states) | |
return h, new_states | |
def get_constants(self, x): | |
constants = self.layer.get_constants(x) | |
# Calculate K.dot(x, W2) only once per sequence by making it a constant | |
nb_samples, nb_time, input_dim = self.input_spec[0].shape | |
Xr = K.reshape(x,(-1,nb_time,input_dim,1)) | |
Xrt = K.permute_dimensions(Xr, (0, 2, 1, 3)) | |
xW1t = K.conv2d(Xrt,self.W1,border_mode='same') | |
xW1 = K.permute_dimensions(xW1t, (0, 2, 3, 1)) | |
constants.append(xW1) | |
# we need to supply the full sequence of inputs to step (as the attention_vector) | |
constants.append(x) | |
return constants | |
def call(self, x, mask=None): | |
# input shape: (nb_samples, time (padded with zeros), input_dim) | |
input_shape = self.input_spec[0].shape | |
if K._BACKEND == 'tensorflow': | |
if not input_shape[1]: | |
raise Exception('When using TensorFlow, you should define ' | |
'explicitly the number of timesteps of ' | |
'your sequences.\n' | |
'If your first layer is an Embedding, ' | |
'make sure to pass it an "input_length" ' | |
'argument. Otherwise, make sure ' | |
'the first layer has ' | |
'an "input_shape" or "batch_input_shape" ' | |
'argument, including the time axis. ' | |
'Found input shape at layer ' + self.name + | |
': ' + str(input_shape)) | |
if self.layer.stateful: | |
initial_states = self.layer.states | |
else: | |
initial_states = self.layer.get_initial_states(x) | |
constants = self.get_constants(x) | |
preprocessed_input = self.layer.preprocess_input(x) | |
last_output, outputs, states = K.rnn(self.step, preprocessed_input, | |
initial_states, | |
go_backwards=self.layer.go_backwards, | |
mask=mask, | |
constants=constants, | |
unroll=self.layer.unroll, | |
input_length=input_shape[1]) | |
if self.layer.stateful: | |
self.updates = [] | |
for i in range(len(states)): | |
self.updates.append((self.layer.states[i], states[i])) | |
if self.layer.return_sequences: | |
return outputs | |
else: | |
return last_output | |
# test likes in https://github.com/fchollet/keras/blob/master/tests/keras/layers/test_wrappers.py | |
import pytest | |
import numpy as np | |
from numpy.testing import assert_allclose | |
from keras.utils.test_utils import keras_test | |
from keras.layers import wrappers, Input, recurrent, InputLayer | |
from keras.layers import core, convolutional, recurrent | |
from keras.models import Sequential, Model, model_from_json | |
nb_samples, timesteps, embedding_dim, output_dim = 2, 5, 3, 4 | |
embedding_num = 12 | |
x = np.random.random((nb_samples, timesteps, embedding_dim)) | |
y = np.random.random((nb_samples, timesteps, output_dim)) | |
# base line test with LSTM | |
model = Sequential() | |
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim))) | |
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem'))) | |
model.add(core.Activation('relu')) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x,y, nb_epoch=1, batch_size=nb_samples) | |
# test stacked with all RNN layers and consume_less options | |
model = Sequential() | |
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim))) | |
# test supported consume_less options | |
# model.add(Attention(recurrent.LSTM(embedding_dim, input_dim=embedding_dim,, consume_less='cpu' return_sequences=True))) # not supported | |
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, consume_less='gpu', return_sequences=True))) | |
model.add(Attention(recurrent.LSTM(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True))) | |
# test each other RNN type | |
model.add(Attention(recurrent.GRU(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True))) | |
model.add(Attention(recurrent.SimpleRNN(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True))) | |
model.add(core.Activation('relu')) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x,y, nb_epoch=1, batch_size=nb_samples) | |
# test with return_sequence = False | |
model = Sequential() | |
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim))) | |
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=False, consume_less='mem'))) | |
model.add(core.Activation('relu')) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x,y[:,-1,:], nb_epoch=1, batch_size=nb_samples) | |
# with bidirectional encoder | |
model = Sequential() | |
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim))) | |
model.add(wrappers.Bidirectional(recurrent.LSTM(embedding_dim, input_dim=embedding_dim, return_sequences=True))) | |
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem'))) | |
model.add(core.Activation('relu')) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x,y, nb_epoch=1, batch_size=nb_samples) | |
# test config | |
model.get_config() | |
# test to and from json | |
model = model_from_json(model.to_json(),custom_objects=dict(Attention=Attention)) | |
model.summary() | |
# test with functional API | |
input = Input(batch_shape=(nb_samples, timesteps, embedding_dim)) | |
output = Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem'))(input) | |
model = Model(input, output) | |
model.compile(optimizer='rmsprop', loss='mse') | |
model.fit(x, y, nb_epoch=1, batch_size=nb_samples) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment