Last active
July 3, 2023 15:12
-
-
Save aradarbel10/837aa65d2f06ac6710c6fbe479909b4c to your computer and use it in GitHub Desktop.
minimal STLC type inference with mutable metavars
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* language definition *) | |
type nom = string | |
type bop = Add | Sub | Mul | |
type typ = | |
| Int | Arrow of typ * typ | Meta of meta | |
and meta = meta_state ref | |
and meta_state = | |
| Solved of typ | |
| Unsolved of nom (* keep name for pretty printing *) | |
type expr = | |
| Num of int | |
| Bop of bop * expr * expr | |
| Var of nom | |
| Lam of nom * expr | |
| App of expr * expr | |
| Let of nom * expr * expr | |
(* pretty printing *) | |
let rec string_of_typ : typ -> string = function | |
| Int -> "int" | |
| Arrow (t0, t1) -> "(" ^ string_of_typ t0 ^ " -> " ^ string_of_typ t1 ^ ")" | |
| Meta m -> string_of_meta !m | |
and string_of_meta : meta_state -> string = function | |
| Solved t -> string_of_typ t | |
| Unsolved x -> "?" ^ x | |
let rec string_of_expr : expr -> string = function | |
| Num n -> string_of_int n | |
| Bop (op, e0, e1) -> "(" ^ string_of_expr e0 ^ " " ^ string_of_op op ^ " " ^ string_of_expr e1 ^ ")" | |
| Var x -> x | |
| Lam (x, e) -> "(𝜆" ^ x ^ ". " ^ string_of_expr e ^ ")" | |
| App (e0, e1) -> "(" ^ string_of_expr e0 ^ " " ^ string_of_expr e1 ^ ")" | |
| Let (x, e, e') -> "(let " ^ x ^ string_of_expr e ^ " = " ^ string_of_expr e' ^ ")" | |
and string_of_op : bop -> string = function | |
| Add -> "+" | |
| Sub -> "-" | |
| Mul -> "*" | |
(* some exceptions *) | |
exception UndefinedVar of nom | |
exception UnUnifiable of typ * typ | |
exception OccursFailure | |
(* fresh name supply *) | |
module Fresh : sig | |
val freshi : int ref | |
val nexti : unit -> int | |
val fresh : unit -> typ | |
end = struct | |
let freshi = ref 0 | |
let nexti () = | |
let curr = !freshi in | |
freshi := curr + 1; | |
curr | |
let fresh () = Meta (ref (Unsolved ("x" ^ string_of_int (nexti ())))) | |
end | |
open Fresh | |
(* metavar forcing: | |
before pattern matching on a type, we always force it to follow the "links" *) | |
let rec force : typ -> typ = function | |
| Meta m as t -> | |
begin match !m with | |
| Solved t -> force t | |
| _ -> t | |
end | |
| t -> t | |
(* unification *) | |
let rec occurs (x : nom) (t : typ) : unit = | |
match force t with | |
| Meta m -> | |
begin match !m with | |
| Unsolved x' when x = x' -> raise OccursFailure | |
| _ -> () | |
end | |
| Int -> () | |
| Arrow (t0, t1) -> occurs x t0; occurs x t1 | |
let rec unify (t0, t1 : typ * typ) : unit = | |
match force t0, force t1 with | |
| t0, t1 when t0 = t1 -> () | |
| Meta m, t | t, Meta m -> | |
begin match !m with | |
| Unsolved x -> occurs x t; m := Solved t | |
| Solved _ -> failwith "absurd!" (* impossible case, since `force` never returns a solved meta *) | |
end | |
| Int, Int -> () | |
| Arrow (t0, t1), Arrow (t0', t1') -> unify (t0, t1); unify (t0', t1') | |
| t0, t1 -> raise (UnUnifiable (t0, t1)) | |
(* type inference itself *) | |
type ctx = (nom * typ) list | |
let rec infer (ctx : ctx) : expr -> typ = function | |
| Num n -> Int | |
| Bop (_, e0, e1) -> | |
check ctx e0 Int; | |
check ctx e1 Int; | |
Int | |
| Var x -> | |
begin match List.assoc_opt x ctx with | |
| Some t -> t | |
| None -> raise (UndefinedVar x) | |
end | |
| Lam (x, e) -> | |
let t0 = fresh () in | |
let t1 = infer ((x, t0) :: ctx) e in | |
Arrow (t0, t1) | |
| App (e0, e1) -> | |
let arg_typ = infer ctx e1 in | |
let ret_typ = fresh () in | |
check ctx e0 (Arrow (arg_typ, ret_typ)); | |
ret_typ | |
| Let (x, e, e') -> | |
let t = infer ctx e in | |
infer ((x, t) :: ctx) e' | |
and check (ctx : ctx) (e : expr) (t : typ) : unit = | |
let t' = infer ctx e in | |
try unify (t, t') with | |
| UnUnifiable (expected, actual) -> failwith ("expected type " ^ string_of_typ expected ^ " but received " ^ string_of_typ actual) | |
let () = print_endline "hello STLC!" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment