Skip to content

Instantly share code, notes, and snippets.

@rizar
Created June 7, 2015 16:37
Show Gist options
  • Save rizar/36451d08719833fb59d0 to your computer and use it in GitHub Desktop.
Save rizar/36451d08719833fb59d0 to your computer and use it in GitHub Desktop.
Recurrent with fork
class RecurrentWithFork(Initializable):
@lazy(allocation=['input_dim'])
def __init__(self, recurrent, input_dim, **kwargs):
super(RecurrentWithFork, self).__init__(**kwargs)
self.recurrent = recurrent
self.input_dim = input_dim
self.fork = Fork(
[name for name in self.recurrent.sequences
if name != 'mask'],
prototype=Linear())
self.children = [recurrent.brick, self.fork]
def _push_allocation_config(self):
self.fork.input_dim = self.input_dim
self.fork.output_dims = [self.recurrent.brick.get_dim(name)
for name in self.fork.output_names]
@application(inputs=['input_', 'mask'])
def apply(self, input_, mask=None, **kwargs):
return self.recurrent(
mask=mask, **dict_union(self.fork.apply(input_, as_dict=True),
kwargs))
@apply.property('outputs')
def apply_outputs(self):
return self.recurrent.states
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment