Skip to content

Instantly share code, notes, and snippets.

@nakamuray
Created October 29, 2013 13:10
Show Gist options
  • Save nakamuray/7214332 to your computer and use it in GitHub Desktop.
Save nakamuray/7214332 to your computer and use it in GitHub Desktop.
state monad like feature for python (it's not monad at all, though)
import functools
import types
class State(object):
def __init__(self, gen, args, kwargs):
self._gen = gen
self._args = args
self._kwargs = kwargs
def run(self, state=None):
'''run with state, return result and final state
:param state: state of this execution
:type state: object
:return: (object, state)
'''
g = self._gen(*self._args, **self._kwargs)
if not isinstance(g, types.GeneratorType):
raise TypeError(
'{0} returns non generator object {1}'.format(self._gen, g))
r = next(g)
while True:
if isinstance(r, State):
(n, state) = r.run(state)
elif isinstance(r, _GetState):
n = state
elif isinstance(r, _SetState):
n = None
state = r.state
else:
# XXX: who are you?
raise Exception(r)
try:
r = g.send(n)
except StopIteration:
return (None, state)
except _ReturnValue as e:
if len(e.args) > 1:
return (e.args, state)
else:
return (e.args[0], state)
def eval(self, state=None):
'''run with state, only return result
Use this method if you don't care about final state.
:param state: state of this execution
:type state: object
:return: object
'''
return self.run(state)[0]
def with_state(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
return State(func, args, kwargs)
return wrapper
class _GetState(object):
pass
get_state = _GetState
class _SetState(object):
def __init__(self, state):
self.state = state
set_state = _SetState
class _ReturnValue(Exception):
pass
def return_value(*args):
raise _ReturnValue(*args)
def test():
@with_state
def add(i):
n = yield get_state()
yield set_state(n + i)
assert add(5).run(10) == (None, 15)
@with_state
def state_sum(*args):
yield set_state(0)
for i in args:
yield add(i)
result = yield get_state()
return_value(result)
assert state_sum(*range(10)).run() == (45, 45)
class MyState(object):
def __init__(self, name, age):
self.name = name
self.age = age
@with_state
def state_reader(attr_name):
st = yield get_state()
return_value(getattr(st, attr_name))
assert state_reader('name').eval(MyState('my name', None)) == 'my name'
@with_state
def get_name_and_age():
name = yield state_reader('name')
age = yield state_reader('age')
return_value(name, age)
assert get_name_and_age().eval(MyState('my name', 20)) == ('my name', 20)
@with_state
def normal_function():
return
try:
normal_function().eval(None)
assert False
except TypeError:
pass
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment