Skip to content

Instantly share code, notes, and snippets.

@SegFaultAX
Created December 18, 2018 03:03
Show Gist options
  • Save SegFaultAX/f7bf2adcc3346a02c5b4e0c8d3263540 to your computer and use it in GitHub Desktop.
Save SegFaultAX/f7bf2adcc3346a02c5b4e0c8d3263540 to your computer and use it in GitHub Desktop.
Simple IO monad [Python]
#!/usr/bin/env python
import functools
import typing as ty
from dataclasses import dataclass
A = ty.TypeVar("A")
B = ty.TypeVar("B")
@dataclass(frozen=True)
class IO(ty.Generic[A]):
pass
@dataclass
class IOPure(IO[A]):
a: A
@dataclass
class IOSuspend(IO[A]):
k: ty.Callable[[], IO[A]]
@dataclass
class IOBind(IO[A]):
v: IO[A]
f: ty.Callable[[A], IO[B]]
@dataclass
class IOException(IO[None]):
e: Exception
def fmap(f, io):
return bind(io, lambda v: pure(f(v)))
def pure(a):
return IOPure(a)
def app(iof, iov):
return bind(iof, lambda fn: bind(iov, lambda v: pure(fn(v))))
def bind(io, f):
return IOBind(io, f)
def lift(f):
return IOSuspend(f)
def wrap(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
return lift(lambda: pure(f(*args, **kwargs)))
return wrapper
def throw(e):
return IOException(e)
def step(io):
root = io
while True:
if isinstance(root, IOBind):
val = root.v
if isinstance(val, IOBind):
print(root, val)
root = bind(val.v, lambda v: bind(val.f(v), root.f))
elif isinstance(val, IOPure):
root = root.f(val.a)
else:
break
else:
break
return root
def unsafe_perform_io(io):
current = io
while True:
current = step(current)
if isinstance(current, IOPure):
return current.a
elif isinstance(current, IOSuspend):
current = current.k()
elif isinstance(current, IOBind):
current = current.f(unsafe_perform_io(current.v))
elif isinstance(current, IOException):
raise current.e
else:
raise RuntimeError(f"Got unexpected value: {current}")
ioprint = wrap(print)
ioinput = wrap(input)
p = bind(bind(ioinput("Derp"), ioprint), throw(Exception("err")))
print(p)
unsafe_perform_io(p)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment