Created
December 17, 2018 08:48
-
-
Save SegFaultAX/1196f1522b672959debc882bf7e290df to your computer and use it in GitHub Desktop.
Simple Free Monad [Python]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import dataclasses as dc | |
import typing as ty | |
import inspect | |
import functools | |
S = ty.TypeVar("S") | |
A = ty.TypeVar("A") | |
B = ty.TypeVar("B") | |
@dc.dataclass(frozen=True) | |
class Monad: | |
fmap: ty.Callable | |
pure: ty.Callable | |
bind: ty.Callable | |
@dc.dataclass(frozen=True) | |
class Free(ty.Generic[S, A]): | |
pass | |
@dc.dataclass(frozen=True) | |
class Pure(Free[S, A]): | |
a: A | |
@dc.dataclass(frozen=True) | |
class Suspend(Free[S, A]): | |
k: S # S[Free[S, A]] | |
@dc.dataclass(frozen=True) | |
class FlatMap(Free[S, A]): | |
v: Free[S, A] | |
f: ty.Callable[[A], Free[S, B]] | |
def fmap(fn, free): | |
return FlatMap(free, lambda x: puref(fn(x))) | |
def pure(a): | |
return Pure(a) | |
def suspend(k): | |
return Suspend(k) | |
def bind(free, fn): | |
return FlatMap(free, fn) | |
MonadFree = Monad(fmap, pure, bind) | |
def match(free, if_pure, if_suspend, if_flatmap): | |
if isinstance(free, Pure): | |
return if_pure(free) | |
elif isinstance(free, Suspend): | |
return if_suspend(free) | |
else: | |
return if_flatmap(free) | |
def step(free): | |
root = free | |
while True: | |
if isinstance(root, FlatMap): | |
if isinstance(root.v, Pure): | |
root = root.f(root.v.a) | |
elif isinstance(root.v, FlatMap): | |
inner = root.v | |
root = bind(inner.v, lambda x: bind(inner.f(x), root.f)) | |
else: | |
break | |
else: | |
break | |
return root | |
def foldmap(free, natural, monad, tailrec): | |
def run1(x): | |
return match(x, | |
lambda pure: (True, pure.a), | |
lambda suspend: (True, natural(suspend.k)), | |
lambda flatmap: (False, flatmap.f(foldmap(flatmap.v, natural, monad, tailrec))) | |
) | |
return tailrec(free, run1) | |
def do(monad, inst=lambda e: True): | |
def binder(gen): | |
def step(value): | |
try: | |
result = gen.send(value) | |
return monad.bind(result, step) | |
except StopIteration as e: | |
return e.value if inst(e.value) else monad.pure(e.value) | |
return step | |
def decorator(fn): | |
def wrapper(*args, **kwargs): | |
gen = fn(*args, **kwargs) | |
if not inspect.isgenerator(gen): | |
return gen | |
return binder(gen)(None) | |
return wrapper | |
return decorator | |
def free(fn): | |
@functools.wraps(fn) | |
@do(MonadFree, lambda e: isinstance(e, Free)) | |
def wrapper(*args, **kwargs): | |
return fn(*args, **kwargs) | |
return wrapper | |
### Example ### | |
@dc.dataclass(frozen=True) | |
class ReadLine: | |
prompt: str | |
def readln(prompt): | |
return suspend(ReadLine(prompt)) | |
@dc.dataclass(frozen=True) | |
class PrintLine: | |
line: str | |
def println(line): | |
return suspend(PrintLine(line)) | |
MonadNullable = Monad( | |
lambda f, e: f(e) if e is not None else None, | |
lambda e: e, | |
lambda e, f: f(e) if e is not None else None) | |
def tailrec_nullable(val, step): | |
done, result = False, val | |
while not done: | |
done, result = step(result) | |
if result is None: | |
return None | |
return result | |
def handler(cmd): | |
if isinstance(cmd, ReadLine): | |
return input(cmd.prompt) | |
elif isinstance(cmd, PrintLine): | |
print(cmd.line) | |
return () | |
@free | |
def program1(): | |
name = yield readln("What is your name? ") | |
age = yield readln("What is your age? ") | |
yield println(f"Your name is {name} and you are {age} years old!") | |
return (name, age) | |
print(foldmap(program1(), handler, MonadNullable, tailrec_nullable)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment