Skip to content

Instantly share code, notes, and snippets.

@laurent-dinh
Last active August 29, 2015 14:17
Show Gist options
  • Select an option

  • Save laurent-dinh/648b948532f36d4c9dbb to your computer and use it in GitHub Desktop.

Select an option

Save laurent-dinh/648b948532f36d4c9dbb to your computer and use it in GitHub Desktop.
Unexpected recurrent
from theano import tensor
import numpy
from numpy.testing import assert_allclose
from blocks.bricks.recurrent import BaseRecurrent, recurrent
class UnexpectedRecurrent(BaseRecurrent):
"""`blocks.bricks.recurrent.recurrent` showing unexpected behavior =(
I would expect it to sum the elements of the sequence according to
the time axis. I also want to.
Parameters
----------
dim : int
The dimension of the hidden state
"""
def __init__(self, dim, ** kwargs):
super(UnexpectedRecurrent, self).__init__(self, ** kwargs)
self.dim = dim
def get_dim(self, name):
if name in ['inputs', 'states', 'outputs', 'states_2', 'outputs_2']:
return self.dim
if name == 'mask':
return 0
return super(UnexpectedRecurrent, self).get_dim(name)
@recurrent(sequences=['inputs', 'mask'], states=['states', 'states_2'],
outputs=['outputs', 'states_2', 'outputs_2', 'states'], contexts=[])
def apply(self, inputs=None, states=None, states_2=None, mask=None):
next_states = states + inputs
next_states_2 = states_2 + .5
if mask:
next_states = (mask[:, None] * next_states +
(1 - mask[:, None]) * states)
outputs = 10 * next_states
outputs_2 = 10 * next_states_2
return outputs, next_states_2, outputs_2, next_states
def main():
unexpected_recurrent_example = UnexpectedRecurrent(dim=1)
X = tensor.tensor3('X')
out, H2, out_2, H = unexpected_recurrent_example.apply(inputs=X, mask=None)
h = H.eval({X:numpy.ones((5, 1, 1))})
h2 = H2.eval({X:numpy.ones((5, 1, 1))})
out_eval = out.eval({X:numpy.ones((5, 1, 1))})
out_2_eval = out_2.eval({X:numpy.ones((5, 1, 1))})
assert_allclose(h, numpy.arange(5).reshape((5, 1, 1)) + 1)
assert_allclose(h * 10, out_eval)
assert_allclose(h2 * 10, out_2_eval)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment