Created
July 24, 2015 15:33
-
-
Save matpalm/503130427d9631ff11bc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import theano | |
import theano.tensor as T | |
import numpy as np | |
NUM_TOKENS = 5 # number of tokens in sequence being attended to | |
D = 3 # generate embedding dim | |
np.random.seed(123) | |
# params of dummy RNN to gen data | |
Wx = theano.shared(np.asarray(np.random.randn(D, D), dtype='float32')) | |
# params of the attention mechanism | |
Wag = theano.shared(np.asarray(np.random.randn(D, D), dtype='float32')) | |
Wug = theano.shared(np.asarray(np.random.randn(D, D), dtype='float32')) | |
wgs = theano.shared(np.asarray(np.random.randn(D), dtype='float32')) | |
# initial sequence | |
x_v = np.asarray(np.random.randn(NUM_TOKENS, D), dtype='float32') | |
# first a dummy RNN just used to generate a sequence of annotations | |
# this could be arbitrarily complex. | |
x = T.fmatrix('x') | |
def _annotation(x_t, h_t_minus_1): | |
h_t = T.tanh(x_t + T.dot(Wx, h_t_minus_1)) | |
return [h_t, h_t] | |
h0 = np.zeros(D, dtype=np.float32).T | |
[annotations, _hidden], _ = theano.scan(fn=_annotation, | |
sequences=[x], | |
outputs_info=[None, h0]) | |
def _attended_annotation(u, annotations, Wag, Wug, wgs): | |
# first we need to mix the annotations using 'u' as a the context of | |
# attention. we'll be doing _all_ annotations wrt u in one hit, so we | |
# need a column broadcastable version of u | |
u_col = u.dimshuffle(0, 'x') | |
glimpse_vectors = T.tanh(T.dot(Wag, annotations.T) + T.dot(Wug, u_col)) | |
# now collapse the glimpse vectors (there's one per token) to scalars | |
unnormalised_glimpse_scalars = T.dot(wgs, glimpse_vectors) | |
# normalise glimpses with a softmax | |
exp_glimpses = T.exp(unnormalised_glimpse_scalars) | |
glimpses = exp_glimpses / T.sum(exp_glimpses) | |
# attended version of the annotations is the the affine combo of the | |
# annotations using the normalised glimpses as the combo weights | |
attended_annotations = T.dot(annotations.T, glimpses) | |
return attended_annotations | |
attended_annotations, _ = theano.scan(fn=_attended_annotation, | |
sequences=[x], | |
non_sequences=[annotations, Wag, Wug, wgs]) | |
# -------------------------------- | |
x_v[2] = x_v[1] # HACK; canary check that attended combo wrt u2=u1 the same | |
print "x_v", x_v | |
annotations_v = annotations.eval({x: x_v}) | |
print "annotations", annotations_v | |
attended_annotations_v = attended_annotations.eval({x: x_v}) | |
print "attended_annotations_v", attended_annotations_v |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment