Created
December 18, 2018 03:03
-
-
Save SegFaultAX/f7bf2adcc3346a02c5b4e0c8d3263540 to your computer and use it in GitHub Desktop.
Simple IO monad [Python]
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import functools | |
import typing as ty | |
from dataclasses import dataclass | |
A = ty.TypeVar("A") | |
B = ty.TypeVar("B") | |
@dataclass(frozen=True) | |
class IO(ty.Generic[A]): | |
pass | |
@dataclass | |
class IOPure(IO[A]): | |
a: A | |
@dataclass | |
class IOSuspend(IO[A]): | |
k: ty.Callable[[], IO[A]] | |
@dataclass | |
class IOBind(IO[A]): | |
v: IO[A] | |
f: ty.Callable[[A], IO[B]] | |
@dataclass | |
class IOException(IO[None]): | |
e: Exception | |
def fmap(f, io): | |
return bind(io, lambda v: pure(f(v))) | |
def pure(a): | |
return IOPure(a) | |
def app(iof, iov): | |
return bind(iof, lambda fn: bind(iov, lambda v: pure(fn(v)))) | |
def bind(io, f): | |
return IOBind(io, f) | |
def lift(f): | |
return IOSuspend(f) | |
def wrap(f): | |
@functools.wraps(f) | |
def wrapper(*args, **kwargs): | |
return lift(lambda: pure(f(*args, **kwargs))) | |
return wrapper | |
def throw(e): | |
return IOException(e) | |
def step(io): | |
root = io | |
while True: | |
if isinstance(root, IOBind): | |
val = root.v | |
if isinstance(val, IOBind): | |
print(root, val) | |
root = bind(val.v, lambda v: bind(val.f(v), root.f)) | |
elif isinstance(val, IOPure): | |
root = root.f(val.a) | |
else: | |
break | |
else: | |
break | |
return root | |
def unsafe_perform_io(io): | |
current = io | |
while True: | |
current = step(current) | |
if isinstance(current, IOPure): | |
return current.a | |
elif isinstance(current, IOSuspend): | |
current = current.k() | |
elif isinstance(current, IOBind): | |
current = current.f(unsafe_perform_io(current.v)) | |
elif isinstance(current, IOException): | |
raise current.e | |
else: | |
raise RuntimeError(f"Got unexpected value: {current}") | |
ioprint = wrap(print) | |
ioinput = wrap(input) | |
p = bind(bind(ioinput("Derp"), ioprint), throw(Exception("err"))) | |
print(p) | |
unsafe_perform_io(p) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment