Skip to content

Instantly share code, notes, and snippets.

@delta2323
Last active July 13, 2016 15:17
Show Gist options
  • Save delta2323/4ac2b92e65f5d6af190e11b66479e2f0 to your computer and use it in GitHub Desktop.
Save delta2323/4ac2b92e65f5d6af190e11b66479e2f0 to your computer and use it in GitHub Desktop.
import six
import numpy
import chainer
class StatelessRNN(chainer.Chain):
# state_names defines the order of states
state_names = ('c', 'h')
def __init__(self, **kwargs):
# ... some initialization
# shapes are calculate from arguments
c_shape = (2, 3)
h_shape = (3, 4)
self.state_shapes = (c_shape, h_shape)
super(StatelessRNN, self).__init__(**kwargs)
def __call__(self, *args):
# compute something
# assume that the last argument is input
ret = args[:-1]
for r in ret:
r += args[-1]
return ret
def make_stateful_rnn(stateless_class, name):
class name_setter(type):
def __new__(cls, _, bases, dict):
return type.__new__(cls, name, bases, dict)
class Stateful(stateless_class):
__metaclass__ = name_setter
def __init__(self, **kwargs):
super(Stateful, self).__init__(**kwargs)
self.reset_state()
def reset_state(self):
self.states = [None] * len(self.state_names)
def __call__(self, x):
for i, val in enumerate(self.states):
if val is None:
self.states[i] = self.xp.zeros(
self.state_shapes[i], dtype=numpy.float32)
args = tuple(self.states) + (x,)
self.states = super(Stateful, self).__call__(*args)
return self.states[-1]
return Stateful
StatefulRNN = make_stateful_rnn(StatelessRNN, 'StatefulRNN')
print(StatefulRNN.__name__)
stateful_rnn = StatefulRNN()
print(stateful_rnn.__class__.__name__)
h = stateful_rnn(5)
print(h)
h = stateful_rnn(3)
print(h)
stateful_rnn.reset_state()
h = stateful_rnn(2)
print(h)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment