Skip to content

Instantly share code, notes, and snippets.

@mb64
Created February 13, 2022 00:16
Show Gist options
  • Save mb64/307c8344f3663b04257faa0c359556fd to your computer and use it in GitHub Desktop.
Save mb64/307c8344f3663b04257faa0c359556fd to your computer and use it in GitHub Desktop.
Hindley-Milner type checking with higher-kinded types, in OCaml
type name = string
module AST = struct
type ty =
| TFun of ty * ty
| TNamed of string
| TApp of ty * ty
type exp =
| Annot of exp * ty
| Var of name
| App of exp * exp
| Lam of name * exp
| Let of name * exp * exp
end
module Tychk = struct
type type_id = int
type ty =
| TCon of type_id (* type constructor, like Int or Maybe *)
| TFun of ty * ty
| TApp of kind * ty * ty
| TUVar of hole ref
and hole =
| Filled of ty
| Empty of { lvl: int }
| Generalized of int (* only in type schemes *)
and ty_scheme = { num_vars: int; ty: ty }
and kind =
| Star
| KFun of kind * kind
type ctx =
{ lvl: int
; var_types: (name * ty_scheme) list
; tyvar_kinds: (name * kind) list
; tyvar_values: (name * ty) list }
exception Can'tUnify
module Kinds = struct
(* Super simple bidirectional elaboration for kinds
* Because the grammar of type annotations is so restrictive, I don't need
* kind unification (yet) *)
(* TODO: add actual unification when necessary *)
let unify a b = if a <> b then raise Can'tUnify
exception KindMismatch of kind * kind
let rec check (ctx: ctx) (e: AST.ty) (kind: kind) : ty =
(* type level is so simple (no lambdas or anything) that check can just
* pass on to infer
* Could only have infer but check is convenient to have *)
let ty, inferred = infer ctx e in
match unify kind inferred with
| () -> ty
| exception Can'tUnify -> raise (KindMismatch(kind, inferred))
and infer (ctx: ctx) (e: AST.ty) : ty * kind =
match e with
| TFun(a, b) ->
let a', b' = check ctx a Star, check ctx b Star in
TFun(a', b'), Star
| TNamed name ->
let ty = List.assoc name ctx.tyvar_values in
let kind = List.assoc name ctx.tyvar_kinds in
ty, kind
| TApp(f, x) ->
match infer ctx f with
| f', KFun(a, b) -> TApp(a, f', check ctx x a), b
| _ -> failwith "Kind error: should have a function kind"
end
(* Standard HM generalization, instantiation, and unification *)
let rec deref = function
| TUVar { contents = Filled ty } -> deref ty
| ty -> ty
let generalize lvl (ty: ty) : ty_scheme =
let counter = ref 0 in
let rec go = function
| TCon type_id -> TCon type_id
| TFun(a, b) -> TFun(go a, go b)
| TApp(k, f, x) -> TApp(k, go f, go x)
| TUVar ref -> match !ref with
| Filled t -> go t
| Empty { lvl = hole_lvl } ->
if hole_lvl > lvl then begin
ref := Generalized !counter;
incr counter;
end; TUVar ref
| Generalized i -> TUVar ref in
let generalized = go ty in
{ num_vars = !counter; ty = generalized }
let instantiate lvl ({ num_vars; ty }: ty_scheme) : ty =
let new_holes =
Array.init num_vars (fun _ -> TUVar (ref (Empty { lvl }))) in
let rec go = function
| TCon type_id -> TCon type_id
| TFun(a, b) -> TFun(go a, go b)
| TApp(k, f, x) -> TApp(k, go f, go x)
| TUVar ref -> match !ref with
| Generalized i -> new_holes.(i)
| _ -> TUVar ref in
go ty
let rec unify ctx (x: ty) (y: ty) = match deref x, deref y with
| TUVar uvar_a, TUVar uvar_b -> unify_two_uvars uvar_a uvar_b
| TUVar uvar, b -> unify_uvar ctx uvar b
| a, TUVar uvar -> unify_uvar ctx uvar a
| TCon tx, TCon ty -> if tx <> ty then raise Can'tUnify
| TFun(xa, xb), TFun(ya, yb) -> unify ctx xa ya; unify ctx xb yb
| TApp(xk, xa, xb), TApp(yk, ya, yb) ->
Kinds.unify xk yk; unify ctx xa ya; unify ctx xb yb
| _ -> raise Can'tUnify
and unify_two_uvars a b =
if a == b then () else
let Empty { lvl = a_lvl } = !a in
let Empty { lvl = b_lvl } = !b in
if a_lvl < b_lvl then b := Filled (TUVar a) else a := Filled (TUVar b)
and unify_uvar ctx uvar b =
let Empty { lvl } = !uvar in
(* occurs check and fixing up levels *)
let rec check = function
| TCon type_id -> ()
| TFun(a, b) -> check a; check b
| TApp(_, f, x) -> check f; check x
| TUVar u ->
if u == uvar (* pointer equality! *)
then raise Can'tUnify
else match !u with
| Filled t -> check t
| Empty { lvl = l } ->
if l > lvl then u := Empty { lvl } in
check b;
uvar := Filled b
(* Standard bidirectional typechecking *)
exception TypeMismatch of ty * ty
let rec check ctx (e: AST.exp) (ty: ty): unit = match e, deref ty with
| Lam(name, body), TFun(a, b) ->
let a_scheme = { num_vars = 0; ty = a } in
let ctx' = { ctx with var_types = (name, a_scheme) :: ctx.var_types } in
check ctx' body b
| Let(x, value, body), ty ->
let x_ty = infer_and_generalize ctx value in
let ctx' = { ctx with var_types = (x, x_ty) :: ctx.var_types } in
check ctx' body ty
| _ ->
let inferred = infer ctx e in
try unify ctx inferred ty with
| Can'tUnify -> raise (TypeMismatch(inferred, ty))
and infer ctx (e: AST.exp): ty = match e with
| Var name -> instantiate ctx.lvl (List.assoc name ctx.var_types)
| Annot(e', ast_ty) ->
let ty = Kinds.check ctx ast_ty Star in
check ctx e' ty;
ty
| App(f, x) -> begin
let f_ty = infer ctx f in
match deref f_ty with
| TFun(a, b) -> check ctx x a; b
| TUVar ({ contents = Empty { lvl }} as uvar) ->
let a = TUVar (ref (Empty { lvl })) in
let b = TUVar (ref (Empty { lvl })) in
uvar := Filled (TFun(a, b));
check ctx x a; b
| _ -> failwith "Must be a function"
end
| Lam(name, body) ->
let a = TUVar (ref (Empty { lvl = ctx.lvl })) in
let a_scheme = { num_vars = 0; ty = a } in
let ctx' = { ctx with var_types = (name, a_scheme) :: ctx.var_types } in
let b = infer ctx' body in
TFun(a, b)
| Let(x, value, body) ->
let x_ty = infer_and_generalize ctx value in
let ctx' = { ctx with var_types = (x, x_ty) :: ctx.var_types } in
infer ctx' body
and infer_and_generalize ctx (e: AST.exp) =
let ctx' = { ctx with lvl = ctx.lvl + 1 } in
let ty = infer ctx' e in
generalize ctx.lvl ty
end
let test (): Tychk.ty_scheme =
let open Tychk in
let maybe = TCon 0 in
let just =
let a = TUVar (ref (Generalized 0)) in
{ num_vars = 1; ty = TFun(a, TApp(Star, maybe, a)) } in
let list = TCon 1 in
let cons =
let a = TUVar (ref (Generalized 0)) in
let list_a = TApp(Star, list, a) in
{ num_vars = 1; ty = TFun(a, TFun(list_a, list_a)) } in
let ctx: ctx =
{ lvl = 0
; var_types = ["just", just; "cons", cons]
; tyvar_kinds = ["maybe", KFun(Star, Star); "list", KFun(Star, Star)]
; tyvar_values = ["maybe", maybe; "list", list] } in
let term = let open AST in (* λ x xs. cons (just x) xs *)
Lam("x", Lam("xs", App(App(Var "cons", App(Var "just", Var "x")), Var "xs"))) in
infer_and_generalize ctx term
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment