Created
October 29, 2013 13:10
-
-
Save nakamuray/7214332 to your computer and use it in GitHub Desktop.
state monad like feature for python (it's not monad at all, though)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import functools | |
import types | |
class State(object): | |
def __init__(self, gen, args, kwargs): | |
self._gen = gen | |
self._args = args | |
self._kwargs = kwargs | |
def run(self, state=None): | |
'''run with state, return result and final state | |
:param state: state of this execution | |
:type state: object | |
:return: (object, state) | |
''' | |
g = self._gen(*self._args, **self._kwargs) | |
if not isinstance(g, types.GeneratorType): | |
raise TypeError( | |
'{0} returns non generator object {1}'.format(self._gen, g)) | |
r = next(g) | |
while True: | |
if isinstance(r, State): | |
(n, state) = r.run(state) | |
elif isinstance(r, _GetState): | |
n = state | |
elif isinstance(r, _SetState): | |
n = None | |
state = r.state | |
else: | |
# XXX: who are you? | |
raise Exception(r) | |
try: | |
r = g.send(n) | |
except StopIteration: | |
return (None, state) | |
except _ReturnValue as e: | |
if len(e.args) > 1: | |
return (e.args, state) | |
else: | |
return (e.args[0], state) | |
def eval(self, state=None): | |
'''run with state, only return result | |
Use this method if you don't care about final state. | |
:param state: state of this execution | |
:type state: object | |
:return: object | |
''' | |
return self.run(state)[0] | |
def with_state(func): | |
@functools.wraps(func) | |
def wrapper(*args, **kwargs): | |
return State(func, args, kwargs) | |
return wrapper | |
class _GetState(object): | |
pass | |
get_state = _GetState | |
class _SetState(object): | |
def __init__(self, state): | |
self.state = state | |
set_state = _SetState | |
class _ReturnValue(Exception): | |
pass | |
def return_value(*args): | |
raise _ReturnValue(*args) | |
def test(): | |
@with_state | |
def add(i): | |
n = yield get_state() | |
yield set_state(n + i) | |
assert add(5).run(10) == (None, 15) | |
@with_state | |
def state_sum(*args): | |
yield set_state(0) | |
for i in args: | |
yield add(i) | |
result = yield get_state() | |
return_value(result) | |
assert state_sum(*range(10)).run() == (45, 45) | |
class MyState(object): | |
def __init__(self, name, age): | |
self.name = name | |
self.age = age | |
@with_state | |
def state_reader(attr_name): | |
st = yield get_state() | |
return_value(getattr(st, attr_name)) | |
assert state_reader('name').eval(MyState('my name', None)) == 'my name' | |
@with_state | |
def get_name_and_age(): | |
name = yield state_reader('name') | |
age = yield state_reader('age') | |
return_value(name, age) | |
assert get_name_and_age().eval(MyState('my name', 20)) == ('my name', 20) | |
@with_state | |
def normal_function(): | |
return | |
try: | |
normal_function().eval(None) | |
assert False | |
except TypeError: | |
pass | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment