https://gist.github.com/hsk/45e2f8432dd18c29d8ff361bb7d45559
Last active
October 27, 2017 07:44
-
-
Save hsk/783a35e7a777d0d43383290d6a8616b0 to your computer and use it in GitHub Desktop.
Algorithm W on Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Algorithm W | |
*/ | |
use std::collections::BTreeMap; | |
use std::collections::BTreeSet; | |
type X = String; | |
type U = i32; | |
type Assump = BTreeMap<X, Scheme>; | |
type Subst = BTreeMap<U, T>; | |
type FTV = BTreeSet<U>; | |
#[derive(PartialEq,Clone)] | |
pub enum E { | |
Int(i32), | |
Bool(bool), | |
Var(X), | |
App(Box<E>,Box<E>), | |
Abs(X,Box<E>), | |
Let(X,Box<E>,Box<E>), | |
LetRec(X,Box<E>,Box<E>), | |
If(Box<E>,Box<E>,Box<E>) | |
} | |
#[derive(PartialEq,Clone,Eq,PartialOrd,Ord)] | |
pub enum T { | |
Var(U), | |
Int, | |
Bool, | |
Arr(Box<T>,Box<T>), | |
} | |
#[derive(PartialEq,Clone,Eq,PartialOrd,Ord)] | |
pub enum Scheme { | |
Mono(T), | |
Poly(FTV,T), | |
} | |
#[derive(Debug)] | |
pub enum TypeError { | |
Err(String), | |
} | |
impl std::fmt::Display for E { | |
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
match *self { | |
E::Int(i) => write!(f, "{}", i), | |
E::Bool(b) => write!(f, "{}", b), | |
E::Var(ref x) => write!(f, "{}", x), | |
E::App(ref e1,ref e2) => write!(f, "({} {})", e1, e2), | |
E::Abs(ref x,ref e) => write!(f, "(fun {}->{})", x, e), | |
E::Let(ref x,ref e1,ref e2) => write!(f, "(let {} = {} in {})", x, e1, e2), | |
E::LetRec(ref x,ref e1,ref e2) => write!(f, "(let rec {} = {} in {})", x, e1, e2), | |
E::If(ref e0,ref e1,ref e2) => write!(f, "(if {} then {} else {})", e0, e1, e2), | |
} | |
} | |
} | |
impl std::fmt::Display for T { | |
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
match *self { | |
T::Var(v) => write!(f, "%{}", v), | |
T::Int => write!(f, "int"), | |
T::Bool => write!(f, "bool"), | |
T::Arr(ref t1,ref t2) => write!(f, "({}->{})", t1, t2), | |
} | |
} | |
} | |
fn newvar(c:&mut i32)->U { | |
let v = *c; | |
*c = *c + 1; | |
v | |
} | |
fn subst(s:&Subst, t:&T)->T { | |
match *t { | |
T::Var(v) => match s.get(&v) { | |
Some(t) => subst(s, t), | |
None => T::Var(v), | |
} | |
T::Arr(ref t1, ref t2) => T::Arr(Box::new(subst(s, t1)), Box::new(subst(s, t2))), | |
_ => t.clone(), | |
} | |
} | |
fn subst_scheme(s:&Subst, sc:&Scheme)->Scheme { | |
match *sc { | |
Scheme::Mono(ref t) => Scheme::Mono(subst(s,&t)), | |
Scheme::Poly(ref vs,ref t) => { | |
let s = s.iter().filter_map(|(x,v)|if vs.contains(x) {Some((x.clone(),v.clone()))} else {None}); | |
Scheme::Poly(vs.clone(),subst(&s.collect(),&t)) | |
} | |
} | |
} | |
fn unify(s:&mut Subst, t1:&T, t2:&T) -> Result<(),TypeError> { | |
fn occurs(v:U,t:&T)->Result<(),TypeError> { | |
match *t { | |
T::Var(v2) if v == v2 => Err(TypeError::Err(format!("unify occurs error %{} %{}",v,v2))), | |
T::Arr(ref t2,ref t3) => {try!(occurs(v, t2)); occurs(v, t3)}, | |
_ => Ok(()), | |
} | |
} | |
match(subst(&s, &t1),subst(&s, &t2)) { | |
(ref t1, ref t2) if t1 == t2 => Ok(()), | |
(T::Var(v), ref t) | (ref t, T::Var(v)) => { | |
try!(occurs(v, &t)); | |
s.insert(v,t.clone()); | |
Ok(()) | |
} | |
(T::Arr(ref t1, ref t2), T::Arr(ref t3, ref t4)) => { | |
try!(unify(s, t1, t3)); | |
try!(unify(s, t2, t4)); | |
Ok(()) | |
} | |
(t1, t2) => Err(TypeError::Err(format!("unify error ({},{})",t1,t2))), | |
} | |
} | |
fn gen(a:&Assump, c:&mut i32, s:&mut Subst, t:&T)->Scheme { | |
fn ftv(t:&T)->FTV { | |
match *t { | |
T::Var(v) => {let mut vs=FTV::new(); vs.insert(v); vs}, | |
T::Arr(ref t1, ref t2) => ftv(t1).union(&ftv(t2)).cloned().collect(), | |
_ => FTV::new(), | |
} | |
} | |
fn ftv_scheme(s:&Subst, sc:&Scheme)->FTV { | |
match *sc { | |
Scheme::Mono(ref t) => ftv(&subst(s, &t)), | |
Scheme::Poly(ref vs,ref t) => ftv(t).difference(vs).cloned().collect(), | |
} | |
} | |
fn ftv_assump(s:&Subst, a:&Assump)->FTV { | |
let mut vs = FTV::new(); | |
for (_,sc) in a { vs = vs.union(&ftv_scheme(s, &subst_scheme(s,sc))).cloned().collect() } | |
vs | |
} | |
let mut s1 = Subst::new(); | |
let vs = ftv(&subst(s, t)).difference(&ftv_assump(s, a)).map(|v|{ | |
let v1 = newvar(c); | |
s1.insert(*v,T::Var(v1)); | |
v1 | |
}).collect(); | |
Scheme::Poly(vs, subst(&s1, t)) | |
} | |
fn inst(c:&mut i32, scheme:&Scheme)->T { | |
match *scheme { | |
Scheme::Poly(ref vs, ref t) => | |
subst(&mut vs.iter().map(|v|(*v,T::Var(newvar(c)))).collect(), &t), | |
Scheme::Mono(ref t) => t.clone(), | |
} | |
} | |
fn ti(a:&mut Assump, c:&mut i32, s:&mut Subst, e:&E) -> Result<T,TypeError> { | |
match *e { | |
E::Int(_) => Ok(T::Int), | |
E::Bool(_) => Ok(T::Bool), | |
E::If(ref e0, ref e1, ref e2) => { | |
let t0 = try!(ti(a, c, s, e0)); | |
try!(unify(s, &t0, &T::Bool)); | |
let t1 = try!(ti(a, c, s, e1)); | |
let t2 = try!(ti(a, c, s, e2)); | |
try!(unify(s, &t1, &t2)); | |
Ok(t1) | |
} | |
E::Abs(ref x, ref e) => { | |
let t1 = T::Var(newvar(c)); | |
let mut a1 = a.clone(); | |
a1.insert(x.clone(),Scheme::Mono(t1.clone())); | |
let t2 = try!(ti(&mut a1, c, s, e)); | |
Ok(T::Arr(Box::new(t1), Box::new(t2))) | |
} | |
E::App(ref e1, ref e2) => { | |
let t1 = try!(ti(a, c, s, e1)); | |
let t2 = try!(ti(a, c, s, e2)); | |
let t3 = T::Var(newvar(c)); | |
try!(unify(s, &t1, &T::Arr(Box::new(t2), Box::new(t3.clone())))); | |
Ok(t3) | |
} | |
E::Let(ref x, ref e1, ref e2) => { | |
let t1 = try!(ti(&mut a.clone(), c, s, e1)); | |
let scheme = gen(a, c, s, &t1); | |
a.insert(x.clone(),scheme); | |
ti(a, c, s, e2) | |
} | |
E::Var(ref x) => | |
match a.get(x) { | |
Some(scheme) => Ok(inst(c,scheme)), | |
None => Err(TypeError::Err(format!("lookup error {}", x))), | |
}, | |
E::LetRec(ref x, ref e1, ref e2) => { | |
let mut a1 = a.clone(); | |
let t2 = T::Var(newvar(c)); | |
a1.insert(x.clone(),Scheme::Mono(t2.clone())); | |
let t1 = try!(ti(&mut a1, c, s, e1)); | |
try!(unify(s,&t1,&t2)); | |
let scheme = gen(a, c, s, &t1); | |
a.insert(x.clone(),scheme); | |
ti(a, c, s, e2) | |
} | |
} | |
} | |
fn test(e:E, t:T) { | |
print!("test {} ", e); | |
let mut c = 0; | |
let mut s = Subst::new(); | |
let mut a = Assump::new(); | |
a.insert(format!("="),Scheme::Poly(vec![1].into_iter().collect(),arr(v(1),arr(v(1),T::Bool)))); | |
a.insert(format!("+"),Scheme::Mono(arr(T::Int,arr(T::Int,T::Int)))); | |
a.insert(format!("-"),Scheme::Mono(arr(T::Int,arr(T::Int,T::Int)))); | |
match ti(&mut a, &mut c, &mut s, &e) { | |
Ok(t1) => { | |
let t2 = subst(&s, &t1); | |
if t == t2 { | |
println!(": {} ok",t) | |
} else { | |
println!("error {} : expected {} but {}", e,t,t2) | |
} | |
} | |
Err(TypeError::Err(m)) => println!("error {}", m), | |
} | |
} | |
fn v(i:i32)->T {T::Var(i)} | |
fn arr(t:T,t2:T)->T {T::Arr(Box::new(t),Box::new(t2))} | |
fn x(s:&str)->E{E::Var(format!("{}",s))} | |
fn i(i:i32)->E{E::Int(i)} | |
fn b(b:bool)->E{E::Bool(b)} | |
fn app(e:E,e2:E)->E {E::App(Box::new(e),Box::new(e2))} | |
fn elet(x:&str,e1:E,e2:E)->E {E::Let(format!("{}",x),Box::new(e1),Box::new(e2))} | |
fn letrec(x:&str,e1:E,e2:E)->E {E::LetRec(format!("{}",x),Box::new(e1),Box::new(e2))} | |
fn abs(x:&str,e:E)->E {E::Abs(format!("{}",x),Box::new(e))} | |
fn eif(e0:E,e1:E,e2:E)->E {E::If(Box::new(e0),Box::new(e1),Box::new(e2))} | |
fn main() { | |
test(i(1), T::Int); | |
test(b(true), T::Bool); | |
test(b(false), T::Bool); | |
test(elet("x", i(1), x("x")), T::Int); | |
test(elet("id",abs("x", x("x")), x("id")), arr(v(2),v(2))); | |
test(elet("id",abs("x", x("x")), app(x("id"),x("id"))), arr(v(3),v(3))); | |
test(elet("id",abs("x", x("x")), app(app(x("id"),x("id")), i(1))), T::Int); | |
test(elet("id",abs("x", x("x")), app(app(x("id"),x("id")), b(true))), T::Bool); | |
test(elet("id",abs("x", eif(x("x"),i(1),i(2))), | |
app(x("id"), b(true))), T::Int); | |
test(elet("id",abs("x", abs("y", eif(x("x"),x("y"),x("y")))), | |
app(app(x("id"), b(true)),i(2))), T::Int); | |
test(letrec("id",abs("x", abs("y", | |
eif(app(app(x("="),x("x")),x("y")),x("x"),app(app(x("id"),x("x")),x("y"))))), | |
app(app(x("id"), i(1)),i(2))), T::Int); | |
test(letrec("id",abs("x", abs("y", | |
eif(app(app(x("="),x("x")),x("y")),x("x"),app(app(x("id"),x("x")),x("y"))))), | |
x("id")), arr(v(7),arr(v(7),v(7)))); | |
test( | |
letrec("sum",abs("x", | |
eif(app(app(x("="),x("x")),i(0)), | |
x("x"), | |
app(app(x("+"),app(x("sum"),app(app(x("-"),x("x")),i(1)))),x("x")))), | |
app(x("sum"), i(10))), T::Int); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment