Skip to content

Instantly share code, notes, and snippets.

@agrif
Last active December 15, 2015 07:09
Show Gist options
  • Select an option

  • Save agrif/5221674 to your computer and use it in GitHub Desktop.

Select an option

Save agrif/5221674 to your computer and use it in GitHub Desktop.
from functools import wraps
class Monad:
def bind(self, fn):
raise NotImplementedError
@classmethod
def pure(cls, val):
raise NotImplementedError
@classmethod
def fail(cls, e):
raise e
class ListMonad(Monad):
def __init__(self, *vals):
self._vals = vals
def __repr__(self):
return "ListMonad" + repr(self._vals)
def bind(self, fn):
res = []
for val in self._vals:
res += fn(val)._vals
return self.__class__(*res)
@classmethod
def pure(cls, val):
return cls(val)
@classmethod
def fail(cls, e):
return cls()
class Maybe(Monad):
def __init__(self, val):
self._has_val = True
self._val = val
@classmethod
def nothing(cls):
o = cls(None)
o._has_val = False
return o
def __repr__(self):
if self._has_val:
return "Maybe(" + repr(self._val) + ")"
return "Maybe.nothing()"
def bind(self, fn):
if not self._has_val:
return self
return fn(self._val)
@classmethod
def pure(cls, val):
return cls(val)
@classmethod
def fail(cls, e):
return cls.nothing()
def _general_do_step(gen_factory, prev_vals, monad_val, wrapper=lambda x: x):
def _general_do_fn(v):
gen = gen_factory()
for val in prev_vals:
gen.send(val)
try:
next_monad_val = wrapper(gen.send(v))
except StopIteration as e:
ret = getattr(e, 'value', None)
if ret is None:
ret = monad_val
else:
ret = monad_val.__class__.pure(ret)
return ret
except Exception as e:
return monad_val.__class__.fail(e)
return _general_do_step(gen_factory, prev_vals + [v], next_monad_val, wrapper)
return monad_val.bind(_general_do_fn)
def do(f):
monad_val = next(f())
return _general_do_step(f, [None], monad_val)
def callable_do(f):
@wraps(f)
def _callable_do(*args, **kwargs):
gen_factory = lambda: f(*args, **kwargs)
monad_val = next(gen_factory())
return _general_do_step(gen_factory, [None], monad_val)
return _callable_do
class ImplicitMonad(Monad):
@classmethod
def do(cls, f):
wrapper = cls.wrap
monad_val = wrapper(next(f()))
return _general_do_step(f, [None], monad_val, wrapper).unwrap()
@classmethod
def callable_do(cls, f):
@wraps(f)
def _callable_do(*args, **kwargs):
gen_factory = lambda: f(*args, **kwargs)
wrapper = cls.wrap
monad_val = wrapper(next(gen_factory()))
return _general_do_step(gen_factory, [None], monad_val, wrapper).unwrap()
return _callable_do
@classmethod
def wrap(cls, val):
raise NotImplementedError
def unwrap(self):
raise NotImplementedError
class ImplicitListMonad(ListMonad, ImplicitMonad):
@classmethod
def wrap(cls, val):
return cls(*val)
def unwrap(self):
return list(self._vals)
class ImplicitMaybe(Maybe, ImplicitMonad):
@classmethod
def wrap(cls, val):
if val is None:
return cls.nothing()
return cls(val)
def unwrap(self):
if not self._has_val:
return None
return self._val
class Continuation(ImplicitMonad):
def __init__(self, fn=lambda c: c):
self._fn = fn
def run(self, last_cont=lambda c: c):
return self._fn(last_cont)
def bind(self, fn):
return self.__class__(lambda c: self._fn(lambda a: fn(a)._fn(c)))
@classmethod
def pure(cls, val):
return cls(lambda c: c(val))
@classmethod
def wrap(cls, val):
if callable(val):
return cls(val)
else:
return cls.pure(val)
def unwrap(self):
return self.run
##
## Explicit Examples
##
@do
def test():
a = yield ListMonad(1, 2, 3)
b = yield ListMonad(5, 6, 7)
return (a, b)
print(test)
# ListMonad((1, 5), (1, 6), (1, 7), (2, 5), (2, 6), (2, 7), (3, 5), (3, 6), (3, 7))
# (almost) equivalent to
test = ListMonad(1, 2, 3).bind(
lambda a:
ListMonad(5, 6, 7).bind(
lambda b:
ListMonad.pure((a, b))))
print(test)
@do
def test():
a = yield ListMonad(1, 2, 3)
b = yield ListMonad(5, 6, 7)
assert (a + b) % 2 == 0
return (a, b)
print(test)
# ListMonad((1, 5), (1, 7), (2, 6), (3, 5), (3, 7))
@callable_do
def test(a_m, b_m):
a = yield a_m
b = yield b_m
return a + b
print(test(Maybe(1), Maybe(2))) # Maybe(3)
print(test(Maybe(1), Maybe.nothing())) # Maybe.nothing()
##
## Implicit Versions
##
@ImplicitListMonad.do
def test():
a = yield [1, 2, 3]
b = yield [4, 5, 6]
return (a, b)
print(test)
# [(1, 4), (1, 5), (1, 6), (2, 4), (2, 5), (2, 6), (3, 4), (3, 5), (3, 6)]
@ImplicitListMonad.do
def test():
a = yield [1, 2, 3]
b = yield [4, 5, 6]
assert (a + b) % 2 == 0
return (a, b)
print(test)
# [(1, 5), (2, 4), (2, 6), (3, 5)]
@ImplicitMaybe.callable_do
def test(m_a, m_b):
a = yield m_a
b = yield m_b
return a + b
print(test(1, 2)) # 3
print(test(1, None)) # None
##
## Continuations
##
@Continuation.do
def test():
x = 1
y = yield lambda c: [c(1), c(2)]
yield lambda c: [c(None), "hello"]
return x + y
test(print)
# outputs:
# 2
# 3
print(test())
# outputs:
# [[2, 'hello'], [3, 'hello']]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment