Created
March 10, 2016 13:02
-
-
Save gsg/9975df61eeae4c02eb55 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type id = string | |
module IdMap = Map.Make (String) | |
type ('a, 'b) alt = Left of 'a | Right of 'b | |
module TypeTerm = struct | |
type t = | |
| Unit | Bool | Int | |
| Pair of t * t | |
| Alt of t * t | |
| Arrow of t * t | |
end | |
module Term = struct | |
type t = | |
| Unit | Bool of bool | Int of int | |
| Lam of id * TypeTerm.t * t | |
| Var of id | |
| App of t * t | |
| Pair of t * t | |
| Fst of t | |
| Snd of t | |
| Inl of t * TypeTerm.t | |
| Inr of TypeTerm.t * t | |
| Case of t * id * t * id * t | |
end | |
module Eq = struct | |
type ('a, 'b) t = Refl : ('a, 'a) t | |
let pair : type a b c d . (a, b) t -> (c, d) t -> (a * c, b * d) t = | |
fun Refl Refl -> Refl | |
let alt : type a b c d . (a, b) t -> (c, d) t -> ((a, c) alt, (b, d) alt) t = | |
fun Refl Refl -> Refl | |
let arrow : type a b c d . (a, b) t -> (c, d) t -> (a -> c, b -> d) t = | |
fun Refl Refl -> Refl | |
end | |
module Type = struct | |
type _ t = | |
| Unit : unit t | |
| Bool : bool t | |
| Int : int t | |
| Pair : 'a t * 'b t -> ('a * 'b) t | |
| Alt : 'a t * 'b t -> ('a, 'b) alt t | |
| Arrow : 'a t * 'b t -> ('a -> 'b) t | |
let rec equal : type a b . a t -> b t -> (a, b) Eq.t option = | |
fun a b -> | |
match a, b with | |
| Unit, Unit -> Some Eq.Refl | |
| Bool, Bool -> Some Eq.Refl | |
| Int, Int -> Some Eq.Refl | |
| Pair (a1, b1), Pair (a2, b2) -> equal2 Eq.pair a1 b1 a2 b2 | |
| Alt (a1, b1), Alt (a2, b2) -> equal2 Eq.alt a1 b1 a2 b2 | |
| Arrow (a1, b1), Arrow (a2, b2) -> equal2 Eq.arrow a1 b1 a2 b2 | |
| _, _ -> None | |
and equal2 : type eq a b c d . | |
((a, c) Eq.t -> (b, d) Eq.t -> eq) -> | |
a t -> b t -> c t -> d t -> eq option = | |
fun join a b c d -> | |
(match equal a c with | |
| Some x -> | |
(match equal b d with | |
| Some y -> Some (join x y) | |
| None -> None) | |
| None -> None) | |
type any = Any : _ t -> any | |
type join_any = { | |
f : 'a 'b . 'a t -> 'b t -> any; | |
} | |
let pair_any = { f = fun a b -> Any (Pair (a, b)) } | |
let alt_any = { f = fun a b -> Any (Alt (a, b)) } | |
let arrow_any = { f = fun a b -> Any (Arrow (a, b)) } | |
let rec of_term term = | |
let binary f a b = | |
let Any aty = of_term a in | |
let Any bty = of_term b in | |
f.f aty bty in | |
match term with | |
| TypeTerm.Unit -> Any Unit | |
| TypeTerm.Bool -> Any Bool | |
| TypeTerm.Int -> Any Int | |
| TypeTerm.Pair (a, b) -> binary pair_any a b | |
| TypeTerm.Alt (a, b) -> binary alt_any a b | |
| TypeTerm.Arrow (a, b) -> binary arrow_any a b | |
let print ty = | |
let rec pr : type a . bool -> a t -> unit = | |
fun arrow_parens -> function | |
| Unit -> print_string "unit" | |
| Bool -> print_string "bool" | |
| Int -> print_string "int" | |
| Pair (a, b) -> | |
print_char '('; | |
pr false a; | |
print_string " * "; | |
pr false b; | |
print_char ')'; | |
| Alt (a, b) -> | |
print_char '('; | |
pr false a; | |
print_string " + "; | |
pr false b; | |
print_char ')'; | |
| Arrow (a, b) -> | |
if arrow_parens then print_char '('; | |
pr true a; | |
print_string " -> "; | |
pr false b; | |
if arrow_parens then print_char ')' in | |
pr false ty | |
end | |
type empty | |
module Env : sig | |
type 'e static_env | |
type 'e dynamic_env | |
type ('e, 'a) access | |
type _ result = Result : 'a Type.t * ('e, 'a) access -> 'e result | |
val empty_static : empty static_env | |
val extend_static : 'e static_env -> string -> 'a Type.t -> ('a * 'e) static_env | |
val find_static : 'e static_env -> string -> 'e result option | |
val empty_dynamic : empty dynamic_env | |
val extend_dynamic : 'e dynamic_env -> 'a -> ('a * 'e) dynamic_env | |
val find_dynamic : 'e dynamic_env -> ('e, 'a) access -> 'a | |
end = struct | |
type _ static_env = | |
| StaticNil : empty static_env | |
| StaticCons : string * 'a Type.t * 'e static_env -> ('a * 'e) static_env | |
type _ dynamic_env = | |
| DynNil : empty dynamic_env | |
| DynCons : 'a * 'e dynamic_env -> ('a * 'e) dynamic_env | |
type (_, 'a) access = | |
| FurtherIn : ('e, 'a) access -> (_ * 'e, 'a) access | |
| Here : ('a * _, 'a) access | |
type _ result = Result : 'a Type.t * ('e, 'a) access -> 'e result | |
let empty_static = StaticNil | |
let extend_static env name ty = StaticCons (name, ty, env) | |
let rec find_static env var_name = | |
let rec loop : type t . t static_env -> t result = function | |
| StaticNil -> raise Not_found | |
| StaticCons (name, ty, xs) -> | |
if var_name = name then Result (ty, Here) | |
else | |
let Result (ty, r) = loop xs in | |
Result (ty, FurtherIn r) in | |
try Some (loop env) | |
with Not_found -> None | |
let empty_dynamic = DynNil | |
let extend_dynamic env value = DynCons (value, env) | |
let rec find_dynamic : type t a . t dynamic_env -> (t, a) access -> a = | |
fun env access -> | |
match env, access with | |
| DynCons (x, _), Here -> x | |
| DynCons (_, xs), FurtherIn a -> find_dynamic xs a | |
| DynNil, _ -> assert false | |
end | |
module TypedTerm = struct | |
type ('env, _) t = | |
| Unit : ('env, unit) t | |
| Bool : bool -> ('env, bool) t | |
| Int : int -> ('env, int) t | |
| Lam : ('a * 'env, 'b) t -> ('env, ('a -> 'b)) t | |
| Var : ('env, 'a) Env.access -> ('env, 'a) t | |
| App : ('env, ('a -> 'b)) t * ('env, 'a) t -> ('env, 'b) t | |
| Pair : ('env, 'a) t * ('env, 'b) t -> ('env, 'a * 'b) t | |
| Fst : ('env, 'a * 'b) t -> ('env, 'a) t | |
| Snd : ('env, 'a * 'b) t -> ('env, 'b) t | |
| Inl : ('env, 'l) t -> ('env, ('l, _) alt) t | |
| Inr : ('env, 'r) t -> ('env, (_, 'r) alt) t | |
| Case : ('env, ('l, 'r) alt) t * | |
('l * 'env, 'a) t * | |
('r * 'env, 'a) t -> ('env, 'a) t | |
end | |
type 'a typed_result = | |
TypedResult : 'r Type.t * ('a, 'r) TypedTerm.t -> 'a typed_result | |
let rec typed : type t . t Env.static_env -> Term.t -> t typed_result = | |
fun env term -> | |
match term with | |
| Term.Unit -> | |
TypedResult (Type.Unit, TypedTerm.Unit) | |
| Term.Bool b -> | |
TypedResult (Type.Bool, TypedTerm.Bool b) | |
| Term.Int i -> | |
TypedResult (Type.Int, TypedTerm.Int i) | |
| Term.Var id -> | |
(match Env.find_static env id with | |
| None -> failwith "unbound variable" | |
| Some (Env.Result (ty, access)) -> | |
TypedResult (ty, TypedTerm.Var access)) | |
| Term.App (f, arg) -> | |
(match typed env f, typed env arg with | |
| TypedResult (Type.Arrow (f_arg_ty, result_ty), f_term), | |
TypedResult (arg_ty, arg_term) -> | |
(match Type.equal f_arg_ty arg_ty with | |
| Some Eq.Refl -> | |
TypedResult (result_ty, TypedTerm.App (f_term, arg_term)) | |
| None -> | |
failwith "type error") | |
| _, _ -> failwith "type error") | |
| Term.Lam (id, ty_term, body) -> | |
let Type.Any arg_ty = Type.of_term ty_term in | |
let env' = Env.extend_static env id arg_ty in | |
let TypedResult (body_ty, body_term) = typed env' body in | |
TypedResult (Type.Arrow (arg_ty, body_ty), | |
TypedTerm.Lam body_term) | |
| Term.Pair (a, b) -> | |
let TypedResult (aty, aterm) = typed env a in | |
let TypedResult (bty, bterm) = typed env b in | |
TypedResult (Type.Pair (aty, bty), | |
TypedTerm.Pair (aterm, bterm)) | |
| Term.Fst p -> | |
(match typed env p with | |
| TypedResult (Type.Pair (fst_ty, _), pterm) -> | |
TypedResult (fst_ty, TypedTerm.Fst pterm) | |
| _ -> failwith "type error") | |
| Term.Snd p -> | |
(match typed env p with | |
| TypedResult (Type.Pair (_, snd_ty), pterm) -> | |
TypedResult (snd_ty, TypedTerm.Snd pterm) | |
| _ -> failwith "type error") | |
| Term.Inl (arg, rty_term) -> | |
let TypedResult (arg_ty, arg_term) = typed env arg in | |
let Type.Any rty = Type.of_term rty_term in | |
TypedResult (Type.Alt (arg_ty, rty), | |
TypedTerm.Inl arg_term) | |
| Term.Inr (lty_term, arg) -> | |
let TypedResult (arg_ty, arg_term) = typed env arg in | |
let Type.Any lty = Type.of_term lty_term in | |
TypedResult (Type.Alt (lty, arg_ty), | |
TypedTerm.Inr arg_term) | |
| Term.Case (test, lvar, lterm, rvar, rterm) -> | |
let TypedResult (test_ty, test_term) = typed env test in | |
(match test_ty with | |
| Type.Alt (lty, rty) -> | |
let lenv = Env.extend_static env lvar lty in | |
let renv = Env.extend_static env rvar rty in | |
let TypedResult (lterm_ty, lbody_term) = typed lenv lterm in | |
let TypedResult (rterm_ty, rbody_term) = typed renv rterm in | |
(match Type.equal lterm_ty rterm_ty with | |
| Some Eq.Refl -> | |
TypedResult (lterm_ty, TypedTerm.Case (test_term, | |
lbody_term, | |
rbody_term)) | |
| None -> failwith "type error") | |
| _ -> failwith "type error") | |
let rec eval : type t a . t Env.dynamic_env -> (t, a) TypedTerm.t -> a = | |
let module T = TypedTerm in | |
fun env -> function | |
| T.Unit -> () | |
| T.Bool b -> b | |
| T.Int n -> n | |
| T.Lam body -> (fun x -> eval (Env.extend_dynamic env x) body) | |
| T.Var access -> Env.find_dynamic env access | |
| T.App (f, arg) -> (eval env f) (eval env arg) | |
| T.Pair (a, b) -> (eval env a, eval env b) | |
| T.Fst p -> fst (eval env p) | |
| T.Snd p -> snd (eval env p) | |
| T.Inl a -> Left (eval env a) | |
| T.Inr a -> Right (eval env a) | |
| T.Case (alt, l, r) -> | |
(match eval env alt with | |
| Left x -> eval (Env.extend_dynamic env x) l | |
| Right x -> eval (Env.extend_dynamic env x) r) | |
let rec print_value : type a . a Type.t -> a -> unit = | |
fun ty value -> | |
match ty with | |
| Type.Unit -> print_string "()" | |
| Type.Bool -> print_string (if value then "true" else "false") | |
| Type.Int -> print_int value | |
| Type.Pair (a, b) -> | |
print_char '('; | |
print_value a (fst value); | |
print_string ", "; | |
print_value b (snd value); | |
print_char ')'; | |
| Type.Alt (lty, rty) -> | |
(match value with | |
| Left l -> print_string "Left "; print_value lty l | |
| Right r -> print_string "Right "; print_value rty r) | |
| Type.Arrow _ -> print_string "<fun>" | |
let type_eval_print term = | |
let TypedResult (ty, ty_term) = typed Env.empty_static term in | |
let value = eval Env.empty_dynamic ty_term in | |
Type.print ty; | |
print_string " = "; | |
print_value ty value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment