Skip to content

Instantly share code, notes, and snippets.

@hsk
Last active October 27, 2017 07:48
Show Gist options
  • Save hsk/45e2f8432dd18c29d8ff361bb7d45559 to your computer and use it in GitHub Desktop.
Save hsk/45e2f8432dd18c29d8ff361bb7d45559 to your computer and use it in GitHub Desktop.
type x = string
type e =
| EInt of int
| EBool of bool
| EVar of x
| EApp of e * e
| EAbs of x * e
| ELet of x * e * e
let rec show_e = function
| EInt i -> string_of_int i
| EBool b -> Printf.sprintf "%b" b
| EVar x -> x
| EApp(e1,e2) -> Printf.sprintf "(%s %s)" (show_e e1) (show_e e2)
| EAbs(x,e) -> Printf.sprintf "fun %s->%s" x (show_e e)
| ELet(x,e1,e2) -> Printf.sprintf "(let %s = %s in %s)" x (show_e e1) (show_e e2)
type t =
| TVar of x
| TInt
| TBool
| TArr of t * t
| TQ of x list * t
let rec show_t = function
| TVar x -> x
| TInt -> "int"
| TBool -> "bool"
| TArr(t1,t2) -> Printf.sprintf "(%s->%s)" (show_t t1) (show_t t2)
| TQ(xs,t) -> Printf.sprintf "∀(%s.%s)" (String.concat " " xs) (show_t t)
type assump = (x * t) list
type subst = (x * t) list
let rec show_a a =
Printf.sprintf "[%s]"
(String.concat "," (List.map(fun (x,t)->x ^ ":" ^show_t t) a))
let rec show_s s =
Printf.sprintf "[%s]"
(String.concat "," (List.map(fun (x,t)-> x ^ "=" ^show_t t) s))
let rec union xs ys =
match xs, ys with
| [], ys -> ys
| x::xs, ys when List.mem x ys -> union xs ys
| x::xs, ys -> x::union xs ys
let rec subtract xs ys =
match xs with
| [] -> []
| x::xs when List.mem x ys -> subtract xs ys
| x::xs -> x::subtract xs ys
let newTVar c =
(c + 1, TVar("x" ^ string_of_int c))
let rec subst (s:subst) = function
| (TVar x) ->
begin try match List.assoc x s with
| (TVar x2) -> subst s (TVar x2)
| t -> t
with _ -> TVar x
end
| TArr(t1, t2) -> TArr (subst s t1, subst s t2)
| TQ(xs,t) -> let s = List.filter(fun(x,_)->List.mem x xs) s in TQ(xs,subst s t)
| t -> t
let rec occurs = function
| (x,TVar y) when x = y -> failwith ("unify occurs error " ^ x ^ " " ^ y)
| (x,TArr(t2,t3)) -> occurs (x, t2); occurs (x, t3)
| (x,TQ(ys,t)) when List.mem x ys -> failwith ("unify occurs error " ^ x ^ " [" ^ (String.concat "," ys) ^ "]")
| (x,TQ(ys,t)) -> occurs (x, t)
| (_,_) -> ()
let rec unify s (t1, t2) =
(subst s t1, subst s t2) |> function
| (t1, t2) when t1 == t2 -> s
| (TVar x1, t) | (t, TVar x1) -> occurs (x1, t); (x1,t) :: s
| (TArr(t1, t2), TArr(t3, t4)) -> let s1 = unify s (t1, t3) in unify s1 (t2, t4)
| (t1, t2) -> failwith ("unify error (" ^ show_t t1 ^ "," ^show_t t2 ^ ") " ^ show_s s)
let gen c s (a:assump) t =
let rec fvt = function
| TVar(x) -> [x]
| TArr(t1, t2) -> union (fvt t1) (fvt t2)
| TQ(vs, t) -> subtract (fvt t) vs
| _ -> []
in
let fv_assump (a:assump): x list =
List.fold_left (fun (fv1:x list) (_,t) ->
union fv1 (fvt (subst s t))
) [] a
in
let fvs = subtract (fvt (subst s t)) (fv_assump a) in
let (c2,s,fvs2) = List.fold_left (fun(c,s,fvs) fv->
newTVar c |> function
| (c2, TVar x) -> (c2, (fv, TVar x)::s, x::fvs)
| _ -> assert false
) (c,[],[]) fvs in
(c2, TQ(fvs2, subst s t))
let rec inst c = function
| TQ(xs, t) ->
List.fold_left (fun (c,t) x ->
let (c',nt) = newTVar c in
(c', subst [(x, nt)] t)
) (c,t) xs
| t -> (c, t)
let rec tp c s a = function
| EInt(i) -> (c, s, TInt)
| EBool(b) -> (c, s, TBool)
| EVar(x) ->
begin try
let t = List.assoc x a in
let (c',t') = inst c t in
(c', s, t')
with Not_found -> failwith ("lookup error " ^ x)
end
| EApp(e1, e2) ->
let (c1, s1, t1) = tp c s a e1 in
let (c2, s2, t2) = tp c1 s1 a e2 in
let (c3, t3) = newTVar c2 in
let s3 = unify s2 (TArr(t2, t3), t1) in
(c3, s3, t3)
| EAbs(x, e) ->
let (c1, t1) = newTVar c in
let (c2, s2, t2) = tp c1 s ((x,t1)::a) e in
(c2, s2, TArr(t1, t2))
| ELet(x, e1, e2) ->
let (c1, s1, t1) = tp c s a e1 in
let (c1', t1') = gen c1 s1 a t1 in
tp c1' s1 ((x,t1')::a) e2
let tp1 e t =
Printf.printf "test %s " (show_e e);
begin try
let (c1, s1, t1) = tp 0 [] [] e in
let t2 = subst s1 t1 in
if t = t2 then Printf.printf "ok\n" else
Printf.printf "error %s : expected %s but %s\n"
(show_e e) (show_t t) (show_t t2)
with Failure m ->
Printf.printf "error %s\n" m
end
let () =
tp1 (EInt 1) TInt;
tp1 (EBool true) TBool;
tp1 (EBool false) TBool;
tp1 (ELet("x", EInt 1, EVar "x")) TInt;
tp1 (ELet("id",EAbs("x",EVar "x"), EVar "id")) (TArr(TVar "x2",TVar "x2"));
tp1 (ELet("id",EAbs("x",EVar "x"), EApp(EVar "id",EVar "id"))) (TArr(TVar "x3",TVar "x3"));
tp1 (ELet("id",EAbs("x",EVar "x"), EApp(EApp(EVar "id",EVar "id"), EInt 1))) TInt
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment