Last active
October 16, 2024 17:59
-
-
Save mb64/4a49d710dcdd1875bebdbc59081acb85 to your computer and use it in GitHub Desktop.
Very simple typechecker for MLTT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* Typechecker for MLTT with a universe hierarchy *) | |
(* Build with: ocamlfind ocamlc -package angstrom,stdio -linkpkg mltt.ml -o mltt *) | |
type name = string | |
module AST = struct | |
type tm = | |
| Var of name | |
| U of int | |
| Pi of name * ty * ty | |
| Annot of tm * ty | |
| Lam of name * tm | |
| App of tm * tm | |
| Let of name * tm * tm | |
and ty = tm | |
end | |
module Tychk = struct | |
type idx = int | |
type lvl = int | |
type d = | |
| DU of int | |
| DPi of d Lazy.t * (d Lazy.t -> d) | |
| DLam of (d Lazy.t -> d) | |
| DNe of lvl * d Lazy.t list | |
type tm = | |
| Var of idx | |
| U of int | |
| Pi of ty * ty | |
| Lam of tm | |
| App of tm * tm | |
| Let of tm * tm | |
and ty = tm | |
let apply f x = match f with | |
| DNe(f, args) -> DNe(f, x::args) | |
| DLam f -> f x | |
| _ -> failwith "internal type error" | |
let rec eval env : tm -> d = function | |
| Var i -> Lazy.force (List.nth env i) | |
| U l -> DU l | |
| Pi(a, b) -> DPi(lazy (eval env a), fun x -> eval (x::env) b) | |
| Lam b -> DLam(fun x -> eval (x::env) b) | |
| App(a, b) -> apply (eval env a) (lazy (eval env b)) | |
| Let(x, b) -> eval (lazy (eval env x) :: env) b | |
let pp t = | |
let next_var = ref 0 in | |
let fresh () = | |
let x = "x" ^ string_of_int !next_var in incr next_var; x in | |
let parens p s = if p then "(" ^ s ^ ")" else s in | |
let rec go p lvl e = function | |
| DU l -> "U" ^ string_of_int l | |
| DPi(lazy a, b) -> | |
let x = fresh () in | |
let v = lazy (DNe(lvl, [])) in | |
parens p ("pi (" ^ x ^ " : " ^ go false lvl e a ^ | |
") -> " ^ go false (lvl+1) (x::e) (b v)) | |
| DLam f -> | |
let x = fresh () in | |
let v = lazy (DNe(lvl, [])) in | |
parens p ("fun " ^ x ^ " -> " ^ go false (lvl+1) (x::e) (f v)) | |
| DNe(f, []) -> List.nth e (lvl - f - 1) | |
| DNe(f, args) -> | |
parens p (List.fold_right (fun (lazy x) f -> f ^ " " ^ go true lvl e x) | |
args (List.nth e (lvl - f - 1))) in | |
go false 0 [] t | |
type ctx = | |
{ lvl : lvl | |
; tys : d Lazy.t list | |
; env : d Lazy.t list | |
; scope : (name * d Lazy.t) list } | |
let empty_ctx : ctx = { lvl = 0 ; tys = [] ; env = [] ; scope = [] } | |
exception TypeError of string | |
let rec eq lvl a b = match a, b with | |
| DU l, DU l' -> if l <> l' then raise (TypeError "universe mismatch") | |
| DPi(lazy a, b), DPi(lazy a', b') -> | |
eq lvl a a'; | |
let v = lazy (DNe(lvl, [])) in | |
eq (lvl+1) (b v) (b' v) | |
| DLam a, a' -> | |
let v = lazy (DNe(lvl, [])) in | |
eq (lvl+1) (a v) (apply a' v) | |
| a, DLam a' -> | |
let v = lazy (DNe(lvl, [])) in | |
eq (lvl+1) (apply a v) (a' v) | |
| DNe(f, args), DNe(f', args') when f = f' -> | |
List.iter2 (fun x y -> eq lvl (Lazy.force x) (Lazy.force y)) args args' | |
| _ -> raise (TypeError "Rigid type mismatch") | |
let rec check (ctx : ctx) (tm : AST.tm) (ty : d) : tm = match tm, ty with | |
| Lam(x, body), DPi(a, b) -> | |
let v = lazy (DNe(ctx.lvl, [])) in | |
let ctx' = | |
{ lvl = ctx.lvl + 1 | |
; tys = a :: ctx.tys | |
; env = v :: ctx.env | |
; scope = (x, a) :: ctx.scope } in | |
Lam(check ctx' body (b v)) | |
| Let(x, e, b), t -> | |
let v, a = infer ctx e in | |
let v' = lazy (eval ctx.env v) in | |
let ctx' = | |
{ ctx with env = v' :: ctx.env; scope = (x, lazy a) :: ctx.scope } in | |
Let(v, check ctx' b t) | |
| e, t -> | |
let v, t' = infer ctx e in | |
eq ctx.lvl t t'; v | |
and infer (ctx : ctx) (tm : AST.tm) : tm * d = match tm with | |
| Var x -> | |
let rec go i = function | |
| [] -> raise (TypeError "Variable not in scope") | |
| (v,t)::_ when v = x -> Var i, Lazy.force t | |
| _::xs -> go (i+1) xs in | |
go 0 ctx.scope | |
| U l -> U l, DU (l+1) | |
| Pi(x, a, b) -> begin match infer ctx a with | |
| a, DU l_a -> | |
let v = lazy (DNe(ctx.lvl, [])) in | |
let a' = lazy (eval ctx.env a) in | |
let ctx' = | |
{ lvl = ctx.lvl + 1 | |
; tys = a' :: ctx.tys | |
; env = v :: ctx.env | |
; scope = (x, a') :: ctx.scope } in | |
begin match infer ctx' b with | |
| b, DU l_b -> Pi(a, b), DU (max l_a l_b) | |
| _ -> raise (TypeError "should be a type") | |
end | |
| _ -> raise (TypeError "should be a type") | |
end | |
| Annot(e, t) -> begin match infer ctx t with | |
| a, DU l_a -> | |
let a' = eval ctx.env a in | |
let v = check ctx e a' in | |
v, a' | |
| _ -> raise (TypeError "should be a type") | |
end | |
| Lam _ -> raise (TypeError "can't infer type of lambda") | |
| App(f, x) -> begin match infer ctx f with | |
| f, DPi(a, b) -> | |
let x = check ctx x (Lazy.force a) in | |
App(f, x), b (lazy (eval ctx.env x)) | |
| _ -> raise (TypeError "should be a function") | |
end | |
| Let(x, e, b) -> | |
let v, a = infer ctx e in | |
let v' = lazy (eval ctx.env v) in | |
let ctx' = | |
{ ctx with env = v' :: ctx.env; scope = (x, lazy a) :: ctx.scope } in | |
let b, t = infer ctx' b in | |
Let(v, b), t | |
end | |
module Parser = struct | |
(* The parser is kinda ugly :/ *) | |
open AST | |
open Angstrom | |
let keywords = ["pi"; "let"; "in"; "fun"] | |
let whitespace = take_while (String.contains " \n\t") | |
let lexeme a = a <* whitespace | |
let ident = lexeme ( | |
let is_start_char c = | |
c = '_' || ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') in | |
let is_ident_char c = is_start_char c || ('0' <= c && c <= '9') in | |
let* s = satisfy is_start_char in | |
let* i = take_while is_ident_char in | |
return (String.make 1 s ^ i)) | |
let str s = lexeme (string s) *> return () | |
let name = | |
let* i = ident in | |
if List.mem i keywords then fail (i ^ " is a keyword") else return i | |
let keyword k = | |
let* i = ident in | |
if i = k then return () else fail ("expected " ^ k) | |
let parens p = str "(" *> p <* str ")" | |
let resolve_name n = | |
if n.[0] = 'U' then | |
match int_of_string_opt (String.sub n 1 (String.length n - 1)) with | |
| Some x -> U x | |
| None -> Var n | |
else | |
Var n | |
let actual_name = let* n = name in | |
match resolve_name n with | |
| Var _ -> return n | |
| _ -> fail (n ^ " is a reserved name") | |
let rec exp = fix (fun exp -> | |
let atomic_exp = parens exp <|> lift resolve_name name in | |
let make_app (f::args) = | |
List.fold_left (fun f arg -> App(f, arg)) f args in | |
let simple_exp = lift make_app (many1 atomic_exp) in | |
let annot_exp = | |
let+ e = simple_exp | |
and+ annot = option (fun e -> e) | |
(lift (fun t e -> Annot(e, t)) (str ":" *> exp)) in | |
annot e in | |
let let_exp = | |
let+ n = keyword "let" *> actual_name <* str "=" | |
and+ e = exp <* keyword "in" | |
and+ body = exp in | |
Let(n, e, body) in | |
let pi_exp = | |
let+ () = keyword "pi" | |
and+ things = many1 begin | |
let+ n = str "(" *> actual_name | |
and+ a = str ":" *> exp <* str ")" | |
in n, a end | |
and+ () = str "->" | |
and+ b = exp in | |
List.fold_right (fun (n, a) r -> Pi(n, a, r)) things b in | |
let lam_exp = | |
let+ ns = keyword "fun" *> many1 actual_name <* str "->" | |
and+ b = exp in | |
List.fold_right (fun x r -> Lam(x, r)) ns b in | |
let_exp <|> pi_exp <|> lam_exp <|> annot_exp <?> "expression") | |
let parse (s: string) = | |
match parse_string ~consume:All (whitespace *> exp) s with | |
| Ok e -> e | |
| Error msg -> failwith msg | |
end | |
let main () = | |
let stdin = Stdio.In_channel.(input_all stdin) in | |
let exp = Parser.parse stdin in | |
let () = print_endline "Checking..." in | |
let open Tychk in | |
let tm, ty = infer empty_ctx exp in | |
print_endline ("type is " ^ pp ty) | |
let () = main () |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment