Last active May 22, 2023 21:37
A basic implementation of the Hindley-Milner inference algorithm. Engagement for CMSC 22100 final project.
exp ::=
id |
(exp) | -- ambiguity
exp exp | -- application
\id -> exp | -- abstraction
let id = exp in exp -- binding
type id = string
module IdMap = Map.Make(String)
type token =
LParen |
RParen |
TokLambda |
TokArrow |
TokLet |
TokEq |
TokIn |
TokForall |
TokPeriod |
TokColon |
TokDashes |
TokId of id
type exp =
Id of id |
Apply of exp * exp |
Abstract of id * exp |
LetIn of id * exp * exp
let rec string_of_exp e =
match e with
| Id i -> i
| Apply (e1, e2) -> wrap_if_not_id e1 ^ " " ^ wrap_if_not_id e2
| Abstract (var, body) -> Printf.sprintf "\\%s -> %s" var (string_of_exp body)
| LetIn (var, value, body) -> Printf.sprintf "let %s = %s in %s" var (string_of_exp value) (string_of_exp body)
and wrap_if_not_id e =
match e with
| Id i -> i
| _ -> "(" ^ string_of_exp e ^ ")"
(* tokenizer *)
let rec take_while (pred : 'a -> bool) (xs : 'a list) : ('a list * 'a list) =
match xs with
| [] -> ([], [])
| x::rest ->
if pred x
then let (prefix, suffix) = take_while pred rest in (x::prefix, suffix)
else ([], x::rest)
let is_ident c =
match c with
| 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' -> true
| _ -> false
exception Could_not_tokenize
let rec next_token (cs : char list) : (token * char list) option =
match cs with
| [] -> None
| ' '::rest | '\t'::rest | '\n'::rest -> next_token rest (* skip whitespace *)
| '#'::rest -> let (_, next_line) = take_while ((<>) '\n') rest in next_token next_line (* comments *)
| '('::rest -> Some (LParen, rest)
| ')'::rest -> Some (RParen, rest)
| '\\'::rest -> Some (TokLambda, rest)
| '-'::'>'::rest -> Some (TokArrow, rest)
(* make sure "letx" is tokenized as [TokId "letx"], not [TokLet; TokId "x"] *)
| 'l'::'e'::'t'::c::rest when not (is_ident c) -> Some (TokLet, c::rest)
| '='::rest -> Some (TokEq, rest)
(* ditto *)
| 'i'::'n'::c::rest when not (is_ident c) -> Some (TokIn, c::rest)
| 'f'::'o'::'r'::'a'::'l'::'l'::c::rest when not (is_ident c) -> Some (TokForall, c::rest)
| '\226'::'\136'::'\128'::rest -> Some (TokForall, rest) (* UTF-8 decomposition of U+2200 FOR ALL *)
| '.'::rest -> Some (TokPeriod, rest)
| ':'::rest -> Some (TokColon, rest)
| '-'::'-'::'-'::rest -> Some (TokDashes, rest)
| c::rest when is_ident c ->
let (ident, rest) = take_while is_ident rest in
Some (TokId (c::ident |> List.to_seq |> String.of_seq), rest)
| _ -> raise Could_not_tokenize
let rec scan (ls : char list) : token list =
match next_token ls with
| Some (t, rest) -> t::scan rest
| None -> []
let scan_string s = s |> String.to_seq |> List.of_seq |> scan
(* parser *)
type parse_result =
Success of exp * token list |
Failure |
let rec base_exp (ts : token list) : parse_result =
match ts with
| [] -> End_of_input
| (TokId i)::rest -> Success (Id i, rest)
| LParen::rest -> (match parse_exp rest with
| Success (expr, RParen::rest) -> Success (expr, rest)
| _ -> Failure)
| TokLet::(TokId i)::TokEq::rest -> (match parse_exp rest with
| Success (value, TokIn::rest) -> (match parse_exp rest with
| Success (body, rest) -> Success (LetIn (i, value, body), rest)
| _ -> Failure)
| _ -> Failure)
| TokLambda::(TokId i)::TokArrow::rest -> (match parse_exp rest with
| Success (body, rest) -> Success (Abstract (i, body), rest)
| _ -> Failure)
| _ -> Failure
and parse_apps (ts : token list) : exp list * token list =
match base_exp ts with
| Success (expr, rest) -> let (es, final) = parse_apps rest in (expr::es, final)
| Failure -> ([], ts)
| End_of_input -> ([], [])
and parse_exp (ts : token list) : parse_result =
match parse_apps ts with
| (e1::e2::es, rest) -> Success (List.fold_left (fun acc e -> Apply (acc, e)) (Apply (e1, e2)) es, rest)
| ([e1], rest) -> Success (e1, rest)
| ([], []) -> End_of_input
| ([], _::_) -> Failure
exception Could_not_parse
type type_var = char
type typ =
TVar of type_var |
TNat |
TBool |
TUnit |
TList of typ |
TArrow of typ * typ
let rec string_of_type t =
match t with
| TVar v -> String.make 1 v
| TNat -> "Nat"
| TBool -> "Bool"
| TUnit -> "Unit"
| TList l -> "List " ^ string_of_type l
| TArrow (t1, t2) -> match t1 with
| TArrow _ -> "(" ^ string_of_type t1 ^ ") -> " ^ string_of_type t2
| _ -> string_of_type t1 ^ " -> " ^ string_of_type t2
let rec base_type (ts : token list) : typ * token list =
match ts with
| LParen::rest -> let (t, rest') = parse_type rest in (match rest' with
| RParen::rest'' -> (t, rest'')
| _ -> raise Could_not_parse)
| (TokId "Nat")::rest -> (TNat, rest)
| (TokId "Bool")::rest -> (TBool, rest)
| (TokId "Unit")::rest -> (TUnit, rest)
| (TokId "List")::rest -> let (t, rest') = base_type rest in (TList t, rest')
| (TokId v)::rest when String.length v = 1 -> (TVar v.[0], rest)
| _ -> raise Could_not_parse
and parse_type (ts : token list) : typ * token list =
match base_type ts with
| (t, TokArrow::rest) -> let (ret, rest') = parse_type rest in (TArrow (t, ret), rest')
| (t, rest) -> (t, rest)
type scheme =
SType of typ |
SForall of type_var * scheme
let rec string_of_scheme s =
match s with
| SType t -> string_of_type t
| SForall (v, s') -> Printf.sprintf "\u{2200}%c. %s" v (string_of_scheme s')
let rec parse_scheme (ts : token list) : scheme * token list =
match ts with
| TokForall::(TokId v)::TokPeriod::rest when String.length v = 1 -> let (s, rest') = parse_scheme rest in (SForall (v.[0], s), rest')
| _ -> let (t, rest) = parse_type ts in (SType t, rest)
type assumptions = scheme IdMap.t
let string_of_assumptions a =
IdMap.bindings a |> (fun (i, s) -> Printf.sprintf "%s : %s" i (string_of_scheme s)) |> String.concat "\n"
let next_assumption (ts : token list) : (id * scheme * token list) option =
match ts with
| [] -> None
| (TokId i)::TokColon::rest -> let (s, rest') = parse_scheme rest in Some (i, s, rest')
| _ -> raise Could_not_parse
let rec parse_assumptions (ts : token list) : assumptions =
match next_assumption ts with
| Some (i, s, rest) -> IdMap.add i s (parse_assumptions rest)
| None -> IdMap.empty
(* inference engine *)
module VarSet = Set.Make(Char)
module VarMap = Map.Make(Char)
let concat_of_seq = Seq.fold_left VarSet.union VarSet.empty
(* originally, this checked the values of not_in and just tried not to reuse anything in the environment *)
(* however somehow this led to variables being reused if new ones were defined deeper in the recursion tree *)
(* so now this just iterates through characters and literally never reuses anything *)
(* horrible hack, but it works *)
let next_char = ref 'a'
let new_var (not_in : VarSet.t) : type_var =
let n = !next_char in
next_char := if n = 'z' then 'A' else n |> Char.code |> succ |> Char.chr;
let rec new_vars (not_in : VarSet.t) (len : int) : type_var list =
if len = 0
then []
else let v = new_var not_in in v::new_vars (VarSet.add v not_in) (len - 1)
let rec decompose (s : scheme) : type_var list * typ =
match s with
| SType t -> ([], t)
| SForall (v, s) -> let (vs, t) = decompose s in (v::vs, t)
let rec free_vars (t : typ) : VarSet.t =
match t with
| TVar v -> VarSet.singleton v
| TList t -> free_vars t
| TArrow (t1, t2) -> VarSet.union (free_vars t1) (free_vars t2)
| _ -> VarSet.empty
let rec free_vars_s (s : scheme) : VarSet.t =
match s with
| SType t -> free_vars t
| SForall (v, s) -> VarSet.remove v (free_vars_s s)
type substitution = typ VarMap.t
let string_of_sub s = "{" ^ (VarMap.bindings s |> (fun (a, t) -> Printf.sprintf "%s / %c" (string_of_type t) a) |> String.concat "; ") ^ "}"
let free_vars_sub (s : substitution) : VarSet.t = VarMap.to_seq s |> (fun (_, t) -> free_vars t) |> concat_of_seq
let rec subst (a : substitution) (t : typ) : typ =
match t with
| TVar v -> (match VarMap.find_opt v a with
| Some t -> t
| None -> TVar v)
| TList t -> TList (subst a t)
| TArrow (t1, t2) -> TArrow (subst a t1, subst a t2)
| _ -> t
let rec subst_s (a : substitution) (s : scheme) : scheme =
match s with
| SType t -> SType (subst a t)
| SForall (v, s) ->
let free_in_env = free_vars_sub a in
if VarSet.mem v free_in_env
then let nv = new_var (VarSet.union (free_vars_s s) free_in_env) in
SForall (nv, subst_s (VarMap.add v (TVar nv) a) s) (* v is free in the environment, so rename it to a new var nv *)
else SForall (v, subst_s (VarMap.remove v a) s) (* v is bound in s, so it should not be substituted *)
let unify (t1 : typ) (t2 : typ) : substitution option =
let rec build_sub (t1 : typ) (t2 : typ) (acc : substitution) : substitution option =
if subst acc t1 = subst acc t2 then Some acc else match (subst acc t1, subst acc t2) with
| (TVar v, t2') when free_vars t2' |> VarSet.mem v |> not -> build_sub t1 t2 (VarMap.add v t2' acc)
| (t1', TVar v) when free_vars t1' |> VarSet.mem v |> not -> build_sub t1 t2 (VarMap.add v t1' acc)
| (TList l1, TList l2) -> build_sub l1 l2 acc
| (TArrow (a1, r1), TArrow (a2, r2)) -> Option.bind (build_sub a1 a2 acc) (build_sub r1 r2)
| (_, _) ->
None (* if t1 and t2 are of different base types and neither is a variable, there's nothing we can do to unify *)
in build_sub t1 t2 VarMap.empty
(* perform inner subsitution then combine, thus making sure variables inside a substitution value get substituted as well *)
let concat (s1 : substitution) (s2 : substitution) : substitution =
let s2_inner_sub = (subst s1) s2 in
VarMap.union (fun k v1 v2 -> Some v1) s1 s2_inner_sub
let subst_a (s : substitution) : assumptions -> assumptions = (subst_s s)
let free_vars_a (a : assumptions) : VarSet.t = IdMap.to_seq a |> (fun (_, s) -> free_vars_s s) |> concat_of_seq
let closure (a : assumptions) (t : typ) : scheme =
let closed_vars = VarSet.diff (free_vars t) (free_vars_a a) in
VarSet.fold (fun v s -> SForall (v, s)) closed_vars (SType t)
let rec infer (a : assumptions) (e : exp) : (substitution * typ) option =
(match e with
| Id x -> (match IdMap.find_opt x a with
| Some s ->
let (vs, t) = decompose s in
let used_vars = VarSet.union (free_vars t) (free_vars_a a) in
let inner_sub = new_vars used_vars (List.length vs) |> (fun v -> TVar v) |> List.combine vs |> List.to_seq |> VarMap.of_seq in
Some (VarMap.empty, subst inner_sub t)
| None ->
Printf.printf "Could not find variable %s in the assumptions set!\n" x;
| Apply (e1, e2) -> Option.bind (infer a e1) (fun (s1, t1) ->
Option.bind (infer (subst_a s1 a) e2) (fun (s2, t2) ->
let used_vars = free_vars t1 |> VarSet.union (free_vars t2) |> VarSet.union (free_vars_a a) in
let b = TVar (new_var used_vars) in
let maybe_v = unify (subst s2 t1) (TArrow (t2, b)) in (fun v -> (concat v (concat s2 s1), subst v b)) maybe_v))
| Abstract (x, e1) ->
let b = TVar (new_var (free_vars_a a)) in (fun (s1, t1) -> (s1, TArrow (subst s1 b, t1))) (infer (IdMap.add x (SType b) a) e1)
| LetIn (x, e1, e2) -> Option.bind (infer a e1) (fun (s1, t1) ->
let a' = subst_a s1 a in (fun (s2, t2) -> (concat s2 s1, t2)) (infer (IdMap.add x (closure a' t1) a') e2)))
let infer_scheme (a : assumptions) (e : exp) : (assumptions * scheme) option = (fun (s, t) -> let a' = subst_a s a in (a', closure a' t)) (infer a e)
(* final I/O *)
let main (filename : string) : unit =
let ic = open_in filename in
let input_str = In_channel.input_all ic in
close_in ic;
let tokens = scan_string input_str in
let (t_assumptions, t_exp) = match take_while ((<>) TokDashes) tokens with
| (a, TokDashes::e) -> (a, e)
| _ -> raise Could_not_parse in
let assumptions = parse_assumptions t_assumptions in
let input_expression = match parse_exp t_exp with
| Success (e, []) -> e
| _ -> raise Could_not_parse in
match infer_scheme assumptions input_expression with
| Some (new_assumptions, inferred_scheme) ->
print_endline "Assumptions, after inferred substitution:";
print_endline (string_of_assumptions new_assumptions);
print_endline "---";
print_endline "Inferred type scheme:";
Printf.printf "%s : %s\n" (string_of_exp input_expression) (string_of_scheme inferred_scheme)
| None -> print_endline "The inference algorithm failed!";;
let pp_exp (f : Format.formatter) (e : exp) = Format.pp_print_string f (string_of_exp e);;
let pp_typ (f : Format.formatter) (t : typ) = Format.pp_print_string f (string_of_type t);;
let pp_sub (f : Format.formatter) (s : substitution) = Format.pp_print_string f (string_of_sub s);;
let pp_a (f : Format.formatter) (a : assumptions) = Format.pp_print_string f (string_of_assumptions a);;
(* #install_printer pp_exp;;
#install_printer pp_typ;;
#install_printer pp_sub;;
#install_printer pp_a;;
#trace infer;;
#trace unify;; *)
match Sys.argv with
| [| _; arg |] -> main arg
| args -> Printf.printf "Usage: %s <input filename>\n" args.(0) (* argv always has at least one argument: the executable name *)
