Last active
May 10, 2024 14:11
-
-
Save monomere/5d0eeea26128d060fa158ea909289680 to your computer and use it in GitHub Desktop.
hindley-milner (algo w) with limited adts and mutual recursion in python. (wip) [v0.2]
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# by monomere | |
# Apache License 2.0 | |
# the code is horrible and undocumented, | |
# might change that later :P | |
# requires python >=3.12 | |
# requires the parsita python package (for parsing) | |
# you can remove the parser code and play around with the ast directly tho | |
# version 0.3 (may 10th, 2024) | |
# - new: match expression | |
from __future__ import annotations | |
import dataclasses, abc, re, typing | |
from typing import Iterator | |
@dataclasses.dataclass | |
class Rc[T]: | |
ref: T | |
@dataclasses.dataclass(frozen=True) | |
class Fmt: | |
color: bool | |
indent: int = 0 | |
def down(self) -> Fmt: return Fmt(color=self.color, indent=self.indent + 1) | |
def as_inline(self) -> Fmt: return Fmt(color=self.color, indent=self.indent) | |
def ind(self, a: int = 0) -> str: return " " * (self.indent + a) | |
def ccn(self, s: str) -> str: return f"\033[95m{s}\033[m" if self.color else s | |
def cvl(self, s: str) -> str: return f"\033[35m{s}\033[m" if self.color else s | |
def ckw(self, s: str) -> str: return f"\033[36m{s}\033[m" if self.color else s | |
def cpn(self, s: str) -> str: return f"\033[90m{s}\033[m" if self.color else s | |
def cer(self, s: str) -> str: return f"\033[31m{s}\033[m" if self.color else s | |
class HasFmt(typing.Protocol): | |
def fmt(self, fmt: Fmt, /) -> str: ... | |
type TreeParent = tuple[Expr | Ty, str] | |
# to get nice errors, each ast node and type store a reference to its parent | |
# with a "reason" (bad name) that tells the user what part of the parent this node is. | |
# to make this bearable, this is a baseclass of Expr and Ty that has a __post_init__ method | |
# which reads the dataclass fields and sees if there's a Branch[T, "reason"] type annotation | |
# if so, either sets _parent if it exists or calls _set_parent if it exists. | |
@dataclasses.dataclass | |
class Tree: | |
_parent: TreeParent | None = dataclasses.field(default=None, init=False, repr=False, hash=False, compare=False) | |
__match_args__ = () | |
def __post_init__(self): | |
AnnotatedType = type(Branch[type, '']) # dummy | |
# def find_in_args(a: type) -> type | None: | |
# return next(( | |
# find_in_args(arg) | |
# for arg in typing.get_args(fld.type) | |
# if isinstance(arg, AnnotatedType) | |
# ), a) | |
for fld in dataclasses.fields(self): | |
if fld.name == "_parent": continue | |
match = re.match(r"Branch\[.*,\s*[\"']([^\"]*)[\"']\]", fld.type) | |
if match is None: continue | |
reason = match.group(1) | |
child = getattr(self, fld.name) | |
def update(o: typing.Any, args: tuple[typing.Any, ...]): | |
match o: | |
case dict(): | |
for k, v in o.items(): | |
update(v, args + (k,)) | |
case list(): | |
for i, v in enumerate(o): | |
suffix = ("st", "nd", "rd")[abs(i) % 10] if abs(i) % 10 + 1 < 4 else "th" | |
update(v, args + (f"{i + 1}{suffix}",)) | |
case Rc(ref): | |
update(ref, args) | |
case _: | |
if hasattr(child, "_set_parent"): | |
child._set_parent((self, reason.format(*args))) | |
elif hasattr(child, "_parent"): | |
child._parent = (self, reason.format(*args)) | |
update(child, ()) | |
def with_parent_copy[T: typing.Any](x: T, parent: TreeParent | None) -> T: | |
if hasattr(x, "_parent"): x._parent = parent | |
elif hasattr(x, "_set_parent"): x._set_parent(parent) | |
else: print("no _parent/_set_parent on", x) | |
return x | |
def with_parent[T: typing.Any](x: T, parent: Ty | Expr, reason: str) -> T: | |
return with_parent_copy(x, (parent, reason)) | |
@dataclasses.dataclass | |
class Expr(Tree, abc.ABC, HasFmt): | |
def fmt(self, f: Fmt) -> str: | |
def F(x: HasFmt) -> str: return x.fmt(f) | |
def Fl(x: HasFmt) -> str: return x.fmt(f.down()) | |
def is_let(x: Expr) -> bool: | |
match x: | |
case ELet() | ELetRec() | EData(): return True | |
case _: return False | |
match self: | |
case EApp(ELam(_, _) as lhs, EApp(_, _) as rhs): | |
return f"{f.cpn("(")}{F(lhs)}{f.cpn(")")} {f.cpn("(")}{F(rhs)}{f.cpn(")")}" | |
case EApp(ELam(_, _) as lhs, rhs): | |
return f"{f.cpn("(")}{F(lhs)}{f.cpn(")")} {F(rhs)}" | |
case EApp(lhs, EApp(_, _) as rhs): | |
return f"{F(lhs)} {f.cpn("(")}{F(rhs)}{f.cpn(")")}" | |
case EApp(lhs, rhs): | |
return f"{F(lhs)} {F(rhs)}" | |
case ELam(name, body): | |
return f"{f.cpn("λ")}{name}{f.cpn(".")} {F(body)}" | |
case ELet(name, value, body): | |
return f"{f.ckw("let")} {name} {f.cpn("=")} {value.fmt(f)} {f.ckw("in")}" + \ | |
(f" {f.ind()}{F(body)}" if is_let(body) else f"\n{f.ind(1)}{Fl(body)}") | |
case ELetRec(names, body): | |
return f"{f.ckw("let rec")}{"\n" if len(names) > 1 else " "}" + \ | |
f"{f.cpn(",\n")} ".join(f"{f.ind(1) if len(names) > 1 else ""}{name} " + \ | |
f"{f.cpn("=")} {value.fmt(f)}" for name, value in names.items()) + \ | |
f"\n{f.ind()}{f.ckw("in")}" + (f" {f.ind()}{F(body)}" if is_let(body) else f"\n{f.ind(1)}{Fl(body)}") | |
case EIf(cond, then, othr): | |
return f"{f.ckw("if")} {cond.fmt(f)} {f.ckw("then")} " + \ | |
f"{then.fmt(f)} {f.ckw("else")} {othr.fmt(f)} {f.ckw("end")}" | |
case EBool(v): return f.cvl("true" if v else "false") | |
case EInt(v): return f.cvl(f"{v}") | |
case EVar(name): return f"{name}" | |
case EData(name, params, cons, body): return f"{f.ckw("data")} {name}" + \ | |
f" {" ".join(params)}" + \ | |
f"\n{f.ind(1)}{f.cpn("=")} {f.cpn(f"\n{f.ind(1)}| ").join(f"{f.ccn(name)} " + \ | |
f"{f.cpn(" × ").join(F(ty) for ty in tys)}" for name, tys in cons.items())}" + \ | |
f"\n{f.ind()}{f.ckw("in")}" + (f" {f.ind()}{F(body)}" if is_let(body) else f"\n{f.ind(1)}{Fl(body)}") | |
case EMatch(expr, cases): return f"{f.ckw("match")} {F(expr)} {f.ckw("with")}" + \ | |
''.join( | |
f.cpn(f"\n{f.ind()}| ") + | |
' '.join((case.name, *case.args)) + | |
f.cpn(f" ⇒ ") + Fl(case.body) | |
for case in cases | |
) + f"\n{f.ind()}{f.ckw("end")}" | |
case str(): return f.cer(f"<string: {repr(self)}>") | |
case _: raise NotImplementedError(self) | |
def __str__(self) -> str: | |
return self.fmt(Fmt(color=False)) | |
# DO NOT RENAME | |
Branch = typing.Annotated | |
@dataclasses.dataclass | |
class ELam(Expr): | |
name: str | |
body: Branch[Expr, "lambda body of"] | |
@dataclasses.dataclass | |
class EApp(Expr): | |
lhs: Branch[Expr, "left-hand side of"] | |
rhs: Branch[Expr, "right-hand side of"] | |
@dataclasses.dataclass | |
class EInt(Expr): | |
value: int | |
@dataclasses.dataclass | |
class EBool(Expr): | |
value: bool | |
@dataclasses.dataclass | |
class EVar(Expr): | |
name: str | |
@dataclasses.dataclass | |
class ELet(Expr): | |
name: str | |
value: Branch[Expr, "right-hand side of"] | |
body: Branch[Expr, "body of"] | |
@dataclasses.dataclass | |
class EIf(Expr): | |
cond: Branch[Expr, "condition of"] | |
then: Branch[Expr, "true branch of"] | |
othr: Branch[Expr, "false branch of"] | |
@dataclasses.dataclass | |
class ELetRec(Expr): | |
lets: dict[str, Branch[Expr, "right-hand side of"]] | |
body: Branch[Expr, "body of"] | |
@dataclasses.dataclass | |
class EData(Expr): | |
name: str | |
params: list[str] | |
cons: dict[str, list[Branch[Ty, "{0}'s {1} type in"]]] | |
body: Branch[Expr, "body of"] | |
@dataclasses.dataclass | |
class ECase(Expr): | |
name: str | |
args: list[str] | |
body: Branch[Expr, "body of"] | |
@dataclasses.dataclass | |
class EMatch(Expr): | |
expr: Expr | |
cases: list[Branch[ECase, "{} case of"]] | |
@dataclasses.dataclass(unsafe_hash=True) | |
class Ty(Tree, abc.ABC, HasFmt): | |
def as_unsolved(self) -> int: | |
match self: | |
case TMeta(Rc(UnsolvedMeta(id))): return id | |
case _: assert False | |
def force(self) -> Ty: | |
match self: | |
case TFun(lhs, rhs): | |
return with_parent_copy(TFun(lhs.force(), rhs.force()), self._parent) | |
case TApp(to, params): | |
return with_parent_copy(TApp(to, [p.force() for p in params]), self._parent) | |
case TMeta(rc): | |
match rc.ref: | |
case SolvedMeta(ty): return ty.force() | |
case _: return self | |
case _: return self | |
def replace_named(self, ns: dict[str, Ty]) -> Ty: | |
match self: | |
case TFun(lhs, rhs): | |
return with_parent_copy(TFun(lhs.replace_named(ns), rhs.replace_named(ns)), self._parent) | |
case TNamed(name) if name in ns: return ns[name] | |
case TApp(to, params): | |
return with_parent_copy(TApp(to, [p.replace_named(ns) for p in params]), self._parent) | |
case TMeta(rc): | |
match rc.ref: | |
case UnsolvedMeta(_, _): return self | |
case SolvedMeta(ty): return ty.replace_named(ns) | |
case _: assert False | |
case _: return self | |
def occurs_unify(self, meta: UnsolvedMeta) -> bool: | |
match self: | |
case TFun(lhs, rhs): | |
return lhs.occurs_unify(meta) or rhs.occurs_unify(meta) | |
case TApp(_, params): | |
return any(p.occurs_unify(meta) for p in params) | |
case TMeta(rc): | |
match rc.ref: | |
case UnsolvedMeta(id, scope): | |
min_scope = min(scope, meta.scope) | |
rc.ref = with_parent_copy(UnsolvedMeta(id, min_scope), rc.ref._parent) | |
return id == meta.id | |
case SolvedMeta(ty): | |
return ty.occurs_unify(meta) | |
case _: assert False | |
case _: return False | |
def fmt(self, f: Fmt, /) -> str: | |
def F(x: Ty): return x.fmt(f) | |
def Fp(x: Ty): return f"{f.cpn("(")}{F(x)}{f.cpn(")")}" \ | |
if isinstance(x, (TApp, TFun, TCons)) else F(x) | |
match self: | |
case TFun(TFun(_, _) as lhs, rhs): return f"{f.cpn("(")}{F(lhs)}{f.cpn(")")} {f.cpn("→")} {F(rhs)}" | |
case TFun(lhs, rhs): return f"{F(lhs)} {f.cpn("→")} {F(rhs)}" | |
case TData(name, _, _): return name | |
case TNamed(name): return name | |
case TApp(to, params): | |
return f"{F(to)} {" ".join(map(Fp, params))}" if len(params) else F(to) | |
case TMeta(rc): | |
match rc.ref: | |
case SolvedMeta(ty): return f"{F(ty)}" | |
case UnsolvedMeta(id, _): return f"{f.cpn("?")}{id}" | |
case _: assert False | |
case TBool(): return f.ckw("Bool") | |
case TNat(): return f.ckw("Nat") | |
case TCons(data, name, params): | |
return f"{data.name} {f.ckw("constructor")} {name}{' ' if params else ''}{' '.join(map(F, params))}" | |
case _: return f.cer(f"<not implemented {type(self)}>") | |
def __str__(self) -> str: | |
return self.fmt(Fmt(color=False)) | |
@dataclasses.dataclass(unsafe_hash=True) | |
class GenTy(HasFmt): | |
params: list[int] | |
under: Ty | |
def fmt(self, fmt: Fmt) -> str: | |
return fmt.cpn("∀ ") + " ".join(fmt.cpn("?") + str(i) for i in self.params) + fmt.cpn(". ") + self.under.fmt(fmt) | |
def __str__(self): return self.fmt(Fmt(color=False)) | |
@dataclasses.dataclass(unsafe_hash=True) | |
class TNat(Ty): pass | |
@dataclasses.dataclass(unsafe_hash=True) | |
class TBool(Ty): pass | |
@dataclasses.dataclass(unsafe_hash=True) | |
class TNamed(Ty): | |
name: str | |
@dataclasses.dataclass(unsafe_hash=True) | |
class TApp(Ty): | |
to: Ty # not a branch, this is a reference | |
params: list[Branch[Ty, "{} type param of"]] | |
@dataclasses.dataclass(unsafe_hash=True) | |
class TFun(Ty): | |
lhs: Branch[Ty, "left-hand side of"] | |
rhs: Branch[Ty, "right-hand side of"] | |
@dataclasses.dataclass(unsafe_hash=True) | |
class TData(Ty): | |
name: str | |
params: list[str] | |
cons: dict[str, list[Branch[Ty, "{0}'s {1} type in"]]] | |
params_inst: dict[str, Ty] | None = None | |
@dataclasses.dataclass(unsafe_hash=True) | |
class TCons(Ty): | |
ty: TData # not a branch, this is a reference | |
name: str | |
applied: list[Branch[Ty, "{} parameter"]] | |
@dataclasses.dataclass | |
class Meta(abc.ABC, Tree): | |
@abc.abstractmethod | |
def _set_parent(self, parent: tuple[Expr | Ty, str] | None): ... | |
@dataclasses.dataclass | |
class SolvedMeta(Meta): | |
ty: Ty | |
def _set_parent(self, parent: tuple[Expr | Ty, str] | None): | |
self.ty._parent = self._parent = parent | |
@dataclasses.dataclass | |
class UnsolvedMeta(Meta): | |
id: int | |
scope: int | |
def _set_parent(self, parent: tuple[Expr | Ty, str] | None): | |
self._parent = parent | |
@dataclasses.dataclass(unsafe_hash=True) | |
class TMeta(Ty): | |
var: Rc[Branch[Meta, "meta"]] | |
@dataclasses.dataclass(frozen=True) | |
class Env(abc.ABC): | |
def find_var(self, name: str) -> GenTy | None: | |
match self: | |
case VarEnv(parent, nameb, ty) if name == nameb: return ty | |
case EmptyEnv(parent) | TyEnv(parent) | VarEnv(parent): | |
return parent.find_var(name) | |
case RootEnv(): return None | |
case _: raise NotImplementedError(self) | |
def find_ty(self, name: str) -> GenTy | None: | |
match self: | |
case TyEnv(parent, nameb, ty) if name == nameb: return ty | |
case EmptyEnv(parent) | TyEnv(parent) | VarEnv(parent): | |
return parent.find_ty(name) | |
case RootEnv(): return None | |
case _: raise NotImplementedError(self) | |
def find_var_check(self, name: str) -> GenTy: | |
if (ty := self.find_var(name)) is not None: | |
return ty | |
raise NameError(f"there's no '{name}' silly") | |
def find_ty_check(self, name: str) -> GenTy: | |
if (ty := self.find_ty(name)) is not None: | |
return ty | |
raise NameError(f"there's no '{name}' silly") | |
def vars(self) -> Iterator[tuple[str, GenTy]]: | |
match self: | |
case VarEnv(parent, name, ty): | |
yield name, GenTy(ty.params, ty.under.force()) | |
yield from parent.vars() | |
case EmptyEnv(parent) | TyEnv(parent): | |
yield from parent.vars() | |
case RootEnv(): return | |
case _: raise NotImplementedError(self) | |
@dataclasses.dataclass(frozen=True) | |
class VarEnv(Env): | |
parent: Env | |
name: str | |
ty: GenTy | |
def __post_init__(self): | |
assert isinstance(self.name, str) | |
@dataclasses.dataclass(frozen=True) | |
class TyEnv(Env): | |
parent: Env | |
name: str | |
ty: GenTy | |
@dataclasses.dataclass(frozen=True) | |
class EmptyEnv(Env): | |
parent: Env | |
@dataclasses.dataclass(frozen=True) | |
class RootEnv(Env): pass | |
class UnifyError(Exception): | |
def __init__(self, a: Ty, b: Ty | None, reason: str): | |
self.a = a | |
self.b = b | |
self.reason = reason | |
class Tycheck: | |
def __init__(self): | |
self.next_id: int = 0 | |
self.cur_scope: int = 0 | |
def gen(self, ty: Ty) -> GenTy: | |
vs: set[int] = set() | |
ty = ty.force() | |
def helper(ty: Ty): | |
match ty: | |
case TFun(lhs, rhs): | |
helper(lhs) | |
helper(rhs) | |
case TMeta(rc): | |
match rc.ref: | |
case UnsolvedMeta(id, scope): | |
if scope > self.cur_scope: | |
vs.add(id) | |
case _: pass | |
case _: pass | |
helper(ty) | |
return GenTy(list(vs), ty) | |
def dont_gen(self, ty: Ty) -> GenTy: | |
return GenTy([], ty) | |
def inst(self, gen: GenTy, parent: Ty | Expr) -> Ty: | |
vs = { p: self.fresh(parent, f"inferred type of parameter {p}, instantiated in") for p in gen.params } | |
def helper(ty: Ty) -> Ty: | |
match ty: | |
case TFun(lhs, rhs): | |
return with_parent(TFun(helper(lhs), helper(rhs)), ty, "instantiated") | |
case TData(name, params, cons): | |
return with_parent(TApp(ty, list(vs.values())), ty, "instantiated") | |
case TMeta(rc): | |
match rc.ref: | |
case UnsolvedMeta(id, _) if id in vs: return vs[id] | |
case _: return ty | |
case _: return ty | |
return helper(gen.under) | |
def fresh(self, parent: Ty | Expr, reason: str) -> Ty: | |
id = self.next_id | |
self.next_id += 1 | |
meta = UnsolvedMeta(id, self.cur_scope) | |
meta._set_parent((parent, reason)) | |
r = TMeta(Rc(meta)) | |
return r | |
def infer(self, env: Env, e: Expr, i: int) -> Ty: | |
# print(" " * i, "infer", str(e)) | |
r = self.infer_(env, e, i) | |
# print(" " * i, ":", str(r)) | |
return r | |
def infer_(self, env: Env, e: Expr, i: int) -> Ty: | |
match e: | |
case EApp(lhs, rhs): | |
lhsty = self.infer(env, lhs, i + 1) | |
if isinstance(lhsty, TCons): # the constructor case, special type for constructors | |
rhsty = self.infer(env, rhs, i + 1) | |
con = lhsty.ty.cons[lhsty.name] | |
conty = con[len(lhsty.applied)] | |
self.unify(env, rhsty, conty, i + 1) | |
if len(lhsty.applied) + 1 == len(con): | |
assert lhsty.ty.params_inst is not None | |
ps = [lhsty.ty.params_inst[name] for name in lhsty.ty.params] | |
return TApp(lhsty.ty, ps) | |
return TCons(lhsty.ty, lhsty.name, lhsty.applied + [rhsty]) | |
rhsty = self.infer(env, rhs, i + 1) | |
resty = self.fresh(e, "inferred return type of") | |
self.unify(env, lhsty, with_parent(TFun(rhsty, resty), lhs, "inferred type of"), i + 1) | |
return resty | |
case ELam(name, body): | |
argty = self.fresh(e, "inferred type of argument of") | |
bodyty = self.infer(VarEnv(env, name, self.dont_gen(argty)), body, i + 1) | |
return with_parent(TFun(argty, bodyty), e, "type of") | |
case EIf(cond, then, othr): | |
condty = self.infer(env, cond, i + 1) | |
self.unify(env, condty, with_parent(TBool(), condty, "inferred type of"), i + 1) | |
thenty = self.infer(env, then, i + 1) | |
othrty = self.infer(env, othr, i + 1) | |
self.unify(env, thenty, othrty, i + 1) | |
return thenty | |
case ELet(name, value, body): | |
self.cur_scope += 1 | |
varty = self.infer(env, value, i + 1) | |
self.cur_scope -= 1 | |
g = self.gen(varty) | |
return self.infer(VarEnv(env, name, g), body, i + 1) | |
case ELetRec(lets, body): | |
# enter scope | |
self.cur_scope += 1 | |
# generate fresh metavars for the bindings | |
vartys = { name: self.fresh(e, f"inferred type of {name} in") for name in lets.keys() } | |
# construct environment (associate bindings with the fresh types) | |
env2 = env | |
for name in lets.keys(): | |
# wait until later to generalize | |
env2 = VarEnv(env2, name, self.dont_gen(vartys[name])) | |
# unify the inferred types and the fresh vars | |
for name, value in lets.items(): | |
r = self.infer(env2, value, i + 1) | |
self.unify(env, vartys[name], r, i + 1) | |
# exit scope | |
self.cur_scope -= 1 | |
# reconstruct the environment but generalize the variables this time | |
env2 = env | |
for name in lets.keys(): | |
env2 = VarEnv(env2, name, self.gen(vartys[name])) | |
# infer body | |
return self.infer(env2, body, i + 1) | |
case EData(name, params, cons, body): | |
ptys = { p: self.fresh(e, f"type parameter {p} in") for p in params } | |
dty = with_parent(TData(name, params, { | |
name: [c.replace_named(ptys) for c in con] | |
for name, con in cons.items() | |
}, ptys), e, "type of") | |
ps = [p.as_unsolved() for p in ptys.values()] | |
env2 = TyEnv(env, name, GenTy(ps, dty)) | |
for con in cons.keys(): | |
# an unary constructor is just the datatype, | |
# >1-ary constructors are special types because we don't really have rank 2 types. | |
conty = TCons(dty, con, []) if len(cons[con]) > 0 else dty | |
env2 = VarEnv(env2, con, GenTy(ps, with_parent(conty, e, f"constructor `{con}` of"))) | |
return self.infer(env2, body, i + 1) | |
case EVar(name): | |
v = env.find_var_check(name) | |
ty = self.inst(v, e) | |
return with_parent(ty, e, "type of") | |
case EMatch(expr, cases): | |
ety = self.infer(env, expr, i + 1) | |
retty = self.fresh(e, "inferred type of") | |
dataty: TApp | None = None # ety will be unified with this | |
for case in cases: | |
consty = self.inst(env.find_var_check(case.name), case) | |
# figure out the type of the constructor; | |
# we need both the instantiated and raw TData versions. | |
match consty: | |
# TData is weird and special (might change that in the future) | |
# instantiated TData has its params in params_inst instead of it being TApp. | |
# we check for both TCons and TData because unary cases are TData and >1ary are TCons. | |
case TCons(): | |
tdata = consty.ty | |
assert consty.ty.params_inst is not None | |
tdata_app = TApp(consty.ty, [consty.ty.params_inst[name] for name in consty.ty.params]) | |
case TData(): | |
tdata = consty | |
assert consty.params_inst is not None | |
tdata_app = TApp(consty, [consty.params_inst[name] for name in consty.params]) | |
case TApp(): # but also maybe TApp idk i need to fix this | |
match consty.to: | |
case TData(): tdata, tdata_app = consty.to, consty | |
case _: raise UnifyError(consty, None, "not a constructor") | |
case _: raise UnifyError(consty, None, "not a constructor") | |
# check if already set or set the datatype | |
if dataty is not None: | |
if dataty.to != tdata: | |
raise UnifyError(dataty.to, tdata, "mismatching constructor types") | |
else: | |
dataty = tdata_app | |
# check if already set or set the datatype | |
argtys = tdata.cons[case.name] | |
if len(argtys) != len(case.args): | |
raise UnifyError(consty, None, f"wrong number of arguments (expected {len(case.args)})") | |
# create argument types | |
env2 = env | |
for name, argty in zip(case.args, argtys): | |
env2 = VarEnv(env2, name, self.dont_gen(argty)) | |
# infer body | |
bodyty = self.infer(env2, case.body, i + 1) | |
self.unify(env2, bodyty, retty, i + 1) | |
if dataty is None: raise UnifyError(ety, None, "no cases") # shouldn't happen | |
self.unify(env, ety, dataty, i + 1) | |
return retty | |
case EInt(_): return with_parent(TNat(), e, "literal type of") | |
case EBool(_): return with_parent(TBool(), e, "literal type of") | |
case _: raise NotImplementedError(e) | |
def unify(self, env: Env, a: Ty, b: Ty, i: int): | |
a = a.force() | |
b = b.force() | |
if isinstance(a, TNamed): | |
return self.unify(env, env.find_ty_check(a.name).under, b, i) | |
if isinstance(b, TNamed): | |
return self.unify(env, a, env.find_ty_check(b.name).under, i) | |
if a == b: return | |
match (a, b): | |
case (TFun(al, ar), TFun(bl, br)): | |
self.unify(env, al, bl, i + 1) | |
self.unify(env, ar, br, i + 1) | |
case (TApp(ato, aps), TApp(bto, bps)) if len(aps) == len(bps): | |
self.unify(env, ato, bto, i + 1) | |
for ap, bp in zip(aps, bps): | |
self.unify(env, ap, bp, i + 1) | |
case (TMeta(rc), t) | (t, TMeta(rc)): | |
match rc.ref: | |
case UnsolvedMeta(id, _): | |
if t.occurs_unify(rc.ref): | |
raise UnifyError(a, b, f"can't unify, ?{id} occurs in {t}") | |
rc.ref = with_parent_copy(SolvedMeta(t), rc.ref._parent) | |
case _: assert False | |
case _: raise UnifyError(a, b, "can't unify") | |
class Parser: | |
import parsita as P | |
import parsita.util | |
import parsita.options | |
parsita.options.whitespace = P.reg(r"\s*(\s*--.*\n\s*)*") | |
@staticmethod | |
def _reduce_to_app(l: list[Expr]) -> Expr: | |
es = iter(l) | |
r = next(es) | |
for e in es: r = EApp(r, e) | |
return r | |
KWS = { 'if', 'then', 'else', 'end', 'in', '*', 'rec', 'let', 'match', 'with' } | |
@staticmethod | |
def _reduce_iden(v: str) -> P.Parser[str, str]: | |
if v in Parser.KWS: return Parser.P.failure("keyword") # type: ignore | |
return Parser.P.success(v) | |
pexpr = P.fwd() | |
pty = P.fwd() | |
piden = P.reg(r"[a-zA-Z_]+") >= _reduce_iden | |
pnatty = P.lit('Nat') > parsita.util.constant(TNat()) | |
pboolty = P.lit('Bool') > parsita.util.constant(TBool) | |
pnamety = piden > TNamed | |
ptypar = P.lit('(') >> pty << P.lit(')') | |
ptyatom = P.first(ptypar, pnatty, pboolty, pnamety) | |
pty.define(P.first((pnamety & P.rep1(ptyatom) > parsita.util.splat(TApp)) | ptyatom)) | |
pvar = piden > EVar | |
pif = P.lit('if') >> pexpr << P.lit('then') & pexpr << P.lit('else') & pexpr << P.lit('end') > parsita.util.splat(EIf) | |
pconsdef = piden & P.repsep(pty, '*') > tuple[str, list[Ty]] | |
pdata = P.lit('data') >> piden & P.rep(piden) << P.lit('=') & \ | |
(P.rep1sep(pconsdef, '|') > dict[str, list[Ty]]) << P.lit('in') & pexpr > parsita.util.splat(EData) | |
pcase = P.lit('|') >> piden & P.rep(piden) << P.lit('=>') & pexpr > parsita.util.splat(ECase) | |
pmatch = P.lit('match') >> pexpr << P.lit('with') & P.rep1(pcase) << P.lit('end') > parsita.util.splat(EMatch) | |
plet = P.lit('let') >> piden << P.lit('=') & pexpr << P.lit('in') & pexpr > parsita.util.splat(ELet) | |
pletrec = P.lit('let') >> P.lit('rec') >> \ | |
(P.rep1sep(piden << P.lit('=') & pexpr > tuple, ',') > dict) << P.lit('in') & pexpr > parsita.util.splat(ELetRec) | |
plam = P.lit('\\') >> piden << P.lit('.') & pexpr > parsita.util.splat(ELam) | |
pnum = (P.reg(r"[0-9]+") > int) > EInt | |
ppar = P.lit('(') >> pexpr << P.lit(')') | |
pbool \ | |
= (P.lit('true') > parsita.util.constant(EBool(True))) \ | |
| (P.lit('false') > parsita.util.constant(EBool(False))) | |
patom = P.first(pmatch, ppar, pif, pletrec, plet, plam, pbool, pnum, pvar | pdata) | |
pexpr.define(P.rep1(patom) > _reduce_to_app) | |
@staticmethod | |
def parse(source: str | bytes) -> P.Result[Expr]: | |
return Parser.pexpr.parse(source) | |
def main(): | |
r = Parser.pexpr.parse( | |
R""" | |
data List a = Cons a * List a | Nil in | |
data Pair a b = Pair a * b in | |
let x = Pair true (Cons true Nil) in | |
match x with | |
| Pair x y => match y with | |
| Cons x xs => x | |
| Nil => x | |
end | |
end | |
""" | |
) | |
match r: | |
case Parser.P.Success(Expr() as e): | |
tc = Tycheck() | |
fmt = Fmt(color=True) | |
print(e.fmt(fmt)) | |
try: | |
print(fmt.cpn(':'), tc.infer(RootEnv(), e, 0).force().fmt(fmt)) | |
except UnifyError as err: | |
if err.b is not None: | |
print(fmt.cpn('│'), fmt.cer(err.reason + ':'), err.a.fmt(fmt), '≠', err.b.fmt(fmt)) | |
else: | |
print(fmt.cpn('│'), fmt.cer(err.reason + ':'), err.a.fmt(fmt)) | |
def print_parents(x: HasFmt, reason: str, prefix: bool, i = 0): | |
p = fmt.cpn("│") if prefix and i > 0 else "" | |
if hasattr(x, "_parent"): | |
t: tuple[HasFmt, str] | None = getattr(x, "_parent") | |
if t is not None: | |
if not reason: | |
print_parents(t[0], t[1], prefix, i) | |
return | |
lns = f"{x.fmt(fmt)}".split("\n") | |
start = "╰" if not prefix or i > 0 else "├" | |
end = "──" if len(lns) == 1 and (not hasattr(t[0], "_parent") or getattr(t[0], "_parent")) is None else "─╮" | |
ind = p + " " * (2*(i)-prefix) | |
if len(lns) == 1: print(ind + fmt.cpn(start + end), fmt.cpn(reason), lns[0]) | |
else: print(ind + fmt.cpn(start + end), fmt.cpn(reason)) | |
for j in range(len(lns) == 1, len(lns)): print(ind + " " + fmt.cpn("│"), lns[j]) | |
print_parents(t[0], t[1], prefix, i + 1) | |
else: | |
print(x, "has no _parent") | |
print(fmt.cpn('│'), err.a.fmt(fmt), 'is from:') | |
print_parents(err.a, "", True) | |
if err.b is not None: | |
print(fmt.cpn('│'), err.b.fmt(fmt), 'is from:') | |
print_parents(err.b, "", False) | |
case Parser.P.Failure(err): # type: ignore | |
print(err) | |
exit(1) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment