Skip to content

Instantly share code, notes, and snippets.

@monomere
Last active May 10, 2024 14:11
Show Gist options
  • Save monomere/5d0eeea26128d060fa158ea909289680 to your computer and use it in GitHub Desktop.
Save monomere/5d0eeea26128d060fa158ea909289680 to your computer and use it in GitHub Desktop.
hindley-milner (algo w) with limited adts and mutual recursion in python. (wip) [v0.2]
# by monomere
# Apache License 2.0
# the code is horrible and undocumented,
# might change that later :P
# requires python >=3.12
# requires the parsita python package (for parsing)
# you can remove the parser code and play around with the ast directly tho
# version 0.3 (may 10th, 2024)
# - new: match expression
from __future__ import annotations
import dataclasses, abc, re, typing
from typing import Iterator
@dataclasses.dataclass
class Rc[T]:
ref: T
@dataclasses.dataclass(frozen=True)
class Fmt:
color: bool
indent: int = 0
def down(self) -> Fmt: return Fmt(color=self.color, indent=self.indent + 1)
def as_inline(self) -> Fmt: return Fmt(color=self.color, indent=self.indent)
def ind(self, a: int = 0) -> str: return " " * (self.indent + a)
def ccn(self, s: str) -> str: return f"\033[95m{s}\033[m" if self.color else s
def cvl(self, s: str) -> str: return f"\033[35m{s}\033[m" if self.color else s
def ckw(self, s: str) -> str: return f"\033[36m{s}\033[m" if self.color else s
def cpn(self, s: str) -> str: return f"\033[90m{s}\033[m" if self.color else s
def cer(self, s: str) -> str: return f"\033[31m{s}\033[m" if self.color else s
class HasFmt(typing.Protocol):
def fmt(self, fmt: Fmt, /) -> str: ...
type TreeParent = tuple[Expr | Ty, str]
# to get nice errors, each ast node and type store a reference to its parent
# with a "reason" (bad name) that tells the user what part of the parent this node is.
# to make this bearable, this is a baseclass of Expr and Ty that has a __post_init__ method
# which reads the dataclass fields and sees if there's a Branch[T, "reason"] type annotation
# if so, either sets _parent if it exists or calls _set_parent if it exists.
@dataclasses.dataclass
class Tree:
_parent: TreeParent | None = dataclasses.field(default=None, init=False, repr=False, hash=False, compare=False)
__match_args__ = ()
def __post_init__(self):
AnnotatedType = type(Branch[type, '']) # dummy
# def find_in_args(a: type) -> type | None:
# return next((
# find_in_args(arg)
# for arg in typing.get_args(fld.type)
# if isinstance(arg, AnnotatedType)
# ), a)
for fld in dataclasses.fields(self):
if fld.name == "_parent": continue
match = re.match(r"Branch\[.*,\s*[\"']([^\"]*)[\"']\]", fld.type)
if match is None: continue
reason = match.group(1)
child = getattr(self, fld.name)
def update(o: typing.Any, args: tuple[typing.Any, ...]):
match o:
case dict():
for k, v in o.items():
update(v, args + (k,))
case list():
for i, v in enumerate(o):
suffix = ("st", "nd", "rd")[abs(i) % 10] if abs(i) % 10 + 1 < 4 else "th"
update(v, args + (f"{i + 1}{suffix}",))
case Rc(ref):
update(ref, args)
case _:
if hasattr(child, "_set_parent"):
child._set_parent((self, reason.format(*args)))
elif hasattr(child, "_parent"):
child._parent = (self, reason.format(*args))
update(child, ())
def with_parent_copy[T: typing.Any](x: T, parent: TreeParent | None) -> T:
if hasattr(x, "_parent"): x._parent = parent
elif hasattr(x, "_set_parent"): x._set_parent(parent)
else: print("no _parent/_set_parent on", x)
return x
def with_parent[T: typing.Any](x: T, parent: Ty | Expr, reason: str) -> T:
return with_parent_copy(x, (parent, reason))
@dataclasses.dataclass
class Expr(Tree, abc.ABC, HasFmt):
def fmt(self, f: Fmt) -> str:
def F(x: HasFmt) -> str: return x.fmt(f)
def Fl(x: HasFmt) -> str: return x.fmt(f.down())
def is_let(x: Expr) -> bool:
match x:
case ELet() | ELetRec() | EData(): return True
case _: return False
match self:
case EApp(ELam(_, _) as lhs, EApp(_, _) as rhs):
return f"{f.cpn("(")}{F(lhs)}{f.cpn(")")} {f.cpn("(")}{F(rhs)}{f.cpn(")")}"
case EApp(ELam(_, _) as lhs, rhs):
return f"{f.cpn("(")}{F(lhs)}{f.cpn(")")} {F(rhs)}"
case EApp(lhs, EApp(_, _) as rhs):
return f"{F(lhs)} {f.cpn("(")}{F(rhs)}{f.cpn(")")}"
case EApp(lhs, rhs):
return f"{F(lhs)} {F(rhs)}"
case ELam(name, body):
return f"{f.cpn("λ")}{name}{f.cpn(".")} {F(body)}"
case ELet(name, value, body):
return f"{f.ckw("let")} {name} {f.cpn("=")} {value.fmt(f)} {f.ckw("in")}" + \
(f" {f.ind()}{F(body)}" if is_let(body) else f"\n{f.ind(1)}{Fl(body)}")
case ELetRec(names, body):
return f"{f.ckw("let rec")}{"\n" if len(names) > 1 else " "}" + \
f"{f.cpn(",\n")} ".join(f"{f.ind(1) if len(names) > 1 else ""}{name} " + \
f"{f.cpn("=")} {value.fmt(f)}" for name, value in names.items()) + \
f"\n{f.ind()}{f.ckw("in")}" + (f" {f.ind()}{F(body)}" if is_let(body) else f"\n{f.ind(1)}{Fl(body)}")
case EIf(cond, then, othr):
return f"{f.ckw("if")} {cond.fmt(f)} {f.ckw("then")} " + \
f"{then.fmt(f)} {f.ckw("else")} {othr.fmt(f)} {f.ckw("end")}"
case EBool(v): return f.cvl("true" if v else "false")
case EInt(v): return f.cvl(f"{v}")
case EVar(name): return f"{name}"
case EData(name, params, cons, body): return f"{f.ckw("data")} {name}" + \
f" {" ".join(params)}" + \
f"\n{f.ind(1)}{f.cpn("=")} {f.cpn(f"\n{f.ind(1)}| ").join(f"{f.ccn(name)} " + \
f"{f.cpn(" × ").join(F(ty) for ty in tys)}" for name, tys in cons.items())}" + \
f"\n{f.ind()}{f.ckw("in")}" + (f" {f.ind()}{F(body)}" if is_let(body) else f"\n{f.ind(1)}{Fl(body)}")
case EMatch(expr, cases): return f"{f.ckw("match")} {F(expr)} {f.ckw("with")}" + \
''.join(
f.cpn(f"\n{f.ind()}| ") +
' '.join((case.name, *case.args)) +
f.cpn(f" ⇒ ") + Fl(case.body)
for case in cases
) + f"\n{f.ind()}{f.ckw("end")}"
case str(): return f.cer(f"<string: {repr(self)}>")
case _: raise NotImplementedError(self)
def __str__(self) -> str:
return self.fmt(Fmt(color=False))
# DO NOT RENAME
Branch = typing.Annotated
@dataclasses.dataclass
class ELam(Expr):
name: str
body: Branch[Expr, "lambda body of"]
@dataclasses.dataclass
class EApp(Expr):
lhs: Branch[Expr, "left-hand side of"]
rhs: Branch[Expr, "right-hand side of"]
@dataclasses.dataclass
class EInt(Expr):
value: int
@dataclasses.dataclass
class EBool(Expr):
value: bool
@dataclasses.dataclass
class EVar(Expr):
name: str
@dataclasses.dataclass
class ELet(Expr):
name: str
value: Branch[Expr, "right-hand side of"]
body: Branch[Expr, "body of"]
@dataclasses.dataclass
class EIf(Expr):
cond: Branch[Expr, "condition of"]
then: Branch[Expr, "true branch of"]
othr: Branch[Expr, "false branch of"]
@dataclasses.dataclass
class ELetRec(Expr):
lets: dict[str, Branch[Expr, "right-hand side of"]]
body: Branch[Expr, "body of"]
@dataclasses.dataclass
class EData(Expr):
name: str
params: list[str]
cons: dict[str, list[Branch[Ty, "{0}'s {1} type in"]]]
body: Branch[Expr, "body of"]
@dataclasses.dataclass
class ECase(Expr):
name: str
args: list[str]
body: Branch[Expr, "body of"]
@dataclasses.dataclass
class EMatch(Expr):
expr: Expr
cases: list[Branch[ECase, "{} case of"]]
@dataclasses.dataclass(unsafe_hash=True)
class Ty(Tree, abc.ABC, HasFmt):
def as_unsolved(self) -> int:
match self:
case TMeta(Rc(UnsolvedMeta(id))): return id
case _: assert False
def force(self) -> Ty:
match self:
case TFun(lhs, rhs):
return with_parent_copy(TFun(lhs.force(), rhs.force()), self._parent)
case TApp(to, params):
return with_parent_copy(TApp(to, [p.force() for p in params]), self._parent)
case TMeta(rc):
match rc.ref:
case SolvedMeta(ty): return ty.force()
case _: return self
case _: return self
def replace_named(self, ns: dict[str, Ty]) -> Ty:
match self:
case TFun(lhs, rhs):
return with_parent_copy(TFun(lhs.replace_named(ns), rhs.replace_named(ns)), self._parent)
case TNamed(name) if name in ns: return ns[name]
case TApp(to, params):
return with_parent_copy(TApp(to, [p.replace_named(ns) for p in params]), self._parent)
case TMeta(rc):
match rc.ref:
case UnsolvedMeta(_, _): return self
case SolvedMeta(ty): return ty.replace_named(ns)
case _: assert False
case _: return self
def occurs_unify(self, meta: UnsolvedMeta) -> bool:
match self:
case TFun(lhs, rhs):
return lhs.occurs_unify(meta) or rhs.occurs_unify(meta)
case TApp(_, params):
return any(p.occurs_unify(meta) for p in params)
case TMeta(rc):
match rc.ref:
case UnsolvedMeta(id, scope):
min_scope = min(scope, meta.scope)
rc.ref = with_parent_copy(UnsolvedMeta(id, min_scope), rc.ref._parent)
return id == meta.id
case SolvedMeta(ty):
return ty.occurs_unify(meta)
case _: assert False
case _: return False
def fmt(self, f: Fmt, /) -> str:
def F(x: Ty): return x.fmt(f)
def Fp(x: Ty): return f"{f.cpn("(")}{F(x)}{f.cpn(")")}" \
if isinstance(x, (TApp, TFun, TCons)) else F(x)
match self:
case TFun(TFun(_, _) as lhs, rhs): return f"{f.cpn("(")}{F(lhs)}{f.cpn(")")} {f.cpn("→")} {F(rhs)}"
case TFun(lhs, rhs): return f"{F(lhs)} {f.cpn("→")} {F(rhs)}"
case TData(name, _, _): return name
case TNamed(name): return name
case TApp(to, params):
return f"{F(to)} {" ".join(map(Fp, params))}" if len(params) else F(to)
case TMeta(rc):
match rc.ref:
case SolvedMeta(ty): return f"{F(ty)}"
case UnsolvedMeta(id, _): return f"{f.cpn("?")}{id}"
case _: assert False
case TBool(): return f.ckw("Bool")
case TNat(): return f.ckw("Nat")
case TCons(data, name, params):
return f"{data.name} {f.ckw("constructor")} {name}{' ' if params else ''}{' '.join(map(F, params))}"
case _: return f.cer(f"<not implemented {type(self)}>")
def __str__(self) -> str:
return self.fmt(Fmt(color=False))
@dataclasses.dataclass(unsafe_hash=True)
class GenTy(HasFmt):
params: list[int]
under: Ty
def fmt(self, fmt: Fmt) -> str:
return fmt.cpn("∀ ") + " ".join(fmt.cpn("?") + str(i) for i in self.params) + fmt.cpn(". ") + self.under.fmt(fmt)
def __str__(self): return self.fmt(Fmt(color=False))
@dataclasses.dataclass(unsafe_hash=True)
class TNat(Ty): pass
@dataclasses.dataclass(unsafe_hash=True)
class TBool(Ty): pass
@dataclasses.dataclass(unsafe_hash=True)
class TNamed(Ty):
name: str
@dataclasses.dataclass(unsafe_hash=True)
class TApp(Ty):
to: Ty # not a branch, this is a reference
params: list[Branch[Ty, "{} type param of"]]
@dataclasses.dataclass(unsafe_hash=True)
class TFun(Ty):
lhs: Branch[Ty, "left-hand side of"]
rhs: Branch[Ty, "right-hand side of"]
@dataclasses.dataclass(unsafe_hash=True)
class TData(Ty):
name: str
params: list[str]
cons: dict[str, list[Branch[Ty, "{0}'s {1} type in"]]]
params_inst: dict[str, Ty] | None = None
@dataclasses.dataclass(unsafe_hash=True)
class TCons(Ty):
ty: TData # not a branch, this is a reference
name: str
applied: list[Branch[Ty, "{} parameter"]]
@dataclasses.dataclass
class Meta(abc.ABC, Tree):
@abc.abstractmethod
def _set_parent(self, parent: tuple[Expr | Ty, str] | None): ...
@dataclasses.dataclass
class SolvedMeta(Meta):
ty: Ty
def _set_parent(self, parent: tuple[Expr | Ty, str] | None):
self.ty._parent = self._parent = parent
@dataclasses.dataclass
class UnsolvedMeta(Meta):
id: int
scope: int
def _set_parent(self, parent: tuple[Expr | Ty, str] | None):
self._parent = parent
@dataclasses.dataclass(unsafe_hash=True)
class TMeta(Ty):
var: Rc[Branch[Meta, "meta"]]
@dataclasses.dataclass(frozen=True)
class Env(abc.ABC):
def find_var(self, name: str) -> GenTy | None:
match self:
case VarEnv(parent, nameb, ty) if name == nameb: return ty
case EmptyEnv(parent) | TyEnv(parent) | VarEnv(parent):
return parent.find_var(name)
case RootEnv(): return None
case _: raise NotImplementedError(self)
def find_ty(self, name: str) -> GenTy | None:
match self:
case TyEnv(parent, nameb, ty) if name == nameb: return ty
case EmptyEnv(parent) | TyEnv(parent) | VarEnv(parent):
return parent.find_ty(name)
case RootEnv(): return None
case _: raise NotImplementedError(self)
def find_var_check(self, name: str) -> GenTy:
if (ty := self.find_var(name)) is not None:
return ty
raise NameError(f"there's no '{name}' silly")
def find_ty_check(self, name: str) -> GenTy:
if (ty := self.find_ty(name)) is not None:
return ty
raise NameError(f"there's no '{name}' silly")
def vars(self) -> Iterator[tuple[str, GenTy]]:
match self:
case VarEnv(parent, name, ty):
yield name, GenTy(ty.params, ty.under.force())
yield from parent.vars()
case EmptyEnv(parent) | TyEnv(parent):
yield from parent.vars()
case RootEnv(): return
case _: raise NotImplementedError(self)
@dataclasses.dataclass(frozen=True)
class VarEnv(Env):
parent: Env
name: str
ty: GenTy
def __post_init__(self):
assert isinstance(self.name, str)
@dataclasses.dataclass(frozen=True)
class TyEnv(Env):
parent: Env
name: str
ty: GenTy
@dataclasses.dataclass(frozen=True)
class EmptyEnv(Env):
parent: Env
@dataclasses.dataclass(frozen=True)
class RootEnv(Env): pass
class UnifyError(Exception):
def __init__(self, a: Ty, b: Ty | None, reason: str):
self.a = a
self.b = b
self.reason = reason
class Tycheck:
def __init__(self):
self.next_id: int = 0
self.cur_scope: int = 0
def gen(self, ty: Ty) -> GenTy:
vs: set[int] = set()
ty = ty.force()
def helper(ty: Ty):
match ty:
case TFun(lhs, rhs):
helper(lhs)
helper(rhs)
case TMeta(rc):
match rc.ref:
case UnsolvedMeta(id, scope):
if scope > self.cur_scope:
vs.add(id)
case _: pass
case _: pass
helper(ty)
return GenTy(list(vs), ty)
def dont_gen(self, ty: Ty) -> GenTy:
return GenTy([], ty)
def inst(self, gen: GenTy, parent: Ty | Expr) -> Ty:
vs = { p: self.fresh(parent, f"inferred type of parameter {p}, instantiated in") for p in gen.params }
def helper(ty: Ty) -> Ty:
match ty:
case TFun(lhs, rhs):
return with_parent(TFun(helper(lhs), helper(rhs)), ty, "instantiated")
case TData(name, params, cons):
return with_parent(TApp(ty, list(vs.values())), ty, "instantiated")
case TMeta(rc):
match rc.ref:
case UnsolvedMeta(id, _) if id in vs: return vs[id]
case _: return ty
case _: return ty
return helper(gen.under)
def fresh(self, parent: Ty | Expr, reason: str) -> Ty:
id = self.next_id
self.next_id += 1
meta = UnsolvedMeta(id, self.cur_scope)
meta._set_parent((parent, reason))
r = TMeta(Rc(meta))
return r
def infer(self, env: Env, e: Expr, i: int) -> Ty:
# print(" " * i, "infer", str(e))
r = self.infer_(env, e, i)
# print(" " * i, ":", str(r))
return r
def infer_(self, env: Env, e: Expr, i: int) -> Ty:
match e:
case EApp(lhs, rhs):
lhsty = self.infer(env, lhs, i + 1)
if isinstance(lhsty, TCons): # the constructor case, special type for constructors
rhsty = self.infer(env, rhs, i + 1)
con = lhsty.ty.cons[lhsty.name]
conty = con[len(lhsty.applied)]
self.unify(env, rhsty, conty, i + 1)
if len(lhsty.applied) + 1 == len(con):
assert lhsty.ty.params_inst is not None
ps = [lhsty.ty.params_inst[name] for name in lhsty.ty.params]
return TApp(lhsty.ty, ps)
return TCons(lhsty.ty, lhsty.name, lhsty.applied + [rhsty])
rhsty = self.infer(env, rhs, i + 1)
resty = self.fresh(e, "inferred return type of")
self.unify(env, lhsty, with_parent(TFun(rhsty, resty), lhs, "inferred type of"), i + 1)
return resty
case ELam(name, body):
argty = self.fresh(e, "inferred type of argument of")
bodyty = self.infer(VarEnv(env, name, self.dont_gen(argty)), body, i + 1)
return with_parent(TFun(argty, bodyty), e, "type of")
case EIf(cond, then, othr):
condty = self.infer(env, cond, i + 1)
self.unify(env, condty, with_parent(TBool(), condty, "inferred type of"), i + 1)
thenty = self.infer(env, then, i + 1)
othrty = self.infer(env, othr, i + 1)
self.unify(env, thenty, othrty, i + 1)
return thenty
case ELet(name, value, body):
self.cur_scope += 1
varty = self.infer(env, value, i + 1)
self.cur_scope -= 1
g = self.gen(varty)
return self.infer(VarEnv(env, name, g), body, i + 1)
case ELetRec(lets, body):
# enter scope
self.cur_scope += 1
# generate fresh metavars for the bindings
vartys = { name: self.fresh(e, f"inferred type of {name} in") for name in lets.keys() }
# construct environment (associate bindings with the fresh types)
env2 = env
for name in lets.keys():
# wait until later to generalize
env2 = VarEnv(env2, name, self.dont_gen(vartys[name]))
# unify the inferred types and the fresh vars
for name, value in lets.items():
r = self.infer(env2, value, i + 1)
self.unify(env, vartys[name], r, i + 1)
# exit scope
self.cur_scope -= 1
# reconstruct the environment but generalize the variables this time
env2 = env
for name in lets.keys():
env2 = VarEnv(env2, name, self.gen(vartys[name]))
# infer body
return self.infer(env2, body, i + 1)
case EData(name, params, cons, body):
ptys = { p: self.fresh(e, f"type parameter {p} in") for p in params }
dty = with_parent(TData(name, params, {
name: [c.replace_named(ptys) for c in con]
for name, con in cons.items()
}, ptys), e, "type of")
ps = [p.as_unsolved() for p in ptys.values()]
env2 = TyEnv(env, name, GenTy(ps, dty))
for con in cons.keys():
# an unary constructor is just the datatype,
# >1-ary constructors are special types because we don't really have rank 2 types.
conty = TCons(dty, con, []) if len(cons[con]) > 0 else dty
env2 = VarEnv(env2, con, GenTy(ps, with_parent(conty, e, f"constructor `{con}` of")))
return self.infer(env2, body, i + 1)
case EVar(name):
v = env.find_var_check(name)
ty = self.inst(v, e)
return with_parent(ty, e, "type of")
case EMatch(expr, cases):
ety = self.infer(env, expr, i + 1)
retty = self.fresh(e, "inferred type of")
dataty: TApp | None = None # ety will be unified with this
for case in cases:
consty = self.inst(env.find_var_check(case.name), case)
# figure out the type of the constructor;
# we need both the instantiated and raw TData versions.
match consty:
# TData is weird and special (might change that in the future)
# instantiated TData has its params in params_inst instead of it being TApp.
# we check for both TCons and TData because unary cases are TData and >1ary are TCons.
case TCons():
tdata = consty.ty
assert consty.ty.params_inst is not None
tdata_app = TApp(consty.ty, [consty.ty.params_inst[name] for name in consty.ty.params])
case TData():
tdata = consty
assert consty.params_inst is not None
tdata_app = TApp(consty, [consty.params_inst[name] for name in consty.params])
case TApp(): # but also maybe TApp idk i need to fix this
match consty.to:
case TData(): tdata, tdata_app = consty.to, consty
case _: raise UnifyError(consty, None, "not a constructor")
case _: raise UnifyError(consty, None, "not a constructor")
# check if already set or set the datatype
if dataty is not None:
if dataty.to != tdata:
raise UnifyError(dataty.to, tdata, "mismatching constructor types")
else:
dataty = tdata_app
# check if already set or set the datatype
argtys = tdata.cons[case.name]
if len(argtys) != len(case.args):
raise UnifyError(consty, None, f"wrong number of arguments (expected {len(case.args)})")
# create argument types
env2 = env
for name, argty in zip(case.args, argtys):
env2 = VarEnv(env2, name, self.dont_gen(argty))
# infer body
bodyty = self.infer(env2, case.body, i + 1)
self.unify(env2, bodyty, retty, i + 1)
if dataty is None: raise UnifyError(ety, None, "no cases") # shouldn't happen
self.unify(env, ety, dataty, i + 1)
return retty
case EInt(_): return with_parent(TNat(), e, "literal type of")
case EBool(_): return with_parent(TBool(), e, "literal type of")
case _: raise NotImplementedError(e)
def unify(self, env: Env, a: Ty, b: Ty, i: int):
a = a.force()
b = b.force()
if isinstance(a, TNamed):
return self.unify(env, env.find_ty_check(a.name).under, b, i)
if isinstance(b, TNamed):
return self.unify(env, a, env.find_ty_check(b.name).under, i)
if a == b: return
match (a, b):
case (TFun(al, ar), TFun(bl, br)):
self.unify(env, al, bl, i + 1)
self.unify(env, ar, br, i + 1)
case (TApp(ato, aps), TApp(bto, bps)) if len(aps) == len(bps):
self.unify(env, ato, bto, i + 1)
for ap, bp in zip(aps, bps):
self.unify(env, ap, bp, i + 1)
case (TMeta(rc), t) | (t, TMeta(rc)):
match rc.ref:
case UnsolvedMeta(id, _):
if t.occurs_unify(rc.ref):
raise UnifyError(a, b, f"can't unify, ?{id} occurs in {t}")
rc.ref = with_parent_copy(SolvedMeta(t), rc.ref._parent)
case _: assert False
case _: raise UnifyError(a, b, "can't unify")
class Parser:
import parsita as P
import parsita.util
import parsita.options
parsita.options.whitespace = P.reg(r"\s*(\s*--.*\n\s*)*")
@staticmethod
def _reduce_to_app(l: list[Expr]) -> Expr:
es = iter(l)
r = next(es)
for e in es: r = EApp(r, e)
return r
KWS = { 'if', 'then', 'else', 'end', 'in', '*', 'rec', 'let', 'match', 'with' }
@staticmethod
def _reduce_iden(v: str) -> P.Parser[str, str]:
if v in Parser.KWS: return Parser.P.failure("keyword") # type: ignore
return Parser.P.success(v)
pexpr = P.fwd()
pty = P.fwd()
piden = P.reg(r"[a-zA-Z_]+") >= _reduce_iden
pnatty = P.lit('Nat') > parsita.util.constant(TNat())
pboolty = P.lit('Bool') > parsita.util.constant(TBool)
pnamety = piden > TNamed
ptypar = P.lit('(') >> pty << P.lit(')')
ptyatom = P.first(ptypar, pnatty, pboolty, pnamety)
pty.define(P.first((pnamety & P.rep1(ptyatom) > parsita.util.splat(TApp)) | ptyatom))
pvar = piden > EVar
pif = P.lit('if') >> pexpr << P.lit('then') & pexpr << P.lit('else') & pexpr << P.lit('end') > parsita.util.splat(EIf)
pconsdef = piden & P.repsep(pty, '*') > tuple[str, list[Ty]]
pdata = P.lit('data') >> piden & P.rep(piden) << P.lit('=') & \
(P.rep1sep(pconsdef, '|') > dict[str, list[Ty]]) << P.lit('in') & pexpr > parsita.util.splat(EData)
pcase = P.lit('|') >> piden & P.rep(piden) << P.lit('=>') & pexpr > parsita.util.splat(ECase)
pmatch = P.lit('match') >> pexpr << P.lit('with') & P.rep1(pcase) << P.lit('end') > parsita.util.splat(EMatch)
plet = P.lit('let') >> piden << P.lit('=') & pexpr << P.lit('in') & pexpr > parsita.util.splat(ELet)
pletrec = P.lit('let') >> P.lit('rec') >> \
(P.rep1sep(piden << P.lit('=') & pexpr > tuple, ',') > dict) << P.lit('in') & pexpr > parsita.util.splat(ELetRec)
plam = P.lit('\\') >> piden << P.lit('.') & pexpr > parsita.util.splat(ELam)
pnum = (P.reg(r"[0-9]+") > int) > EInt
ppar = P.lit('(') >> pexpr << P.lit(')')
pbool \
= (P.lit('true') > parsita.util.constant(EBool(True))) \
| (P.lit('false') > parsita.util.constant(EBool(False)))
patom = P.first(pmatch, ppar, pif, pletrec, plet, plam, pbool, pnum, pvar | pdata)
pexpr.define(P.rep1(patom) > _reduce_to_app)
@staticmethod
def parse(source: str | bytes) -> P.Result[Expr]:
return Parser.pexpr.parse(source)
def main():
r = Parser.pexpr.parse(
R"""
data List a = Cons a * List a | Nil in
data Pair a b = Pair a * b in
let x = Pair true (Cons true Nil) in
match x with
| Pair x y => match y with
| Cons x xs => x
| Nil => x
end
end
"""
)
match r:
case Parser.P.Success(Expr() as e):
tc = Tycheck()
fmt = Fmt(color=True)
print(e.fmt(fmt))
try:
print(fmt.cpn(':'), tc.infer(RootEnv(), e, 0).force().fmt(fmt))
except UnifyError as err:
if err.b is not None:
print(fmt.cpn('│'), fmt.cer(err.reason + ':'), err.a.fmt(fmt), '≠', err.b.fmt(fmt))
else:
print(fmt.cpn('│'), fmt.cer(err.reason + ':'), err.a.fmt(fmt))
def print_parents(x: HasFmt, reason: str, prefix: bool, i = 0):
p = fmt.cpn("│") if prefix and i > 0 else ""
if hasattr(x, "_parent"):
t: tuple[HasFmt, str] | None = getattr(x, "_parent")
if t is not None:
if not reason:
print_parents(t[0], t[1], prefix, i)
return
lns = f"{x.fmt(fmt)}".split("\n")
start = "╰" if not prefix or i > 0 else "├"
end = "──" if len(lns) == 1 and (not hasattr(t[0], "_parent") or getattr(t[0], "_parent")) is None else "─╮"
ind = p + " " * (2*(i)-prefix)
if len(lns) == 1: print(ind + fmt.cpn(start + end), fmt.cpn(reason), lns[0])
else: print(ind + fmt.cpn(start + end), fmt.cpn(reason))
for j in range(len(lns) == 1, len(lns)): print(ind + " " + fmt.cpn("│"), lns[j])
print_parents(t[0], t[1], prefix, i + 1)
else:
print(x, "has no _parent")
print(fmt.cpn('│'), err.a.fmt(fmt), 'is from:')
print_parents(err.a, "", True)
if err.b is not None:
print(fmt.cpn('│'), err.b.fmt(fmt), 'is from:')
print_parents(err.b, "", False)
case Parser.P.Failure(err): # type: ignore
print(err)
exit(1)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment