Last active
May 22, 2023 21:37
-
-
Save nickbclifford/f1da210cee1519f160ae441e69a500fa to your computer and use it in GitHub Desktop.
A basic implementation of the Hindley-Milner inference algorithm. Engagement for CMSC 22100 final project.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
(* | |
exp ::= | |
id | | |
(exp) | -- ambiguity | |
exp exp | -- application | |
\id -> exp | -- abstraction | |
let id = exp in exp -- binding | |
*) | |
type id = string | |
module IdMap = Map.Make(String) | |
type token = | |
LParen | | |
RParen | | |
TokLambda | | |
TokArrow | | |
TokLet | | |
TokEq | | |
TokIn | | |
TokForall | | |
TokPeriod | | |
TokColon | | |
TokDashes | | |
TokId of id | |
type exp = | |
Id of id | | |
Apply of exp * exp | | |
Abstract of id * exp | | |
LetIn of id * exp * exp | |
let rec string_of_exp e = | |
match e with | |
| Id i -> i | |
| Apply (e1, e2) -> wrap_if_not_id e1 ^ " " ^ wrap_if_not_id e2 | |
| Abstract (var, body) -> Printf.sprintf "\\%s -> %s" var (string_of_exp body) | |
| LetIn (var, value, body) -> Printf.sprintf "let %s = %s in %s" var (string_of_exp value) (string_of_exp body) | |
and wrap_if_not_id e = | |
match e with | |
| Id i -> i | |
| _ -> "(" ^ string_of_exp e ^ ")" | |
(* tokenizer *) | |
let rec take_while (pred : 'a -> bool) (xs : 'a list) : ('a list * 'a list) = | |
match xs with | |
| [] -> ([], []) | |
| x::rest -> | |
if pred x | |
then let (prefix, suffix) = take_while pred rest in (x::prefix, suffix) | |
else ([], x::rest) | |
let is_ident c = | |
match c with | |
| 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' -> true | |
| _ -> false | |
exception Could_not_tokenize | |
let rec next_token (cs : char list) : (token * char list) option = | |
match cs with | |
| [] -> None | |
| ' '::rest | '\t'::rest | '\n'::rest -> next_token rest (* skip whitespace *) | |
| '#'::rest -> let (_, next_line) = take_while ((<>) '\n') rest in next_token next_line (* comments *) | |
| '('::rest -> Some (LParen, rest) | |
| ')'::rest -> Some (RParen, rest) | |
| '\\'::rest -> Some (TokLambda, rest) | |
| '-'::'>'::rest -> Some (TokArrow, rest) | |
(* make sure "letx" is tokenized as [TokId "letx"], not [TokLet; TokId "x"] *) | |
| 'l'::'e'::'t'::c::rest when not (is_ident c) -> Some (TokLet, c::rest) | |
| '='::rest -> Some (TokEq, rest) | |
(* ditto *) | |
| 'i'::'n'::c::rest when not (is_ident c) -> Some (TokIn, c::rest) | |
| 'f'::'o'::'r'::'a'::'l'::'l'::c::rest when not (is_ident c) -> Some (TokForall, c::rest) | |
| '\226'::'\136'::'\128'::rest -> Some (TokForall, rest) (* UTF-8 decomposition of U+2200 FOR ALL *) | |
| '.'::rest -> Some (TokPeriod, rest) | |
| ':'::rest -> Some (TokColon, rest) | |
| '-'::'-'::'-'::rest -> Some (TokDashes, rest) | |
| c::rest when is_ident c -> | |
let (ident, rest) = take_while is_ident rest in | |
Some (TokId (c::ident |> List.to_seq |> String.of_seq), rest) | |
| _ -> raise Could_not_tokenize | |
let rec scan (ls : char list) : token list = | |
match next_token ls with | |
| Some (t, rest) -> t::scan rest | |
| None -> [] | |
let scan_string s = s |> String.to_seq |> List.of_seq |> scan | |
(* parser *) | |
type parse_result = | |
Success of exp * token list | | |
Failure | | |
End_of_input | |
let rec base_exp (ts : token list) : parse_result = | |
match ts with | |
| [] -> End_of_input | |
| (TokId i)::rest -> Success (Id i, rest) | |
| LParen::rest -> (match parse_exp rest with | |
| Success (expr, RParen::rest) -> Success (expr, rest) | |
| _ -> Failure) | |
| TokLet::(TokId i)::TokEq::rest -> (match parse_exp rest with | |
| Success (value, TokIn::rest) -> (match parse_exp rest with | |
| Success (body, rest) -> Success (LetIn (i, value, body), rest) | |
| _ -> Failure) | |
| _ -> Failure) | |
| TokLambda::(TokId i)::TokArrow::rest -> (match parse_exp rest with | |
| Success (body, rest) -> Success (Abstract (i, body), rest) | |
| _ -> Failure) | |
| _ -> Failure | |
and parse_apps (ts : token list) : exp list * token list = | |
match base_exp ts with | |
| Success (expr, rest) -> let (es, final) = parse_apps rest in (expr::es, final) | |
| Failure -> ([], ts) | |
| End_of_input -> ([], []) | |
and parse_exp (ts : token list) : parse_result = | |
match parse_apps ts with | |
| (e1::e2::es, rest) -> Success (List.fold_left (fun acc e -> Apply (acc, e)) (Apply (e1, e2)) es, rest) | |
| ([e1], rest) -> Success (e1, rest) | |
| ([], []) -> End_of_input | |
| ([], _::_) -> Failure | |
exception Could_not_parse | |
type type_var = char | |
type typ = | |
TVar of type_var | | |
TNat | | |
TBool | | |
TUnit | | |
TList of typ | | |
TArrow of typ * typ | |
let rec string_of_type t = | |
match t with | |
| TVar v -> String.make 1 v | |
| TNat -> "Nat" | |
| TBool -> "Bool" | |
| TUnit -> "Unit" | |
| TList l -> "List " ^ string_of_type l | |
| TArrow (t1, t2) -> match t1 with | |
| TArrow _ -> "(" ^ string_of_type t1 ^ ") -> " ^ string_of_type t2 | |
| _ -> string_of_type t1 ^ " -> " ^ string_of_type t2 | |
let rec base_type (ts : token list) : typ * token list = | |
match ts with | |
| LParen::rest -> let (t, rest') = parse_type rest in (match rest' with | |
| RParen::rest'' -> (t, rest'') | |
| _ -> raise Could_not_parse) | |
| (TokId "Nat")::rest -> (TNat, rest) | |
| (TokId "Bool")::rest -> (TBool, rest) | |
| (TokId "Unit")::rest -> (TUnit, rest) | |
| (TokId "List")::rest -> let (t, rest') = base_type rest in (TList t, rest') | |
| (TokId v)::rest when String.length v = 1 -> (TVar v.[0], rest) | |
| _ -> raise Could_not_parse | |
and parse_type (ts : token list) : typ * token list = | |
match base_type ts with | |
| (t, TokArrow::rest) -> let (ret, rest') = parse_type rest in (TArrow (t, ret), rest') | |
| (t, rest) -> (t, rest) | |
type scheme = | |
SType of typ | | |
SForall of type_var * scheme | |
let rec string_of_scheme s = | |
match s with | |
| SType t -> string_of_type t | |
| SForall (v, s') -> Printf.sprintf "\u{2200}%c. %s" v (string_of_scheme s') | |
let rec parse_scheme (ts : token list) : scheme * token list = | |
match ts with | |
| TokForall::(TokId v)::TokPeriod::rest when String.length v = 1 -> let (s, rest') = parse_scheme rest in (SForall (v.[0], s), rest') | |
| _ -> let (t, rest) = parse_type ts in (SType t, rest) | |
type assumptions = scheme IdMap.t | |
let string_of_assumptions a = | |
IdMap.bindings a |> List.map (fun (i, s) -> Printf.sprintf "%s : %s" i (string_of_scheme s)) |> String.concat "\n" | |
let next_assumption (ts : token list) : (id * scheme * token list) option = | |
match ts with | |
| [] -> None | |
| (TokId i)::TokColon::rest -> let (s, rest') = parse_scheme rest in Some (i, s, rest') | |
| _ -> raise Could_not_parse | |
let rec parse_assumptions (ts : token list) : assumptions = | |
match next_assumption ts with | |
| Some (i, s, rest) -> IdMap.add i s (parse_assumptions rest) | |
| None -> IdMap.empty | |
(* inference engine *) | |
module VarSet = Set.Make(Char) | |
module VarMap = Map.Make(Char) | |
let concat_of_seq = Seq.fold_left VarSet.union VarSet.empty | |
(* originally, this checked the values of not_in and just tried not to reuse anything in the environment *) | |
(* however somehow this led to variables being reused if new ones were defined deeper in the recursion tree *) | |
(* so now this just iterates through characters and literally never reuses anything *) | |
(* horrible hack, but it works *) | |
let next_char = ref 'a' | |
let new_var (not_in : VarSet.t) : type_var = | |
let n = !next_char in | |
next_char := if n = 'z' then 'A' else n |> Char.code |> succ |> Char.chr; | |
n | |
let rec new_vars (not_in : VarSet.t) (len : int) : type_var list = | |
if len = 0 | |
then [] | |
else let v = new_var not_in in v::new_vars (VarSet.add v not_in) (len - 1) | |
let rec decompose (s : scheme) : type_var list * typ = | |
match s with | |
| SType t -> ([], t) | |
| SForall (v, s) -> let (vs, t) = decompose s in (v::vs, t) | |
let rec free_vars (t : typ) : VarSet.t = | |
match t with | |
| TVar v -> VarSet.singleton v | |
| TList t -> free_vars t | |
| TArrow (t1, t2) -> VarSet.union (free_vars t1) (free_vars t2) | |
| _ -> VarSet.empty | |
let rec free_vars_s (s : scheme) : VarSet.t = | |
match s with | |
| SType t -> free_vars t | |
| SForall (v, s) -> VarSet.remove v (free_vars_s s) | |
type substitution = typ VarMap.t | |
let string_of_sub s = "{" ^ (VarMap.bindings s |> List.map (fun (a, t) -> Printf.sprintf "%s / %c" (string_of_type t) a) |> String.concat "; ") ^ "}" | |
let free_vars_sub (s : substitution) : VarSet.t = VarMap.to_seq s |> Seq.map (fun (_, t) -> free_vars t) |> concat_of_seq | |
let rec subst (a : substitution) (t : typ) : typ = | |
match t with | |
| TVar v -> (match VarMap.find_opt v a with | |
| Some t -> t | |
| None -> TVar v) | |
| TList t -> TList (subst a t) | |
| TArrow (t1, t2) -> TArrow (subst a t1, subst a t2) | |
| _ -> t | |
let rec subst_s (a : substitution) (s : scheme) : scheme = | |
match s with | |
| SType t -> SType (subst a t) | |
| SForall (v, s) -> | |
let free_in_env = free_vars_sub a in | |
if VarSet.mem v free_in_env | |
then let nv = new_var (VarSet.union (free_vars_s s) free_in_env) in | |
SForall (nv, subst_s (VarMap.add v (TVar nv) a) s) (* v is free in the environment, so rename it to a new var nv *) | |
else SForall (v, subst_s (VarMap.remove v a) s) (* v is bound in s, so it should not be substituted *) | |
let unify (t1 : typ) (t2 : typ) : substitution option = | |
let rec build_sub (t1 : typ) (t2 : typ) (acc : substitution) : substitution option = | |
if subst acc t1 = subst acc t2 then Some acc else match (subst acc t1, subst acc t2) with | |
| (TVar v, t2') when free_vars t2' |> VarSet.mem v |> not -> build_sub t1 t2 (VarMap.add v t2' acc) | |
| (t1', TVar v) when free_vars t1' |> VarSet.mem v |> not -> build_sub t1 t2 (VarMap.add v t1' acc) | |
| (TList l1, TList l2) -> build_sub l1 l2 acc | |
| (TArrow (a1, r1), TArrow (a2, r2)) -> Option.bind (build_sub a1 a2 acc) (build_sub r1 r2) | |
| (_, _) -> | |
None (* if t1 and t2 are of different base types and neither is a variable, there's nothing we can do to unify *) | |
in build_sub t1 t2 VarMap.empty | |
(* perform inner subsitution then combine, thus making sure variables inside a substitution value get substituted as well *) | |
let concat (s1 : substitution) (s2 : substitution) : substitution = | |
let s2_inner_sub = VarMap.map (subst s1) s2 in | |
VarMap.union (fun k v1 v2 -> Some v1) s1 s2_inner_sub | |
let subst_a (s : substitution) : assumptions -> assumptions = IdMap.map (subst_s s) | |
let free_vars_a (a : assumptions) : VarSet.t = IdMap.to_seq a |> Seq.map (fun (_, s) -> free_vars_s s) |> concat_of_seq | |
let closure (a : assumptions) (t : typ) : scheme = | |
let closed_vars = VarSet.diff (free_vars t) (free_vars_a a) in | |
VarSet.fold (fun v s -> SForall (v, s)) closed_vars (SType t) | |
let rec infer (a : assumptions) (e : exp) : (substitution * typ) option = | |
(match e with | |
| Id x -> (match IdMap.find_opt x a with | |
| Some s -> | |
let (vs, t) = decompose s in | |
let used_vars = VarSet.union (free_vars t) (free_vars_a a) in | |
let inner_sub = new_vars used_vars (List.length vs) |> List.map (fun v -> TVar v) |> List.combine vs |> List.to_seq |> VarMap.of_seq in | |
Some (VarMap.empty, subst inner_sub t) | |
| None -> | |
Printf.printf "Could not find variable %s in the assumptions set!\n" x; | |
None) | |
| Apply (e1, e2) -> Option.bind (infer a e1) (fun (s1, t1) -> | |
Option.bind (infer (subst_a s1 a) e2) (fun (s2, t2) -> | |
let used_vars = free_vars t1 |> VarSet.union (free_vars t2) |> VarSet.union (free_vars_a a) in | |
let b = TVar (new_var used_vars) in | |
let maybe_v = unify (subst s2 t1) (TArrow (t2, b)) in | |
Option.map (fun v -> (concat v (concat s2 s1), subst v b)) maybe_v)) | |
| Abstract (x, e1) -> | |
let b = TVar (new_var (free_vars_a a)) in | |
Option.map (fun (s1, t1) -> (s1, TArrow (subst s1 b, t1))) (infer (IdMap.add x (SType b) a) e1) | |
| LetIn (x, e1, e2) -> Option.bind (infer a e1) (fun (s1, t1) -> | |
let a' = subst_a s1 a in | |
Option.map (fun (s2, t2) -> (concat s2 s1, t2)) (infer (IdMap.add x (closure a' t1) a') e2))) | |
let infer_scheme (a : assumptions) (e : exp) : (assumptions * scheme) option = | |
Option.map (fun (s, t) -> let a' = subst_a s a in (a', closure a' t)) (infer a e) | |
(* final I/O *) | |
let main (filename : string) : unit = | |
let ic = open_in filename in | |
let input_str = In_channel.input_all ic in | |
close_in ic; | |
let tokens = scan_string input_str in | |
let (t_assumptions, t_exp) = match take_while ((<>) TokDashes) tokens with | |
| (a, TokDashes::e) -> (a, e) | |
| _ -> raise Could_not_parse in | |
let assumptions = parse_assumptions t_assumptions in | |
let input_expression = match parse_exp t_exp with | |
| Success (e, []) -> e | |
| _ -> raise Could_not_parse in | |
match infer_scheme assumptions input_expression with | |
| Some (new_assumptions, inferred_scheme) -> | |
print_endline "Assumptions, after inferred substitution:"; | |
print_endline (string_of_assumptions new_assumptions); | |
print_endline "---"; | |
print_endline "Inferred type scheme:"; | |
Printf.printf "%s : %s\n" (string_of_exp input_expression) (string_of_scheme inferred_scheme) | |
| None -> print_endline "The inference algorithm failed!";; | |
let pp_exp (f : Format.formatter) (e : exp) = Format.pp_print_string f (string_of_exp e);; | |
let pp_typ (f : Format.formatter) (t : typ) = Format.pp_print_string f (string_of_type t);; | |
let pp_sub (f : Format.formatter) (s : substitution) = Format.pp_print_string f (string_of_sub s);; | |
let pp_a (f : Format.formatter) (a : assumptions) = Format.pp_print_string f (string_of_assumptions a);; | |
(* #install_printer pp_exp;; | |
#install_printer pp_typ;; | |
#install_printer pp_sub;; | |
#install_printer pp_a;; | |
#trace infer;; | |
#trace unify;; *) | |
match Sys.argv with | |
| [| _; arg |] -> main arg | |
| args -> Printf.printf "Usage: %s <input filename>\n" args.(0) (* argv always has at least one argument: the executable name *) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment