Last active
October 27, 2017 07:48
-
-
Save hsk/45e2f8432dd18c29d8ff361bb7d45559 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type x = string | |
type e = | |
| EInt of int | |
| EBool of bool | |
| EVar of x | |
| EApp of e * e | |
| EAbs of x * e | |
| ELet of x * e * e | |
let rec show_e = function | |
| EInt i -> string_of_int i | |
| EBool b -> Printf.sprintf "%b" b | |
| EVar x -> x | |
| EApp(e1,e2) -> Printf.sprintf "(%s %s)" (show_e e1) (show_e e2) | |
| EAbs(x,e) -> Printf.sprintf "fun %s->%s" x (show_e e) | |
| ELet(x,e1,e2) -> Printf.sprintf "(let %s = %s in %s)" x (show_e e1) (show_e e2) | |
type t = | |
| TVar of x | |
| TInt | |
| TBool | |
| TArr of t * t | |
| TQ of x list * t | |
let rec show_t = function | |
| TVar x -> x | |
| TInt -> "int" | |
| TBool -> "bool" | |
| TArr(t1,t2) -> Printf.sprintf "(%s->%s)" (show_t t1) (show_t t2) | |
| TQ(xs,t) -> Printf.sprintf "∀(%s.%s)" (String.concat " " xs) (show_t t) | |
type assump = (x * t) list | |
type subst = (x * t) list | |
let rec show_a a = | |
Printf.sprintf "[%s]" | |
(String.concat "," (List.map(fun (x,t)->x ^ ":" ^show_t t) a)) | |
let rec show_s s = | |
Printf.sprintf "[%s]" | |
(String.concat "," (List.map(fun (x,t)-> x ^ "=" ^show_t t) s)) | |
let rec union xs ys = | |
match xs, ys with | |
| [], ys -> ys | |
| x::xs, ys when List.mem x ys -> union xs ys | |
| x::xs, ys -> x::union xs ys | |
let rec subtract xs ys = | |
match xs with | |
| [] -> [] | |
| x::xs when List.mem x ys -> subtract xs ys | |
| x::xs -> x::subtract xs ys | |
let newTVar c = | |
(c + 1, TVar("x" ^ string_of_int c)) | |
let rec subst (s:subst) = function | |
| (TVar x) -> | |
begin try match List.assoc x s with | |
| (TVar x2) -> subst s (TVar x2) | |
| t -> t | |
with _ -> TVar x | |
end | |
| TArr(t1, t2) -> TArr (subst s t1, subst s t2) | |
| TQ(xs,t) -> let s = List.filter(fun(x,_)->List.mem x xs) s in TQ(xs,subst s t) | |
| t -> t | |
let rec occurs = function | |
| (x,TVar y) when x = y -> failwith ("unify occurs error " ^ x ^ " " ^ y) | |
| (x,TArr(t2,t3)) -> occurs (x, t2); occurs (x, t3) | |
| (x,TQ(ys,t)) when List.mem x ys -> failwith ("unify occurs error " ^ x ^ " [" ^ (String.concat "," ys) ^ "]") | |
| (x,TQ(ys,t)) -> occurs (x, t) | |
| (_,_) -> () | |
let rec unify s (t1, t2) = | |
(subst s t1, subst s t2) |> function | |
| (t1, t2) when t1 == t2 -> s | |
| (TVar x1, t) | (t, TVar x1) -> occurs (x1, t); (x1,t) :: s | |
| (TArr(t1, t2), TArr(t3, t4)) -> let s1 = unify s (t1, t3) in unify s1 (t2, t4) | |
| (t1, t2) -> failwith ("unify error (" ^ show_t t1 ^ "," ^show_t t2 ^ ") " ^ show_s s) | |
let gen c s (a:assump) t = | |
let rec fvt = function | |
| TVar(x) -> [x] | |
| TArr(t1, t2) -> union (fvt t1) (fvt t2) | |
| TQ(vs, t) -> subtract (fvt t) vs | |
| _ -> [] | |
in | |
let fv_assump (a:assump): x list = | |
List.fold_left (fun (fv1:x list) (_,t) -> | |
union fv1 (fvt (subst s t)) | |
) [] a | |
in | |
let fvs = subtract (fvt (subst s t)) (fv_assump a) in | |
let (c2,s,fvs2) = List.fold_left (fun(c,s,fvs) fv-> | |
newTVar c |> function | |
| (c2, TVar x) -> (c2, (fv, TVar x)::s, x::fvs) | |
| _ -> assert false | |
) (c,[],[]) fvs in | |
(c2, TQ(fvs2, subst s t)) | |
let rec inst c = function | |
| TQ(xs, t) -> | |
List.fold_left (fun (c,t) x -> | |
let (c',nt) = newTVar c in | |
(c', subst [(x, nt)] t) | |
) (c,t) xs | |
| t -> (c, t) | |
let rec tp c s a = function | |
| EInt(i) -> (c, s, TInt) | |
| EBool(b) -> (c, s, TBool) | |
| EVar(x) -> | |
begin try | |
let t = List.assoc x a in | |
let (c',t') = inst c t in | |
(c', s, t') | |
with Not_found -> failwith ("lookup error " ^ x) | |
end | |
| EApp(e1, e2) -> | |
let (c1, s1, t1) = tp c s a e1 in | |
let (c2, s2, t2) = tp c1 s1 a e2 in | |
let (c3, t3) = newTVar c2 in | |
let s3 = unify s2 (TArr(t2, t3), t1) in | |
(c3, s3, t3) | |
| EAbs(x, e) -> | |
let (c1, t1) = newTVar c in | |
let (c2, s2, t2) = tp c1 s ((x,t1)::a) e in | |
(c2, s2, TArr(t1, t2)) | |
| ELet(x, e1, e2) -> | |
let (c1, s1, t1) = tp c s a e1 in | |
let (c1', t1') = gen c1 s1 a t1 in | |
tp c1' s1 ((x,t1')::a) e2 | |
let tp1 e t = | |
Printf.printf "test %s " (show_e e); | |
begin try | |
let (c1, s1, t1) = tp 0 [] [] e in | |
let t2 = subst s1 t1 in | |
if t = t2 then Printf.printf "ok\n" else | |
Printf.printf "error %s : expected %s but %s\n" | |
(show_e e) (show_t t) (show_t t2) | |
with Failure m -> | |
Printf.printf "error %s\n" m | |
end | |
let () = | |
tp1 (EInt 1) TInt; | |
tp1 (EBool true) TBool; | |
tp1 (EBool false) TBool; | |
tp1 (ELet("x", EInt 1, EVar "x")) TInt; | |
tp1 (ELet("id",EAbs("x",EVar "x"), EVar "id")) (TArr(TVar "x2",TVar "x2")); | |
tp1 (ELet("id",EAbs("x",EVar "x"), EApp(EVar "id",EVar "id"))) (TArr(TVar "x3",TVar "x3")); | |
tp1 (ELet("id",EAbs("x",EVar "x"), EApp(EApp(EVar "id",EVar "id"), EInt 1))) TInt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment