Skip to content

Instantly share code, notes, and snippets.

@pydemia
Forked from wassname/keras_attention_wrapper.py
Created January 12, 2018 02:34
Show Gist options
  • Save pydemia/5977fc1d44dfe6e0f9502cef79a27905 to your computer and use it in GitHub Desktop.
Save pydemia/5977fc1d44dfe6e0f9502cef79a27905 to your computer and use it in GitHub Desktop.
A keras attention layer that wraps RNN layers.
"""
A keras attention layer that wraps RNN layers.
Based on tensorflows [attention_decoder](https://github.com/tensorflow/tensorflow/blob/c8a45a8e236776bed1d14fd71f3b6755bd63cc58/tensorflow/python/ops/seq2seq.py#L506)
and [Grammar as a Foreign Language](https://arxiv.org/abs/1412.7449).
date: 20161101
author: wassname
url: https://gist.github.com/wassname/5292f95000e409e239b9dc973295327a
"""
from keras import backend as K
from keras.engine import InputSpec
from keras.layers import LSTM, activations, Wrapper, Recurrent
class Attention(Wrapper):
"""
This wrapper will provide an attention layer to a recurrent layer.
# Arguments:
layer: `Recurrent` instance with consume_less='gpu' or 'mem'
# Examples:
```python
model = Sequential()
model.add(LSTM(10, return_sequences=True), batch_input_shape=(4, 5, 10))
model.add(TFAttentionRNNWrapper(LSTM(10, return_sequences=True, consume_less='gpu')))
model.add(Dense(5))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
```
# References
- [Grammar as a Foreign Language](https://arxiv.org/abs/1412.7449)
"""
def __init__(self, layer, **kwargs):
assert isinstance(layer, Recurrent)
if layer.get_config()['consume_less']=='cpu':
raise Exception("AttentionLSTMWrapper doesn't support RNN's with consume_less='cpu'")
self.supports_masking = True
super(Attention, self).__init__(layer, **kwargs)
def build(self, input_shape):
assert len(input_shape) >= 3
self.input_spec = [InputSpec(shape=input_shape)]
nb_samples, nb_time, input_dim = input_shape
if not self.layer.built:
self.layer.build(input_shape)
self.layer.built = True
super(Attention, self).build()
self.W1 = self.layer.init((input_dim, input_dim, 1, 1), name='{}_W1'.format(self.name))
self.W2 = self.layer.init((self.layer.output_dim, input_dim), name='{}_W2'.format(self.name))
self.b2 = K.zeros((input_dim,), name='{}_b2'.format(self.name))
self.W3 = self.layer.init((input_dim*2, input_dim), name='{}_W3'.format(self.name))
self.b3 = K.zeros((input_dim,), name='{}_b3'.format(self.name))
self.V = self.layer.init((input_dim,), name='{}_V'.format(self.name))
self.trainable_weights = [self.W1, self.W2, self.W3, self.V, self.b2, self.b3]
def get_output_shape_for(self, input_shape):
return self.layer.get_output_shape_for(input_shape)
def step(self, x, states):
# This is based on [tensorflows implementation](https://github.com/tensorflow/tensorflow/blob/c8a45a8e236776bed1d14fd71f3b6755bd63cc58/tensorflow/python/ops/seq2seq.py#L506).
# First, we calculate new attention masks:
# attn = softmax(V^T * tanh(W2 * X +b2 + W1 * h))
# and we make the input as a concatenation of the input and weighted inputs which is then
# transformed back to the shape x of using W3
# x = W3*(x+X*attn)+b3
# Then, we run the cell on a combination of the input and previous attention masks:
# h, state = cell(x, h).
nb_samples, nb_time, input_dim = self.input_spec[0].shape
h = states[0]
X = states[-1]
xW1 = states[-2]
Xr = K.reshape(X,(-1,nb_time,1,input_dim))
hW2 = K.dot(h,self.W2)+self.b2
hW2 = K.reshape(hW2,(-1,1,1,input_dim))
u = K.tanh(xW1+hW2)
a = K.sum(self.V*u,[2,3])
a = K.softmax(a)
a = K.reshape(a,(-1, nb_time, 1, 1))
# Weight attention vector by attention
Xa = K.sum(a*Xr,[1,2])
Xa = K.reshape(Xa,(-1,input_dim))
# Merge input and attention weighted inputs into one vector of the right size.
x = K.dot(K.concatenate([x,Xa],1),self.W3)+self.b3
h, new_states = self.layer.step(x, states)
return h, new_states
def get_constants(self, x):
constants = self.layer.get_constants(x)
# Calculate K.dot(x, W2) only once per sequence by making it a constant
nb_samples, nb_time, input_dim = self.input_spec[0].shape
Xr = K.reshape(x,(-1,nb_time,input_dim,1))
Xrt = K.permute_dimensions(Xr, (0, 2, 1, 3))
xW1t = K.conv2d(Xrt,self.W1,border_mode='same')
xW1 = K.permute_dimensions(xW1t, (0, 2, 3, 1))
constants.append(xW1)
# we need to supply the full sequence of inputs to step (as the attention_vector)
constants.append(x)
return constants
def call(self, x, mask=None):
# input shape: (nb_samples, time (padded with zeros), input_dim)
input_shape = self.input_spec[0].shape
if K._BACKEND == 'tensorflow':
if not input_shape[1]:
raise Exception('When using TensorFlow, you should define '
'explicitly the number of timesteps of '
'your sequences.\n'
'If your first layer is an Embedding, '
'make sure to pass it an "input_length" '
'argument. Otherwise, make sure '
'the first layer has '
'an "input_shape" or "batch_input_shape" '
'argument, including the time axis. '
'Found input shape at layer ' + self.name +
': ' + str(input_shape))
if self.layer.stateful:
initial_states = self.layer.states
else:
initial_states = self.layer.get_initial_states(x)
constants = self.get_constants(x)
preprocessed_input = self.layer.preprocess_input(x)
last_output, outputs, states = K.rnn(self.step, preprocessed_input,
initial_states,
go_backwards=self.layer.go_backwards,
mask=mask,
constants=constants,
unroll=self.layer.unroll,
input_length=input_shape[1])
if self.layer.stateful:
self.updates = []
for i in range(len(states)):
self.updates.append((self.layer.states[i], states[i]))
if self.layer.return_sequences:
return outputs
else:
return last_output
# test likes in https://github.com/fchollet/keras/blob/master/tests/keras/layers/test_wrappers.py
import pytest
import numpy as np
from numpy.testing import assert_allclose
from keras.utils.test_utils import keras_test
from keras.layers import wrappers, Input, recurrent, InputLayer
from keras.layers import core, convolutional, recurrent
from keras.models import Sequential, Model, model_from_json
nb_samples, timesteps, embedding_dim, output_dim = 2, 5, 3, 4
embedding_num = 12
x = np.random.random((nb_samples, timesteps, embedding_dim))
y = np.random.random((nb_samples, timesteps, output_dim))
# base line test with LSTM
model = Sequential()
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim)))
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem')))
model.add(core.Activation('relu'))
model.compile(optimizer='rmsprop', loss='mse')
model.fit(x,y, nb_epoch=1, batch_size=nb_samples)
# test stacked with all RNN layers and consume_less options
model = Sequential()
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim)))
# test supported consume_less options
# model.add(Attention(recurrent.LSTM(embedding_dim, input_dim=embedding_dim,, consume_less='cpu' return_sequences=True))) # not supported
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, consume_less='gpu', return_sequences=True)))
model.add(Attention(recurrent.LSTM(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True)))
# test each other RNN type
model.add(Attention(recurrent.GRU(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True)))
model.add(Attention(recurrent.SimpleRNN(embedding_dim, input_dim=embedding_dim, consume_less='mem', return_sequences=True)))
model.add(core.Activation('relu'))
model.compile(optimizer='rmsprop', loss='mse')
model.fit(x,y, nb_epoch=1, batch_size=nb_samples)
# test with return_sequence = False
model = Sequential()
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim)))
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=False, consume_less='mem')))
model.add(core.Activation('relu'))
model.compile(optimizer='rmsprop', loss='mse')
model.fit(x,y[:,-1,:], nb_epoch=1, batch_size=nb_samples)
# with bidirectional encoder
model = Sequential()
model.add(InputLayer(batch_input_shape=(nb_samples, timesteps, embedding_dim)))
model.add(wrappers.Bidirectional(recurrent.LSTM(embedding_dim, input_dim=embedding_dim, return_sequences=True)))
model.add(Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem')))
model.add(core.Activation('relu'))
model.compile(optimizer='rmsprop', loss='mse')
model.fit(x,y, nb_epoch=1, batch_size=nb_samples)
# test config
model.get_config()
# test to and from json
model = model_from_json(model.to_json(),custom_objects=dict(Attention=Attention))
model.summary()
# test with functional API
input = Input(batch_shape=(nb_samples, timesteps, embedding_dim))
output = Attention(recurrent.LSTM(output_dim, input_dim=embedding_dim, return_sequences=True, consume_less='mem'))(input)
model = Model(input, output)
model.compile(optimizer='rmsprop', loss='mse')
model.fit(x, y, nb_epoch=1, batch_size=nb_samples)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment