Skip to content

Instantly share code, notes, and snippets.

@matpalm
Created July 24, 2015 15:33
Show Gist options
  • Save matpalm/503130427d9631ff11bc to your computer and use it in GitHub Desktop.
Save matpalm/503130427d9631ff11bc to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import theano
import theano.tensor as T
import numpy as np
NUM_TOKENS = 5 # number of tokens in sequence being attended to
D = 3 # generate embedding dim
np.random.seed(123)
# params of dummy RNN to gen data
Wx = theano.shared(np.asarray(np.random.randn(D, D), dtype='float32'))
# params of the attention mechanism
Wag = theano.shared(np.asarray(np.random.randn(D, D), dtype='float32'))
Wug = theano.shared(np.asarray(np.random.randn(D, D), dtype='float32'))
wgs = theano.shared(np.asarray(np.random.randn(D), dtype='float32'))
# initial sequence
x_v = np.asarray(np.random.randn(NUM_TOKENS, D), dtype='float32')
# first a dummy RNN just used to generate a sequence of annotations
# this could be arbitrarily complex.
x = T.fmatrix('x')
def _annotation(x_t, h_t_minus_1):
h_t = T.tanh(x_t + T.dot(Wx, h_t_minus_1))
return [h_t, h_t]
h0 = np.zeros(D, dtype=np.float32).T
[annotations, _hidden], _ = theano.scan(fn=_annotation,
sequences=[x],
outputs_info=[None, h0])
def _attended_annotation(u, annotations, Wag, Wug, wgs):
# first we need to mix the annotations using 'u' as a the context of
# attention. we'll be doing _all_ annotations wrt u in one hit, so we
# need a column broadcastable version of u
u_col = u.dimshuffle(0, 'x')
glimpse_vectors = T.tanh(T.dot(Wag, annotations.T) + T.dot(Wug, u_col))
# now collapse the glimpse vectors (there's one per token) to scalars
unnormalised_glimpse_scalars = T.dot(wgs, glimpse_vectors)
# normalise glimpses with a softmax
exp_glimpses = T.exp(unnormalised_glimpse_scalars)
glimpses = exp_glimpses / T.sum(exp_glimpses)
# attended version of the annotations is the the affine combo of the
# annotations using the normalised glimpses as the combo weights
attended_annotations = T.dot(annotations.T, glimpses)
return attended_annotations
attended_annotations, _ = theano.scan(fn=_attended_annotation,
sequences=[x],
non_sequences=[annotations, Wag, Wug, wgs])
# --------------------------------
x_v[2] = x_v[1] # HACK; canary check that attended combo wrt u2=u1 the same
print "x_v", x_v
annotations_v = annotations.eval({x: x_v})
print "annotations", annotations_v
attended_annotations_v = attended_annotations.eval({x: x_v})
print "attended_annotations_v", attended_annotations_v
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment