Created
October 26, 2013 03:11
-
-
Save nakamuray/7164897 to your computer and use it in GitHub Desktop.
[WIP] monad for python
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from monad import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from monad import Monad | |
__all__ = ['Either', 'Right', 'Left'] | |
class Either(Monad): | |
def __init__(self, value): | |
if type(self) == Either: | |
raise Exception('abstract class') | |
self.value = value | |
def bind(self, monad_factory): | |
if isinstance(self, Right): | |
return monad_factory(self.value) | |
else: | |
return self | |
@classmethod | |
def pure(self, a): | |
return Right(a) | |
class Right(Either): | |
def __repr__(self): | |
return 'Right({0})'.format(repr(self.value)) | |
class Left(Either): | |
def __repr__(self): | |
return 'Left({0})'.format(repr(self.value)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# vim: fileencoding=utf-8 | |
# XXX: lazy ではなく末尾再帰の最適化な気もする | |
import functools | |
class Thunk(object): | |
def __init__(self, function, *args, **kwargs): | |
self.function = function | |
self.args = args | |
self.kwargs = kwargs | |
def eval(self): | |
thunk = self | |
while isinstance(thunk, Thunk): | |
thunk = thunk.function(*thunk.args, **thunk.kwargs) | |
return thunk | |
def eval(thunk): | |
if isinstance(thunk, Thunk): | |
return thunk.eval() | |
else: | |
return thunk | |
def lazy(function): | |
def wrapper(*args, **kwargs): | |
return Thunk(function, *args, **kwargs) | |
return functools.update_wrapper(wrapper, function) | |
def test(): | |
def sum(ns): | |
@lazy | |
def _sum(ns, acc): | |
if len(ns) == 0: | |
return acc | |
else: | |
return _sum(ns[1:], acc + ns[0]) | |
return _sum(ns, 0).eval() | |
print sum(range(1000)) | |
if __name__ == '__main__': | |
test() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from monad import Monad, do | |
class Maybe(Monad): | |
def bind(self, monad_factory): | |
if isinstance(self, Just): | |
return monad_factory(self.value) | |
else: | |
return Nothing() | |
@classmethod | |
def pure(self, a): | |
return Just(a) | |
class Just(Maybe): | |
def __init__(self, value): | |
self.value = value | |
def __repr__(self): | |
return 'Just({0})'.format(repr(self.value)) | |
class Nothing(Maybe): | |
def __repr__(self): | |
return 'Nothing' | |
def test(): | |
return Just(10).bind( lambda x: (Just(20).bind( lambda y: Just(x + y) )) ) | |
@do | |
def test2(): | |
x = yield Just(10) | |
y = yield Just(20) | |
yield Just(x + y) | |
if __name__ == '__main__': | |
print test().value | |
print test2().value |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# vim: fileencoding=utf-8 | |
__all__ = ['Monad', 'do', 'Identity', 'sequence', 'sequence_', 'mapM', 'mapM_'] | |
import functools | |
import itertools | |
class Monad(object): | |
def bind(self, monad_factory): | |
raise NotImplementedError | |
def bind_(self, monad_factory): | |
return self.bind(lambda _: monad_factory) | |
@classmethod | |
def pure(cls, a): | |
return cls(a) | |
def do(monad_generator): | |
'''do notation | |
>>> @do | |
... def test(): | |
... x = yield Identity(10) | |
... y = yield Identity(x * 2) | |
... yield Identity(y + 1) | |
... | |
>>> test().runIdentity() | |
21 | |
''' | |
def do_(*args, **kwargs): | |
g = monad_generator(*args, **kwargs) | |
m = g.next() | |
# FIXME: do_ の結果が複数回"実行"された場合、 | |
# g は一度目の実行で最後まで走りきっているので、 | |
# 二度目以降は常に StopIteration を返す。 | |
# その結果、二度目以降の実行では最初の yield で返された | |
# monad のみが実行されることになる。 | |
# | |
# StopIteration を catch した後、 g を再生成するような処理を | |
# 入れるという手があるが、それは monad が実行されるたびに | |
# generator が必ず最後まで走り切ると仮定しての話であり、 | |
# そして必ずしもそうではないような気がしている。 | |
# そして、 thread unsafe な気がしている。 | |
return m.bind(lambda x: go(x, g, m)) | |
def go(x, g, m_): | |
try: | |
m = g.send(x) | |
return m.bind(lambda x: go(x, g, m)) | |
except StopIteration: | |
return m_ | |
return functools.update_wrapper(do_, monad_generator) | |
class Identity(Monad): | |
a = None | |
def __init__(self, a): | |
self.a = a | |
def runIdentity(self): | |
return self.a | |
def bind(self, monad_factory): | |
return monad_factory(self.a) | |
# sequence :: Monad m => [m a] -> m [a] | |
@do | |
def sequence(monads): | |
# XXX: is there a way to return iterator instead of list? | |
result = [] | |
m = None | |
for ma in monads: | |
if m is None: | |
m = type(ma) | |
a = yield ma | |
result.append(a) | |
if m is None: | |
m = Monad | |
yield m.pure(result) | |
# sequence_ :: Monad m => [m a] -> m () | |
@do | |
def sequence_(monads): | |
m = None | |
for ma in monads: | |
if m is None: | |
m = type(ma) | |
yield ma | |
if m is None: | |
m = Monad | |
yield m.pure(None) | |
def mapM(f, monads): | |
return sequence(itertools.imap(f, monads)) | |
def mapM_(f, monads): | |
return sequence_(itertools.imap(f, monads)) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from monad import do | |
from state_lazy import State | |
from either import * | |
from lazy import lazy | |
class Parsec(State): | |
def bind(self, parser_factory): | |
cls = type(self) | |
def runState(s): | |
a, s_ = self.runState(s).eval() | |
if isinstance(a, Right): | |
return parser_factory(a.value).runState(s_) | |
else: | |
return cls.fail(a.value).runState(s_) | |
return cls(runState) | |
@classmethod | |
def pure(cls, a): | |
return super(Parsec, cls).pure(Right(a)) | |
@classmethod | |
def fail(cls, msg): | |
return super(Parsec, cls).pure(Left(msg)) | |
@classmethod | |
def get(cls): | |
def runState(s): | |
return (Either.pure(s), s) | |
return cls(runState) | |
@classmethod | |
def put(cls, s): | |
def runState(_): | |
return (Either.pure(None), s) | |
return cls(runState) | |
@classmethod | |
def modify(cls, f): | |
def runState(s): | |
s = f(s) | |
return (Either.pure(None), s) | |
return cls(runState) | |
def try_(self): | |
def runState(s): | |
a, s_ = self.runState(s).eval() | |
if isinstance(a, Left): | |
# revert to old state | |
s_ = s | |
return (a, s_) | |
return type(self)(runState) | |
def or_(self, other): | |
cls = type(self) | |
def runState(s): | |
a, s_ = self.runState(s).eval() | |
if isinstance(a, Left): | |
a, s_ = other.runState(s_).eval() | |
return (a, s_) | |
return type(self)(runState) | |
get = Parsec.get | |
put = Parsec.put | |
modify = Parsec.modify | |
@do | |
def char(c): | |
state = yield get() | |
if state[0] == c: | |
yield put(state[1:]) | |
yield Parsec.pure(c) | |
else: | |
yield Parsec.fail('"{0}" is expected, but "{1}" come'.format(c, state[0])) | |
@do | |
def string(s): | |
state = yield get() | |
if state.startswith(s): | |
yield put(state[len(s):]) | |
yield Parsec.pure(s) | |
else: | |
yield Parsec.fail('"{0}" is expected, but "{1}"... come'.format(s, state[:len(s)])) | |
def many(p): | |
def runState(s): | |
result = [] | |
p_try = p.try_() | |
while True: | |
a, s = p_try.runState(s).eval() | |
if isinstance(a, Left): | |
return (Right(result), s) | |
else: | |
result.append(a.value) | |
return Parsec(runState) | |
def test(): | |
print string('hello').runState('hello world').eval() | |
print string('hello').runState('hi, world').eval() | |
print string('hello').or_(string('hi')).runState('hi, world').eval() | |
print char('h').bind_(char('i')).or_(string('hello')).runState('hello world').eval() | |
print char('h').bind_(char('i')).try_().or_(string('hello')).runState('hello world').eval() | |
c = char('x') | |
print c.runState('xxx').eval() | |
# FIXME: | |
print c.runState('xxx').eval() | |
#print many(char('x')).runState('xxx!').eval() | |
if __name__ == '__main__': | |
test() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from monad import Monad, do | |
class State(Monad): | |
# runState :: s -> (a, s) | |
runState = None | |
def __init__(self, runState): | |
self.runState = runState | |
def bind(self, state_factory): | |
def runState(s): | |
a, s_ = self.runState(s) | |
return state_factory(a).runState(s_) | |
return type(self)(runState) | |
@classmethod | |
def pure(cls, a): | |
def runState(s): | |
return (a, s) | |
return cls(runState) | |
def evalState(self, s): | |
return self.runState(s)[0] | |
@classmethod | |
def get(cls): | |
def runState(s): | |
return (s, s) | |
return cls(runState) | |
@classmethod | |
def put(cls, s): | |
def runState(_): | |
return (None, s) | |
return cls(runState) | |
@classmethod | |
def modify(cls, f): | |
def runState(s): | |
s = f(s) | |
return (None, s) | |
return cls(runState) | |
get = State.get | |
put = State.put | |
modify = State.modify | |
def test(): | |
def add(n): | |
return modify(lambda s: s + n) | |
print add(1).bind_(add(1)).bind_(add(1)).bind_(get()).runState(0) | |
@do | |
def test2(): | |
yield add(1) | |
yield add(1) | |
yield add(1) | |
yield get() | |
print test2().runState(0) | |
@do | |
def stateSum(ns): | |
for n in ns: | |
yield modify(lambda x: x + n) | |
yield get() | |
print stateSum(range(1, 11)).evalState(0) | |
#print stateSum(range(10000)).evalState(0) | |
if __name__ == '__main__': | |
test() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from lazy import lazy | |
from monad import Monad, do | |
from state import State as StateStrict | |
__all__ = ['State', 'get', 'put', 'modify'] | |
class State(StateStrict): | |
def __init__(self, runState): | |
self.runState = lazy(runState) | |
def bind(self, state_factory): | |
def runState(s): | |
a, s_ = self.runState(s).eval() | |
return state_factory(a).runState(s_) | |
return type(self)(runState) | |
def evalState(self, s): | |
return self.runState(s).eval()[0] | |
get = State.get | |
put = State.put | |
modify = State.modify | |
def test(): | |
def add(n): | |
return modify(lambda s: s + n) | |
@do | |
def stateSum(ns): | |
for n in ns: | |
yield add(n) | |
yield get() | |
print stateSum(range(1, 11)).evalState(0) | |
print stateSum(range(10000)).evalState(0) | |
s = stateSum(range(1, 11)) | |
print s.evalState(0) | |
# FIXME: | |
print s.evalState(0) | |
if __name__ == '__main__': | |
test() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment