Last active
April 1, 2020 18:26
-
-
Save felko/f831f59c014feed4e8228582ea6ad36e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3.7 | |
# coding: utf-8 | |
import types | |
import functools | |
import abc | |
import typing | |
from contextlib import contextmanager | |
from collections import OrderedDict | |
def _to_tuple(val): | |
if isinstance(val, tuple): | |
return val | |
else: | |
return (val,) | |
def _from_tuple(val): | |
if len(val) == 1: | |
return val[0] | |
else: | |
return val | |
class PatternMatchError(ValueError): | |
pass | |
class Pattern: | |
def __rmatmul__(self, name): | |
return As(name, self) | |
def __mul__(self, pat): | |
return Tup(self, pat) | |
def __and__(self, pat): | |
return And(self, pat) | |
def __or__(self, pat): | |
return Or(self, pat) | |
def match(self, val, env): | |
raise NotImplementedError | |
class Val(Pattern): | |
__slots__ = ['expected_val'] | |
def __init__(self, val): | |
self.expected_val = val | |
def match(self, val, env): | |
if val != self.expected_val: | |
raise PatternMatchError(f"Expected {val} to be {self.expected_val}") | |
class Is(Pattern): | |
__slots__ = ['typ'] | |
def __init__(self, typ): | |
self.typ = typ | |
def match(self, val, env): | |
if not isinstance(val, self.typ): | |
raise PatternMatchError(f"Expected {self.typ}, got {val} of type {type(val)}") | |
class Alt(Pattern): | |
__slots__ = ['pats'] | |
def __init__(self, *pats): | |
self.pats = list(pats) | |
def __ror__(self, lhs): | |
return Alt(lhs, *self.pats) | |
def __or__(self, rhs): | |
pats = self.pats + [rhs] | |
return Alt(*pats) | |
def match(self, val, env): | |
for pat in self.pats: | |
new_env = env.copy() | |
try: | |
pat.match(val, new_env) | |
except PatternMatchError: | |
continue | |
else: | |
env.update(new_env) | |
return | |
raise PatternMatchError(f"Non exhaustive pattern match, missing case: {val}") | |
class Nil(Alt): | |
def __init__(self): | |
super().__init__() | |
class And(Pattern): | |
def __init__(self, *pats): | |
self.pats = list(pats) | |
def __rand__(self, lhs): | |
return And(lhs, *self.pats) | |
def __and__(self, rhs): | |
pats = self.pats + [rhs] | |
return And(*pats) | |
def match(self, val, env): | |
for pat in self.pats: | |
pat.match(val, env) | |
class Any(Pattern): | |
def match(self, val, env): | |
pass | |
class Var(Pattern): | |
def __init__(self, name): | |
self.name = name | |
def match(self, val, env): | |
if self.name not in env: | |
env[self.name] = val | |
elif env[self.name] is not val: | |
raise PatternMatchError(f"Variable pattern failed to satisfy {self.name} = {env[self.name]} != {val}") | |
class As(Var): | |
def __init__(self, name, pat): | |
super().__init__(name) | |
self.pat = pat | |
def match(self, val, env): | |
super().match(val, env) | |
pat.match(env) | |
class View(Pattern): | |
def __init__(self, f, pat): | |
self.f = f | |
self.pat = pat | |
def match(self, val, env): | |
self.pat.match(self.f(val), env) | |
class Tup(Pattern): | |
def __init__(self, *pats): | |
self.pats = list(pats) | |
def __rmul__(self, lhs): | |
return Tup(lhs, *self.pats) | |
def __mul__(self, rhs): | |
pats = self.pats + [rhs] | |
return Tup(*pats) | |
def match(self, vals, env): | |
if isinstance(vals, tuple) and len(vals) == len(self.pats): | |
for pat, val in zip(self.pats, vals): | |
pat.match(val, env) | |
else: | |
raise PatternMatchError(f"Cannot match tuple {vals}, expected {len(self.pats)} elements") | |
class Constr(Pattern): | |
def __init__(self, prism, args): | |
self.prism = prism | |
self.args = _to_tuple(args) | |
def match(self, val, env): | |
vals = self.prism.preview(val) | |
if isinstance(self.prism, Case) and isinstance(val, self.prism.cls) and val._case is self.prism: | |
pass | |
elif isinstance(self.prism, Case): | |
raise PatternMatchError(f"Expected {self.prism.name}, got {val}") | |
if isinstance(vals, tuple) and len(vals) == len(self.args): | |
for pat, val in zip(self.args, vals): | |
pat.match(val, env) | |
else: | |
raise PatternMatchError(f"Expected {self.prism.review.__name__}, got {val}") | |
class Prism: | |
def __init__(self, review, preview=None): | |
self.review = review | |
self.preview = preview | |
def __lshift__(self, args): | |
return Constr(self, _to_tuple(args)) | |
def __call__(self, *args): | |
return self.review(*args) | |
def unwrap(self, p): | |
self.preview = p | |
class Case(Prism): | |
def __init__(self, check, name=None): | |
@functools.wraps(check) | |
def _review_wrapper(*args, **kwargs): | |
check(*args, **kwargs) | |
return self.cls(self, args) | |
super().__init__(_review_wrapper, lambda obj: obj._args) | |
self.cls = None | |
self.name = name or check.__name__ | |
self.check = check | |
def match(self, val, env): | |
if isinstance(val, self.cls): | |
return super().match(val, env) | |
else: | |
raise PatternMatchError(f"Expected value of type {self.cls}, got {type(val)}") | |
class ADTMeta(abc.ABCMeta): | |
_ADTBase = None | |
def __new__(mcs, name, bases, attrs, renaming=()): | |
if name == 'ADT' and mcs._ADTBase is None: | |
mcs._ADTBase = super().__new__(mcs, name, (), attrs) | |
return mcs._ADTBase | |
new_bases = [] | |
cls = super().__new__(mcs, name, (mcs._ADTBase,), attrs) | |
for b in bases: | |
if isinstance(b, ADTMeta) and b is not mcs._ADTBase: | |
if b.__bases__ == (mcs._ADTBase,): | |
b.__bases__ = (cls,) | |
else: | |
print(cls, b) | |
b.__bases__ += (cls,) | |
else: | |
new_bases.append(b) | |
cls.__bases__ = tuple(new_bases) or (mcs._ADTBase,) | |
return cls | |
def __prepare__(mcs, bases, renaming=()): | |
renaming = dict(renaming) | |
env = {} | |
for b in bases: | |
if isinstance(b, ADTMeta): | |
for name, attr in b.__dict__.items(): | |
if isinstance(attr, Case): | |
n = renaming.get(attr.name, attr.name) | |
env[n] = attr | |
env[attr.name] = attr | |
elif isinstance(attr, _CaseMethod): | |
env[name] = _CaseMethod(attr._default, cases=attr._cases) | |
return env | |
def __init__(cls, name, bases, attrs, renaming=()): | |
cls.__renaming__ = dict(renaming) | |
if name == 'ADT': | |
cls.__traits__ = () | |
cls.__supers__ = () | |
cls.__constrs__ = () | |
super().__init__(name, (), attrs) | |
return | |
constrs = [] | |
for attr in attrs.values(): | |
if isinstance(attr, Case): | |
attr.cls = cls | |
constrs.append(attr) | |
elif isinstance(attr, _CaseMethod): | |
attr._cls = cls | |
cls.__constrs__ = tuple(constrs) | |
traits = [] | |
for b in bases: | |
if isinstance(b, ADTMeta): | |
traits.extend(b.__traits__) | |
elif issubclass(b, Trait): | |
traits.append(b) | |
cls.__traits__ = tuple(traits) | |
cls.__supers__ = tuple(b for b in bases if isinstance(b, ADTMeta)) | |
super().__init__(name, cls.__bases__, attrs) | |
def _propagate_new_superclass(cls, new): | |
bs = list(cls.__supers__) | |
bs.remove(cls._ADTBase) | |
for base in bs: | |
if isinstance(base, ADTMeta): | |
base._propagate_new_superclass(cls) | |
cls.__bases__ = tuple(bs) + (new,) | |
class ADT(metaclass=ADTMeta): | |
def __init__(self, *args): | |
if len(args) == 0: | |
raise ValueError("Expected value or case/arguments pair") | |
elif len(args) == 1: | |
val, = args | |
if isinstance(val, type(self)): | |
self._case = val._case | |
self._args = val._args | |
else: | |
raise TypeError(f"Expected a subtype of {type(self)}, got {type(val)}") | |
elif len(args) == 2: | |
case, args = args | |
self._case = case | |
self._args = tuple(args) | |
else: | |
raise TypeError(f"Too many arguments, expected 1 or 2") | |
def __repr__(self): | |
name = type(self).__renaming__.get(self._case.name, self._case.name) | |
if self._args: | |
return f"<{name} {' '.join(map(repr, self._args))}>" | |
else: | |
return f"<{name}>" | |
def __getattribute__(self, attr): | |
d = type(self).__dict__ | |
if attr in d and isinstance(d[attr], _CaseMethod): | |
return _BoundCaseMethod(d[attr], self) | |
else: | |
return super(ADT, self).__getattribute__(attr) | |
def case(self, cases): | |
for pat, branch in OrderedDict(cases).items(): | |
try: | |
env = {} | |
pat.match(self, env) | |
except PatternMatchError: | |
continue | |
else: | |
return branch(**env) | |
raise PatternMatchError("No match") | |
def Wrapper(f): | |
return ADTMeta(f.__name__, (ADT,), {f.__name__: Case(f.__name__, f)}) | |
@contextmanager | |
def match(val, pat, exc=None): | |
env = OrderedDict() | |
try: | |
pat.match(val, env) | |
except PatternMatchError: | |
if exc is not None: | |
raise exc from None | |
raise | |
else: | |
yield _from_tuple(tuple(env.values())) | |
class _CaseFunction(typing.Callable): | |
def __init__(self, f, cases=()): | |
self._default = f | |
self._cases = OrderedDict(cases) | |
def __call__(self, *args): | |
for pats, branch in self._cases.items(): | |
if len(args) == len(pats): | |
env = {} | |
try: | |
for pat, val in zip(pats, vals): | |
pat.match(val, env) | |
except PatternMatchError: | |
continue | |
return branch(**env) | |
else: | |
raise TypeError(f"Expected {len(pats)} arguments, got {len(args)}") | |
try: | |
raise PatternMatchError(f"Non exhaustive pattern match, missing case: {args}") | |
except PatternMatchError: | |
return self._default(*args) | |
def case(self, pats): | |
def _decorator_wrapper(f): | |
self._cases[pats] = f | |
return _decorator_wrapper | |
class _CaseMethod(_CaseFunction): | |
def __init__(self, f, cases=()): | |
self._default = f | |
self._cases = OrderedDict(cases) | |
self._cls = None | |
def __call__(self, obj, *args): | |
if not isinstance(obj, self._cls): | |
raise TypeError(f"Expected {self._cls} instance, got {type(obj)}") | |
for pat, branch in self._cases.items(): | |
env = OrderedDict() | |
try: | |
pat.match(obj, env) | |
except PatternMatchError: | |
continue | |
else: | |
mth_args = tuple(env.values()) + args | |
return branch(obj, *mth_args) | |
try: | |
raise PatternMatchError(f"Non exhaustive pattern match in method {self._default.__name__}, missing case: {obj}") | |
except PatternMatchError: | |
return self._default(obj, *args) | |
class _BoundCaseMethod(typing.Callable): | |
def __init__(self, method, instance): | |
if not isinstance(instance, method._cls): | |
raise TypeError(f"Cannot bind case method of datatype {method._cls} to {type(instance)} object") | |
@functools.wraps(method) | |
def _method_wrapper(*args): | |
return method(instance, *args) | |
self._method = _method_wrapper | |
self._cases = method._cases | |
self._default = method._default | |
def __call__(self, *args): | |
return self._method(*args) | |
casefunc = _CaseFunction | |
casemethod = _CaseMethod | |
class Trait(metaclass=abc.ABCMeta): | |
pass | |
class Functor(Trait): | |
@abc.abstractmethod | |
def map(self, f): | |
raise NotImplementedError | |
class Applicative(Functor): | |
@classmethod | |
@abc.abstractmethod | |
def pure(cls, x): | |
raise NotImplementedError | |
@abc.abstractmethod | |
def app(self, xs): | |
raise NotImplementedError | |
class Monad(Applicative): | |
@abc.abstractmethod | |
def bind(self, f): | |
raise NotImplementedError | |
def join(self): | |
return self.bind(lambda x: x) | |
def app(self, xs): | |
return self.bind(lambda f: xs.map(f)) | |
class Foldable(Trait): | |
@abc.abstractmethod | |
def foldr(self, f, i): | |
raise NotImplementedError | |
def foldl(self, f, i): | |
return self.foldr(lambda b, g: lambda x: g(f(x, b)), lambda x: x)(i) | |
class Traversable(Foldable, Functor): | |
@abc.abstractmethod | |
def traverse(self, f): | |
raise NotImplementedError | |
class Maybe(ADT, Monad, Traversable): | |
@Case | |
def Just(x): pass | |
@Case | |
def Nothing(): pass | |
@casemethod | |
def map(self, *_): raise | |
@classmethod | |
def pure(cls, x): | |
return cls.Just(x) | |
@casemethod | |
def bind(self, f): raise | |
@casemethod | |
def foldr(self, f, i): raise | |
@casemethod | |
def traverse(self, f): raise | |
@map.case(Just << Var('x')) | |
def map_just(self, x, f): | |
return self.Just(f(x)) | |
@map.case(Nothing << ()) | |
def map_nothing(self, f): | |
return self | |
@bind.case(Just << Var('x')) | |
def bind_just(self, x, f): | |
return f(x) | |
@bind.case(Nothing << ()) | |
def bind_maybe(self, f): | |
return self | |
@foldr.case(Just << Var('x')) | |
def foldr_just(self, x, f, i): | |
return f(x, i) | |
@foldr.case(Nothing << ()) | |
def foldr_nothing(self, f, i): | |
return i | |
@traverse.case(Just << Var('x')) | |
def traverse_just(self, x, f): | |
fb = f(x) | |
f._applicative = type(fb) | |
return fb.map(self.Just) | |
@traverse.case(Nothing << ()) | |
def traverse_nothing(self, f): | |
return f._applicative.pure(self) | |
class Either(ADT): | |
@Case | |
def Left(x): pass | |
@Case | |
def Right(x): pass | |
@casemethod | |
def map(self, f): raise | |
@classmethod | |
def pure(cls, x): | |
return cls.Right(x) | |
@casemethod | |
def bind(self, f): raise | |
@casemethod | |
def foldr(self, f, i): raise | |
@casemethod | |
def traverse(self, f): raise | |
@map.case(Left << Var('x')) | |
def map_left(self, x, f): | |
return self.Left(x) | |
@map.case(Right << Var('x')) | |
def map_right(self, x, f): | |
return self.Right(f(x)) | |
@bind.case(Left << Any()) | |
def bind_left(self, f): | |
return self | |
@bind.case(Right << Var('x')) | |
def bind_right(self, x, f): | |
return self.Right() | |
@foldr.case(Left << Any()) | |
def foldr_left(self, f, i): | |
return i | |
@foldr.case(Right << Var('x')) | |
def foldr_right(self, x, f, i): | |
return f(x, i) | |
@traverse.case(Left << Any()) | |
def traverse_nothing(self, f): | |
return f._applicative.pure(self) | |
@traverse.case(Right << Var('x')) | |
def traverse_just(self, x, f): | |
return f(x).map(self.Right) | |
class These(Either, renaming={'Left': 'This', 'Right': 'That'}): | |
@Case | |
def These(x, y): pass | |
@map.case(These << (Var('x'), Var('y'))) | |
def map_these(self, x, y, f): | |
return self.These(x, f(y)) | |
@casemethod | |
def from_these(self, a, b): pass | |
@from_these.case(This << Var('x')) | |
def from_this(self, x, a, b): | |
return x, b | |
@from_these.case(That << Var('x')) | |
def from_that(self, x, a, b): | |
return a, x | |
@from_these.case(These << (Var('x'), Var('y'))) | |
def from_these_(self, x, y, a, b): | |
return x, y | |
class List(ADT, Functor): | |
@Case | |
def Nil(): pass | |
@Case | |
def Cons(x, xs): pass | |
def map(self, f): | |
return self.case({ | |
self.Nil << (): lambda: self.Nil(), | |
self.Cons << (Var('x'), Var('xs')): lambda x, xs: self.Cons(f(x), xs.map(f)) | |
}) | |
def sum(self): | |
return self.case({ | |
self.Nil << (): lambda: 0, | |
self.Cons << (Var('x'), Var('xs')): lambda x, xs: x + xs.sum() | |
}) | |
@classmethod | |
def from_list(cls, lst): | |
return functools.reduce(lambda xs, x: List.Cons(x, xs), lst[::-1], cls.Nil()) | |
# Example: lambda calculus interpreter | |
class Expr(ADT): | |
@Case | |
def Id(name): | |
if not name.isidentifier(): | |
raise ValueError(f"{name!r} is not a valid identifier") | |
@Case | |
def App(f, x): pass | |
@Case | |
def Lam(x, e, closure): pass | |
@Case | |
def Zero(): pass | |
@Case | |
def Succ(n): pass | |
@Case | |
def Val(x): pass | |
@casemethod | |
def eval(self, env): raise | |
@eval.case(Id << Var('name')) | |
def eval_id(self, name, env): | |
try: | |
return env[name].eval(env) | |
except KeyError: | |
raise ValueError(self) from None | |
@eval.case(App << (Var('f'), Var('y'))) | |
def eval_app(self, f, y, env): | |
pat = View(lambda e: e.eval(env), self.Lam << (Var('x'), Var('e'), Var('c'))) | |
with match(f, pat, exc=ValueError(self)) as (x, e, closure): | |
local_env = {**env, **closure, x: y.eval(env)} | |
return e.eval(local_env) | |
@eval.case(Lam << (Var('x'), Var('e'), Var('c'))) | |
def eval_lam(self, x, e, c, env): | |
return self.Lam(x, e, {**env, **c}) | |
@eval.case(Zero << ()) | |
def eval_zero(self, env): | |
return self.Val(0) | |
@eval.case(Succ << Var('n')) | |
def eval_succ(self, n, env): | |
with match(n, View(lambda e: e.eval(env), self.Val << Var('x')), exc=ValueError(n)) as x: | |
return self.Val(x + 1) | |
@eval.case(Val << Var('x')) | |
def eval_val(self, x, env): | |
return self | |
if __name__ == '__main__': | |
m = Maybe.Just(1) | |
m.case({ | |
Maybe.Just << Var('x'): lambda x: print(x), | |
Maybe.Nothing << (): lambda: print('nothing') | |
}) | |
print(m.map(lambda x: x+9)) | |
lst = List.from_list([1,2,3]) | |
print(lst.sum()) | |
print(lst.map(lambda x: x ** 2)) | |
print(These.__mro__) | |
print(Either.__mro__) | |
t1 = These.This('test') | |
t2 = These.That(3) | |
t3 = These.These('test', 3) | |
for t in [t1, t2, t3]: | |
print(t.map(lambda x: x * 2)) | |
print(isinstance(Either.Left(1), These)) | |
print(Either.Left(1).from_these('a', 'b')) | |
# Nat = forall r. (r -> r) -> r -> r | |
z = Expr.Lam('f', Expr.Lam('x', Expr.Id('x'), {}), {}) | |
s = Expr.Lam('n', Expr.Lam('f', Expr.Lam('x', Expr.App(Expr.Id('f'), Expr.App(Expr.App(Expr.Id('n'), Expr.Id('f')), Expr.Id('x'))), {}), {}), {}) | |
incr = Expr.Lam('n', Expr.Succ(Expr.Id('n')), {}) | |
to_int = Expr.Lam('n', Expr.App(Expr.App(Expr.Id('n'), incr), Expr.Zero()), {}) | |
add = Expr.Lam('m', Expr.Lam('n', Expr.Lam('f', Expr.Lam('x', Expr.App(Expr.App(Expr.Id('m'), Expr.Id('f')), Expr.App(Expr.App(Expr.Id('n'), Expr.Id('f')), Expr.Id('x'))), {}), {}), {}), {}) | |
mul = Expr.Lam('m', Expr.Lam('n', Expr.Lam('f', Expr.Lam('x', Expr.App(Expr.App(Expr.Id('m'), Expr.App(Expr.Id('n'), Expr.Id('f'))), Expr.Id('x')), {}), {}), {}), {}) | |
two = Expr.App(s, Expr.App(s, z)) | |
four = Expr.App(Expr.App(mul, two), two) | |
six = Expr.App(Expr.App(add, two), four) | |
seven = Expr.App(s, six) | |
fourtytwo = Expr.App(Expr.App(mul, six), seven) | |
print(Expr.App(to_int, fourtytwo).eval({})) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment