Skip to content

Instantly share code, notes, and snippets.

@nakamuray
Created October 26, 2013 03:11
Show Gist options
  • Save nakamuray/7164897 to your computer and use it in GitHub Desktop.
Save nakamuray/7164897 to your computer and use it in GitHub Desktop.
[WIP] monad for python
from monad import *
from monad import Monad
__all__ = ['Either', 'Right', 'Left']
class Either(Monad):
def __init__(self, value):
if type(self) == Either:
raise Exception('abstract class')
self.value = value
def bind(self, monad_factory):
if isinstance(self, Right):
return monad_factory(self.value)
else:
return self
@classmethod
def pure(self, a):
return Right(a)
class Right(Either):
def __repr__(self):
return 'Right({0})'.format(repr(self.value))
class Left(Either):
def __repr__(self):
return 'Left({0})'.format(repr(self.value))
# vim: fileencoding=utf-8
# XXX: lazy ではなく末尾再帰の最適化な気もする
import functools
class Thunk(object):
def __init__(self, function, *args, **kwargs):
self.function = function
self.args = args
self.kwargs = kwargs
def eval(self):
thunk = self
while isinstance(thunk, Thunk):
thunk = thunk.function(*thunk.args, **thunk.kwargs)
return thunk
def eval(thunk):
if isinstance(thunk, Thunk):
return thunk.eval()
else:
return thunk
def lazy(function):
def wrapper(*args, **kwargs):
return Thunk(function, *args, **kwargs)
return functools.update_wrapper(wrapper, function)
def test():
def sum(ns):
@lazy
def _sum(ns, acc):
if len(ns) == 0:
return acc
else:
return _sum(ns[1:], acc + ns[0])
return _sum(ns, 0).eval()
print sum(range(1000))
if __name__ == '__main__':
test()
from monad import Monad, do
class Maybe(Monad):
def bind(self, monad_factory):
if isinstance(self, Just):
return monad_factory(self.value)
else:
return Nothing()
@classmethod
def pure(self, a):
return Just(a)
class Just(Maybe):
def __init__(self, value):
self.value = value
def __repr__(self):
return 'Just({0})'.format(repr(self.value))
class Nothing(Maybe):
def __repr__(self):
return 'Nothing'
def test():
return Just(10).bind( lambda x: (Just(20).bind( lambda y: Just(x + y) )) )
@do
def test2():
x = yield Just(10)
y = yield Just(20)
yield Just(x + y)
if __name__ == '__main__':
print test().value
print test2().value
# vim: fileencoding=utf-8
__all__ = ['Monad', 'do', 'Identity', 'sequence', 'sequence_', 'mapM', 'mapM_']
import functools
import itertools
class Monad(object):
def bind(self, monad_factory):
raise NotImplementedError
def bind_(self, monad_factory):
return self.bind(lambda _: monad_factory)
@classmethod
def pure(cls, a):
return cls(a)
def do(monad_generator):
'''do notation
>>> @do
... def test():
... x = yield Identity(10)
... y = yield Identity(x * 2)
... yield Identity(y + 1)
...
>>> test().runIdentity()
21
'''
def do_(*args, **kwargs):
g = monad_generator(*args, **kwargs)
m = g.next()
# FIXME: do_ の結果が複数回"実行"された場合、
# g は一度目の実行で最後まで走りきっているので、
# 二度目以降は常に StopIteration を返す。
# その結果、二度目以降の実行では最初の yield で返された
# monad のみが実行されることになる。
#
# StopIteration を catch した後、 g を再生成するような処理を
# 入れるという手があるが、それは monad が実行されるたびに
# generator が必ず最後まで走り切ると仮定しての話であり、
# そして必ずしもそうではないような気がしている。
# そして、 thread unsafe な気がしている。
return m.bind(lambda x: go(x, g, m))
def go(x, g, m_):
try:
m = g.send(x)
return m.bind(lambda x: go(x, g, m))
except StopIteration:
return m_
return functools.update_wrapper(do_, monad_generator)
class Identity(Monad):
a = None
def __init__(self, a):
self.a = a
def runIdentity(self):
return self.a
def bind(self, monad_factory):
return monad_factory(self.a)
# sequence :: Monad m => [m a] -> m [a]
@do
def sequence(monads):
# XXX: is there a way to return iterator instead of list?
result = []
m = None
for ma in monads:
if m is None:
m = type(ma)
a = yield ma
result.append(a)
if m is None:
m = Monad
yield m.pure(result)
# sequence_ :: Monad m => [m a] -> m ()
@do
def sequence_(monads):
m = None
for ma in monads:
if m is None:
m = type(ma)
yield ma
if m is None:
m = Monad
yield m.pure(None)
def mapM(f, monads):
return sequence(itertools.imap(f, monads))
def mapM_(f, monads):
return sequence_(itertools.imap(f, monads))
from monad import do
from state_lazy import State
from either import *
from lazy import lazy
class Parsec(State):
def bind(self, parser_factory):
cls = type(self)
def runState(s):
a, s_ = self.runState(s).eval()
if isinstance(a, Right):
return parser_factory(a.value).runState(s_)
else:
return cls.fail(a.value).runState(s_)
return cls(runState)
@classmethod
def pure(cls, a):
return super(Parsec, cls).pure(Right(a))
@classmethod
def fail(cls, msg):
return super(Parsec, cls).pure(Left(msg))
@classmethod
def get(cls):
def runState(s):
return (Either.pure(s), s)
return cls(runState)
@classmethod
def put(cls, s):
def runState(_):
return (Either.pure(None), s)
return cls(runState)
@classmethod
def modify(cls, f):
def runState(s):
s = f(s)
return (Either.pure(None), s)
return cls(runState)
def try_(self):
def runState(s):
a, s_ = self.runState(s).eval()
if isinstance(a, Left):
# revert to old state
s_ = s
return (a, s_)
return type(self)(runState)
def or_(self, other):
cls = type(self)
def runState(s):
a, s_ = self.runState(s).eval()
if isinstance(a, Left):
a, s_ = other.runState(s_).eval()
return (a, s_)
return type(self)(runState)
get = Parsec.get
put = Parsec.put
modify = Parsec.modify
@do
def char(c):
state = yield get()
if state[0] == c:
yield put(state[1:])
yield Parsec.pure(c)
else:
yield Parsec.fail('"{0}" is expected, but "{1}" come'.format(c, state[0]))
@do
def string(s):
state = yield get()
if state.startswith(s):
yield put(state[len(s):])
yield Parsec.pure(s)
else:
yield Parsec.fail('"{0}" is expected, but "{1}"... come'.format(s, state[:len(s)]))
def many(p):
def runState(s):
result = []
p_try = p.try_()
while True:
a, s = p_try.runState(s).eval()
if isinstance(a, Left):
return (Right(result), s)
else:
result.append(a.value)
return Parsec(runState)
def test():
print string('hello').runState('hello world').eval()
print string('hello').runState('hi, world').eval()
print string('hello').or_(string('hi')).runState('hi, world').eval()
print char('h').bind_(char('i')).or_(string('hello')).runState('hello world').eval()
print char('h').bind_(char('i')).try_().or_(string('hello')).runState('hello world').eval()
c = char('x')
print c.runState('xxx').eval()
# FIXME:
print c.runState('xxx').eval()
#print many(char('x')).runState('xxx!').eval()
if __name__ == '__main__':
test()
from monad import Monad, do
class State(Monad):
# runState :: s -> (a, s)
runState = None
def __init__(self, runState):
self.runState = runState
def bind(self, state_factory):
def runState(s):
a, s_ = self.runState(s)
return state_factory(a).runState(s_)
return type(self)(runState)
@classmethod
def pure(cls, a):
def runState(s):
return (a, s)
return cls(runState)
def evalState(self, s):
return self.runState(s)[0]
@classmethod
def get(cls):
def runState(s):
return (s, s)
return cls(runState)
@classmethod
def put(cls, s):
def runState(_):
return (None, s)
return cls(runState)
@classmethod
def modify(cls, f):
def runState(s):
s = f(s)
return (None, s)
return cls(runState)
get = State.get
put = State.put
modify = State.modify
def test():
def add(n):
return modify(lambda s: s + n)
print add(1).bind_(add(1)).bind_(add(1)).bind_(get()).runState(0)
@do
def test2():
yield add(1)
yield add(1)
yield add(1)
yield get()
print test2().runState(0)
@do
def stateSum(ns):
for n in ns:
yield modify(lambda x: x + n)
yield get()
print stateSum(range(1, 11)).evalState(0)
#print stateSum(range(10000)).evalState(0)
if __name__ == '__main__':
test()
from lazy import lazy
from monad import Monad, do
from state import State as StateStrict
__all__ = ['State', 'get', 'put', 'modify']
class State(StateStrict):
def __init__(self, runState):
self.runState = lazy(runState)
def bind(self, state_factory):
def runState(s):
a, s_ = self.runState(s).eval()
return state_factory(a).runState(s_)
return type(self)(runState)
def evalState(self, s):
return self.runState(s).eval()[0]
get = State.get
put = State.put
modify = State.modify
def test():
def add(n):
return modify(lambda s: s + n)
@do
def stateSum(ns):
for n in ns:
yield add(n)
yield get()
print stateSum(range(1, 11)).evalState(0)
print stateSum(range(10000)).evalState(0)
s = stateSum(range(1, 11))
print s.evalState(0)
# FIXME:
print s.evalState(0)
if __name__ == '__main__':
test()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment