Created
February 13, 2022 00:16
-
-
Save mb64/307c8344f3663b04257faa0c359556fd to your computer and use it in GitHub Desktop.
Hindley-Milner type checking with higher-kinded types, in OCaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type name = string | |
module AST = struct | |
type ty = | |
| TFun of ty * ty | |
| TNamed of string | |
| TApp of ty * ty | |
type exp = | |
| Annot of exp * ty | |
| Var of name | |
| App of exp * exp | |
| Lam of name * exp | |
| Let of name * exp * exp | |
end | |
module Tychk = struct | |
type type_id = int | |
type ty = | |
| TCon of type_id (* type constructor, like Int or Maybe *) | |
| TFun of ty * ty | |
| TApp of kind * ty * ty | |
| TUVar of hole ref | |
and hole = | |
| Filled of ty | |
| Empty of { lvl: int } | |
| Generalized of int (* only in type schemes *) | |
and ty_scheme = { num_vars: int; ty: ty } | |
and kind = | |
| Star | |
| KFun of kind * kind | |
type ctx = | |
{ lvl: int | |
; var_types: (name * ty_scheme) list | |
; tyvar_kinds: (name * kind) list | |
; tyvar_values: (name * ty) list } | |
exception Can'tUnify | |
module Kinds = struct | |
(* Super simple bidirectional elaboration for kinds | |
* Because the grammar of type annotations is so restrictive, I don't need | |
* kind unification (yet) *) | |
(* TODO: add actual unification when necessary *) | |
let unify a b = if a <> b then raise Can'tUnify | |
exception KindMismatch of kind * kind | |
let rec check (ctx: ctx) (e: AST.ty) (kind: kind) : ty = | |
(* type level is so simple (no lambdas or anything) that check can just | |
* pass on to infer | |
* Could only have infer but check is convenient to have *) | |
let ty, inferred = infer ctx e in | |
match unify kind inferred with | |
| () -> ty | |
| exception Can'tUnify -> raise (KindMismatch(kind, inferred)) | |
and infer (ctx: ctx) (e: AST.ty) : ty * kind = | |
match e with | |
| TFun(a, b) -> | |
let a', b' = check ctx a Star, check ctx b Star in | |
TFun(a', b'), Star | |
| TNamed name -> | |
let ty = List.assoc name ctx.tyvar_values in | |
let kind = List.assoc name ctx.tyvar_kinds in | |
ty, kind | |
| TApp(f, x) -> | |
match infer ctx f with | |
| f', KFun(a, b) -> TApp(a, f', check ctx x a), b | |
| _ -> failwith "Kind error: should have a function kind" | |
end | |
(* Standard HM generalization, instantiation, and unification *) | |
let rec deref = function | |
| TUVar { contents = Filled ty } -> deref ty | |
| ty -> ty | |
let generalize lvl (ty: ty) : ty_scheme = | |
let counter = ref 0 in | |
let rec go = function | |
| TCon type_id -> TCon type_id | |
| TFun(a, b) -> TFun(go a, go b) | |
| TApp(k, f, x) -> TApp(k, go f, go x) | |
| TUVar ref -> match !ref with | |
| Filled t -> go t | |
| Empty { lvl = hole_lvl } -> | |
if hole_lvl > lvl then begin | |
ref := Generalized !counter; | |
incr counter; | |
end; TUVar ref | |
| Generalized i -> TUVar ref in | |
let generalized = go ty in | |
{ num_vars = !counter; ty = generalized } | |
let instantiate lvl ({ num_vars; ty }: ty_scheme) : ty = | |
let new_holes = | |
Array.init num_vars (fun _ -> TUVar (ref (Empty { lvl }))) in | |
let rec go = function | |
| TCon type_id -> TCon type_id | |
| TFun(a, b) -> TFun(go a, go b) | |
| TApp(k, f, x) -> TApp(k, go f, go x) | |
| TUVar ref -> match !ref with | |
| Generalized i -> new_holes.(i) | |
| _ -> TUVar ref in | |
go ty | |
let rec unify ctx (x: ty) (y: ty) = match deref x, deref y with | |
| TUVar uvar_a, TUVar uvar_b -> unify_two_uvars uvar_a uvar_b | |
| TUVar uvar, b -> unify_uvar ctx uvar b | |
| a, TUVar uvar -> unify_uvar ctx uvar a | |
| TCon tx, TCon ty -> if tx <> ty then raise Can'tUnify | |
| TFun(xa, xb), TFun(ya, yb) -> unify ctx xa ya; unify ctx xb yb | |
| TApp(xk, xa, xb), TApp(yk, ya, yb) -> | |
Kinds.unify xk yk; unify ctx xa ya; unify ctx xb yb | |
| _ -> raise Can'tUnify | |
and unify_two_uvars a b = | |
if a == b then () else | |
let Empty { lvl = a_lvl } = !a in | |
let Empty { lvl = b_lvl } = !b in | |
if a_lvl < b_lvl then b := Filled (TUVar a) else a := Filled (TUVar b) | |
and unify_uvar ctx uvar b = | |
let Empty { lvl } = !uvar in | |
(* occurs check and fixing up levels *) | |
let rec check = function | |
| TCon type_id -> () | |
| TFun(a, b) -> check a; check b | |
| TApp(_, f, x) -> check f; check x | |
| TUVar u -> | |
if u == uvar (* pointer equality! *) | |
then raise Can'tUnify | |
else match !u with | |
| Filled t -> check t | |
| Empty { lvl = l } -> | |
if l > lvl then u := Empty { lvl } in | |
check b; | |
uvar := Filled b | |
(* Standard bidirectional typechecking *) | |
exception TypeMismatch of ty * ty | |
let rec check ctx (e: AST.exp) (ty: ty): unit = match e, deref ty with | |
| Lam(name, body), TFun(a, b) -> | |
let a_scheme = { num_vars = 0; ty = a } in | |
let ctx' = { ctx with var_types = (name, a_scheme) :: ctx.var_types } in | |
check ctx' body b | |
| Let(x, value, body), ty -> | |
let x_ty = infer_and_generalize ctx value in | |
let ctx' = { ctx with var_types = (x, x_ty) :: ctx.var_types } in | |
check ctx' body ty | |
| _ -> | |
let inferred = infer ctx e in | |
try unify ctx inferred ty with | |
| Can'tUnify -> raise (TypeMismatch(inferred, ty)) | |
and infer ctx (e: AST.exp): ty = match e with | |
| Var name -> instantiate ctx.lvl (List.assoc name ctx.var_types) | |
| Annot(e', ast_ty) -> | |
let ty = Kinds.check ctx ast_ty Star in | |
check ctx e' ty; | |
ty | |
| App(f, x) -> begin | |
let f_ty = infer ctx f in | |
match deref f_ty with | |
| TFun(a, b) -> check ctx x a; b | |
| TUVar ({ contents = Empty { lvl }} as uvar) -> | |
let a = TUVar (ref (Empty { lvl })) in | |
let b = TUVar (ref (Empty { lvl })) in | |
uvar := Filled (TFun(a, b)); | |
check ctx x a; b | |
| _ -> failwith "Must be a function" | |
end | |
| Lam(name, body) -> | |
let a = TUVar (ref (Empty { lvl = ctx.lvl })) in | |
let a_scheme = { num_vars = 0; ty = a } in | |
let ctx' = { ctx with var_types = (name, a_scheme) :: ctx.var_types } in | |
let b = infer ctx' body in | |
TFun(a, b) | |
| Let(x, value, body) -> | |
let x_ty = infer_and_generalize ctx value in | |
let ctx' = { ctx with var_types = (x, x_ty) :: ctx.var_types } in | |
infer ctx' body | |
and infer_and_generalize ctx (e: AST.exp) = | |
let ctx' = { ctx with lvl = ctx.lvl + 1 } in | |
let ty = infer ctx' e in | |
generalize ctx.lvl ty | |
end | |
let test (): Tychk.ty_scheme = | |
let open Tychk in | |
let maybe = TCon 0 in | |
let just = | |
let a = TUVar (ref (Generalized 0)) in | |
{ num_vars = 1; ty = TFun(a, TApp(Star, maybe, a)) } in | |
let list = TCon 1 in | |
let cons = | |
let a = TUVar (ref (Generalized 0)) in | |
let list_a = TApp(Star, list, a) in | |
{ num_vars = 1; ty = TFun(a, TFun(list_a, list_a)) } in | |
let ctx: ctx = | |
{ lvl = 0 | |
; var_types = ["just", just; "cons", cons] | |
; tyvar_kinds = ["maybe", KFun(Star, Star); "list", KFun(Star, Star)] | |
; tyvar_values = ["maybe", maybe; "list", list] } in | |
let term = let open AST in (* λ x xs. cons (just x) xs *) | |
Lam("x", Lam("xs", App(App(Var "cons", App(Var "just", Var "x")), Var "xs"))) in | |
infer_and_generalize ctx term |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment