Created
October 27, 2017 08:18
-
-
Save hsk/f372f5abb7e70cf4520eb08e252cfe4b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
Algorithm W | |
*/ | |
object poly extends App { | |
type X = String | |
type U = Int | |
type Assump = Map[X, Scheme] | |
type Subst = Map[U, T] | |
type FTV = Set[U] | |
sealed trait E { | |
override def toString = | |
this match { | |
case EInt(i) => "%d".format(i) | |
case EBool(b) => "%b".format(b) | |
case EVar(x) => "%s".format(x) | |
case EApp(e1,e2) => "(%s %s)".format(e1, e2) | |
case EAbs(x,e) => "(fun %s->%s)".format(x, e) | |
case ELet(x,e1,e2) => "(let %s = %s in %s)".format(x, e1, e2) | |
case ELetRec(x,e1,e2) => "(let rec %s = %s in %s)".format(x, e1, e2) | |
case EIf(e0,e1,e2) => "(if %s then %s else %s)".format(e0, e1, e2) | |
} | |
} | |
case class EInt(i:Int) extends E | |
case class EBool(b:Boolean) extends E | |
case class EVar(x:X) extends E | |
case class EApp(e1:E,e2:E) extends E | |
case class EAbs(x:X,e2:E) extends E | |
case class ELet(x:X,e1:E,e2:E) extends E | |
case class ELetRec(x:X,e1:E,e2:E) extends E | |
case class EIf(e1:E,e2:E,e3:E) extends E | |
sealed trait T { | |
override def toString = | |
this match { | |
case TVar(v) => "%%%d".format(v) | |
case TInt => "int" | |
case TBool => "bool" | |
case TArr(t1,t2) => "(%s->%s)".format(t1, t2) | |
} | |
} | |
case class TVar(u:U) extends T | |
case object TInt extends T | |
case object TBool extends T | |
case class TArr(t1:T,t2:T) extends T | |
sealed trait Scheme | |
case class Mono(t:T) extends Scheme | |
case class Poly(vs:FTV,t:T) extends Scheme | |
def newvar(c:U):(U,U) = (c + 1,c) | |
def subst(s:Subst, t:T):T = | |
t match { | |
case TVar(v) => s.get(v) match { | |
case Some(t) => subst(s, t) | |
case None => TVar(v) | |
} | |
case TArr(t1, t2) => TArr(subst(s, t1), subst(s, t2)) | |
case _ => t | |
} | |
def subst_scheme(s:Subst, sc:Scheme):Scheme = | |
sc match { | |
case Mono(t) => Mono(subst(s,t)) | |
case Poly(vs,t) => Poly(vs,subst(s.filter{case(x,_)=>vs.contains(x);case _=>false},t)) | |
} | |
def unify(s:Subst, t1:T, t2:T):Subst = { | |
def occurs(v:U,t:T) { | |
t match { | |
case TVar(v2) if v == v2 => throw new Error("unify occurs error %d %d".format(v,v2)) | |
case TArr(t2,t3) => occurs(v, t2); occurs(v, t3) | |
case _ => | |
} | |
} | |
(subst(s, t1),subst(s, t2)) match { | |
case (t1, t2) if t1 == t2 => s | |
case (TVar(v), t) => occurs(v, t); s + (v->t) | |
case (t, TVar(v)) => occurs(v, t); s + (v->t) | |
case (TArr(t1, t2), TArr(t3, t4)) => unify(unify(s, t1, t3), t2, t4) | |
case (t1, t2) => throw new Error("unify error (%s,%s)".format(t1,t2)) | |
} | |
} | |
def gen(a:Assump, c:Int, s:Subst, t:T):(U,Scheme) = { | |
def ftv(t:T):FTV = | |
t match { | |
case TVar(v) => Set(v) | |
case TArr(t1, t2) => ftv(t1).union(ftv(t2)) | |
case _ => Set() | |
} | |
def ftv_scheme(sc:Scheme):FTV = | |
sc match { | |
case Mono(t) => ftv(t) | |
case Poly(vs,t) => ftv(t).diff(vs) | |
} | |
def ftv_assump(s:Subst, a:Assump):FTV = | |
a.foldLeft(Set[U]()) {case(vs,(_,sc)) => | |
vs.union(ftv_scheme(subst_scheme(s, sc))) | |
} | |
val (c1,vs,s1) = ftv(subst(s, t)).diff(ftv_assump(s, a)).foldLeft(c,Set[U](),Map():Subst){ | |
case ((c,vs,s),v) => | |
val (c1,v1:U) = newvar(c) | |
(c1,vs+v1, s+(v->TVar(v1))) | |
} | |
(c1,Poly(vs, subst(s1, t))) | |
} | |
def inst(c:U, scheme:Scheme):(U,T) = | |
scheme match { | |
case Poly(vs, t) => { | |
val (c1,s)=vs.foldLeft(c,Map():Subst){ | |
case((c,s),v)=> | |
val(c1,v1)=newvar(c) | |
(c1,s+(v->TVar(v1))) | |
} | |
(c1, subst(s, t)) | |
} | |
case Mono(t) => (c,t) | |
} | |
def ti(a:Assump, c:Int, s:Subst, e:E) : (U,Subst,T) = { | |
e match { | |
case EInt(_) => (c,s,TInt) | |
case EBool(_) => (c,s,TBool) | |
case EIf(e0, e1, e2) => | |
val (c0,s0,t0) = ti(a, c, s, e0) | |
val s0_ = unify(s0, t0, TBool) | |
val (c1,s1,t1) = ti(a, c0, s0_, e1) | |
val (c2,s2,t2) = ti(a, c1, s1, e2) | |
val s3 = unify(s2, t1, t2) | |
(c2,s3,t1) | |
case EAbs(x, e) => | |
val (c1,v1) = newvar(c) | |
val t1 = TVar(v1) | |
val (c2,s2,t2) = ti(a+(x->Mono(t1)), c1, s, e) | |
(c2,s2,TArr(t1, t2)) | |
case EApp(e1, e2) => | |
val (c1,s1,t1) = ti(a, c, s, e1) | |
val (c2,s2,t2) = ti(a, c1, s1, e2) | |
val (c3,v3) = newvar(c2) | |
val t3 = TVar(v3) | |
val s3 = unify(s2, t1, TArr(t2, t3)) | |
(c3,s3,t3) | |
case ELet(x, e1, e2) => | |
val (c1,s1,t1) = ti(a, c, s, e1) | |
val (c2,scheme) = gen(a, c1, s1, t1) | |
ti(a+(x->scheme), c2, s1, e2) | |
case EVar(x) => | |
a.get(x) match { | |
case Some(scheme) => val (c1,t) = inst(c,scheme); (c1,s,t) | |
case None => throw new Exception("lookup error " + x) | |
} | |
case ELetRec(x, e1, e2) => | |
val (c0,v0) = newvar(c) | |
val t0 = TVar(v0) | |
val (c1,s1,t1) = ti(a+(x->Mono(t0)), c0, s, e1) | |
val s2 = unify(s1,t1,t0) | |
val (c3,scheme) = gen(a, c1, s2, t1) | |
ti(a+(x->scheme), c3, s2, e2) | |
} | |
} | |
def test(e:E, t:T) { | |
print("test " + e + " : ") | |
val a = Map( | |
"="->Poly(Set(1),TArr(TVar(1),TArr(TVar(1),TBool))), | |
"+"->Mono(TArr(TInt,TArr(TInt,TInt))), | |
"-"->Mono(TArr(TInt,TArr(TInt,TInt))) | |
) | |
try { | |
val (_,s,t1) = ti(a, 0, Map(), e) | |
val t2 = subst(s, t1) | |
if (t == t2) { | |
println(t+" ok") | |
} else { | |
printf("error expected %s but %s s:%s\n", t,t2,s) | |
} | |
} catch { | |
case e:Error => println("error "+ e.getMessage()) | |
} | |
} | |
def v(i:Int):T = TVar(i) | |
def arr(t1:T,t2:T):T = TArr(t1,t2) | |
def x(x:X):E = EVar(x) | |
def i(i:Int):E = EInt(i) | |
def b(b:Boolean):E = EBool(b) | |
def app(e1:E,e2:E):E = EApp(e1,e2) | |
def elet(x:X,e1:E,e2:E):E = ELet(x,e1,e2) | |
def letrec(x:X,e1:E,e2:E):E = ELetRec(x,e1,e2) | |
def abs(x:X,e:E):E = EAbs(x,e) | |
def eif(e0:E,e1:E,e2:E):E = EIf(e0,e1,e2) | |
test(i(1), TInt) | |
test(b(true), TBool) | |
test(b(false), TBool) | |
test(elet("x", i(1), x("x")), TInt) | |
test(elet("id",abs("x", x("x")), x("id")), arr(v(2),v(2))) | |
test(elet("id",abs("x", x("x")), app(x("id"),x("id"))), arr(v(3),v(3))) | |
test(elet("id",abs("x", x("x")), app(app(x("id"),x("id")), i(1))), TInt) | |
test(elet("id",abs("x", x("x")), app(app(x("id"),x("id")), b(true))), TBool) | |
test(elet("id",abs("x", eif(x("x"),i(1),i(2))), | |
app(x("id"), b(true))), TInt) | |
test(elet("id",abs("x", abs("y", eif(x("x"),x("y"),x("y")))), | |
app(app(x("id"), b(true)),i(2))), TInt) | |
test(letrec("id",abs("x", abs("y", | |
eif(app(app(x("="),x("x")),x("y")),x("x"),app(app(x("id"),x("x")),x("y"))))), | |
app(app(x("id"), i(1)),i(2))), TInt) | |
test(letrec("id",abs("x", abs("y", | |
eif(app(app(x("="),x("x")),x("y")),x("x"),app(app(x("id"),x("x")),x("y"))))), | |
x("id")), arr(v(7),arr(v(7),v(7)))) | |
test( | |
letrec("sum",abs("x", | |
eif(app(app(x("="),x("x")),i(0)), | |
x("x"), | |
app(app(x("+"),app(x("sum"),app(app(x("-"),x("x")),i(1)))),x("x")))), | |
app(x("sum"), i(10))), TInt) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment