Skip to content

Instantly share code, notes, and snippets.

@hsk
Created October 27, 2017 08:18
Show Gist options
  • Save hsk/f372f5abb7e70cf4520eb08e252cfe4b to your computer and use it in GitHub Desktop.
Save hsk/f372f5abb7e70cf4520eb08e252cfe4b to your computer and use it in GitHub Desktop.
/*
Algorithm W
*/
object poly extends App {
type X = String
type U = Int
type Assump = Map[X, Scheme]
type Subst = Map[U, T]
type FTV = Set[U]
sealed trait E {
override def toString =
this match {
case EInt(i) => "%d".format(i)
case EBool(b) => "%b".format(b)
case EVar(x) => "%s".format(x)
case EApp(e1,e2) => "(%s %s)".format(e1, e2)
case EAbs(x,e) => "(fun %s->%s)".format(x, e)
case ELet(x,e1,e2) => "(let %s = %s in %s)".format(x, e1, e2)
case ELetRec(x,e1,e2) => "(let rec %s = %s in %s)".format(x, e1, e2)
case EIf(e0,e1,e2) => "(if %s then %s else %s)".format(e0, e1, e2)
}
}
case class EInt(i:Int) extends E
case class EBool(b:Boolean) extends E
case class EVar(x:X) extends E
case class EApp(e1:E,e2:E) extends E
case class EAbs(x:X,e2:E) extends E
case class ELet(x:X,e1:E,e2:E) extends E
case class ELetRec(x:X,e1:E,e2:E) extends E
case class EIf(e1:E,e2:E,e3:E) extends E
sealed trait T {
override def toString =
this match {
case TVar(v) => "%%%d".format(v)
case TInt => "int"
case TBool => "bool"
case TArr(t1,t2) => "(%s->%s)".format(t1, t2)
}
}
case class TVar(u:U) extends T
case object TInt extends T
case object TBool extends T
case class TArr(t1:T,t2:T) extends T
sealed trait Scheme
case class Mono(t:T) extends Scheme
case class Poly(vs:FTV,t:T) extends Scheme
def newvar(c:U):(U,U) = (c + 1,c)
def subst(s:Subst, t:T):T =
t match {
case TVar(v) => s.get(v) match {
case Some(t) => subst(s, t)
case None => TVar(v)
}
case TArr(t1, t2) => TArr(subst(s, t1), subst(s, t2))
case _ => t
}
def subst_scheme(s:Subst, sc:Scheme):Scheme =
sc match {
case Mono(t) => Mono(subst(s,t))
case Poly(vs,t) => Poly(vs,subst(s.filter{case(x,_)=>vs.contains(x);case _=>false},t))
}
def unify(s:Subst, t1:T, t2:T):Subst = {
def occurs(v:U,t:T) {
t match {
case TVar(v2) if v == v2 => throw new Error("unify occurs error %d %d".format(v,v2))
case TArr(t2,t3) => occurs(v, t2); occurs(v, t3)
case _ =>
}
}
(subst(s, t1),subst(s, t2)) match {
case (t1, t2) if t1 == t2 => s
case (TVar(v), t) => occurs(v, t); s + (v->t)
case (t, TVar(v)) => occurs(v, t); s + (v->t)
case (TArr(t1, t2), TArr(t3, t4)) => unify(unify(s, t1, t3), t2, t4)
case (t1, t2) => throw new Error("unify error (%s,%s)".format(t1,t2))
}
}
def gen(a:Assump, c:Int, s:Subst, t:T):(U,Scheme) = {
def ftv(t:T):FTV =
t match {
case TVar(v) => Set(v)
case TArr(t1, t2) => ftv(t1).union(ftv(t2))
case _ => Set()
}
def ftv_scheme(sc:Scheme):FTV =
sc match {
case Mono(t) => ftv(t)
case Poly(vs,t) => ftv(t).diff(vs)
}
def ftv_assump(s:Subst, a:Assump):FTV =
a.foldLeft(Set[U]()) {case(vs,(_,sc)) =>
vs.union(ftv_scheme(subst_scheme(s, sc)))
}
val (c1,vs,s1) = ftv(subst(s, t)).diff(ftv_assump(s, a)).foldLeft(c,Set[U](),Map():Subst){
case ((c,vs,s),v) =>
val (c1,v1:U) = newvar(c)
(c1,vs+v1, s+(v->TVar(v1)))
}
(c1,Poly(vs, subst(s1, t)))
}
def inst(c:U, scheme:Scheme):(U,T) =
scheme match {
case Poly(vs, t) => {
val (c1,s)=vs.foldLeft(c,Map():Subst){
case((c,s),v)=>
val(c1,v1)=newvar(c)
(c1,s+(v->TVar(v1)))
}
(c1, subst(s, t))
}
case Mono(t) => (c,t)
}
def ti(a:Assump, c:Int, s:Subst, e:E) : (U,Subst,T) = {
e match {
case EInt(_) => (c,s,TInt)
case EBool(_) => (c,s,TBool)
case EIf(e0, e1, e2) =>
val (c0,s0,t0) = ti(a, c, s, e0)
val s0_ = unify(s0, t0, TBool)
val (c1,s1,t1) = ti(a, c0, s0_, e1)
val (c2,s2,t2) = ti(a, c1, s1, e2)
val s3 = unify(s2, t1, t2)
(c2,s3,t1)
case EAbs(x, e) =>
val (c1,v1) = newvar(c)
val t1 = TVar(v1)
val (c2,s2,t2) = ti(a+(x->Mono(t1)), c1, s, e)
(c2,s2,TArr(t1, t2))
case EApp(e1, e2) =>
val (c1,s1,t1) = ti(a, c, s, e1)
val (c2,s2,t2) = ti(a, c1, s1, e2)
val (c3,v3) = newvar(c2)
val t3 = TVar(v3)
val s3 = unify(s2, t1, TArr(t2, t3))
(c3,s3,t3)
case ELet(x, e1, e2) =>
val (c1,s1,t1) = ti(a, c, s, e1)
val (c2,scheme) = gen(a, c1, s1, t1)
ti(a+(x->scheme), c2, s1, e2)
case EVar(x) =>
a.get(x) match {
case Some(scheme) => val (c1,t) = inst(c,scheme); (c1,s,t)
case None => throw new Exception("lookup error " + x)
}
case ELetRec(x, e1, e2) =>
val (c0,v0) = newvar(c)
val t0 = TVar(v0)
val (c1,s1,t1) = ti(a+(x->Mono(t0)), c0, s, e1)
val s2 = unify(s1,t1,t0)
val (c3,scheme) = gen(a, c1, s2, t1)
ti(a+(x->scheme), c3, s2, e2)
}
}
def test(e:E, t:T) {
print("test " + e + " : ")
val a = Map(
"="->Poly(Set(1),TArr(TVar(1),TArr(TVar(1),TBool))),
"+"->Mono(TArr(TInt,TArr(TInt,TInt))),
"-"->Mono(TArr(TInt,TArr(TInt,TInt)))
)
try {
val (_,s,t1) = ti(a, 0, Map(), e)
val t2 = subst(s, t1)
if (t == t2) {
println(t+" ok")
} else {
printf("error expected %s but %s s:%s\n", t,t2,s)
}
} catch {
case e:Error => println("error "+ e.getMessage())
}
}
def v(i:Int):T = TVar(i)
def arr(t1:T,t2:T):T = TArr(t1,t2)
def x(x:X):E = EVar(x)
def i(i:Int):E = EInt(i)
def b(b:Boolean):E = EBool(b)
def app(e1:E,e2:E):E = EApp(e1,e2)
def elet(x:X,e1:E,e2:E):E = ELet(x,e1,e2)
def letrec(x:X,e1:E,e2:E):E = ELetRec(x,e1,e2)
def abs(x:X,e:E):E = EAbs(x,e)
def eif(e0:E,e1:E,e2:E):E = EIf(e0,e1,e2)
test(i(1), TInt)
test(b(true), TBool)
test(b(false), TBool)
test(elet("x", i(1), x("x")), TInt)
test(elet("id",abs("x", x("x")), x("id")), arr(v(2),v(2)))
test(elet("id",abs("x", x("x")), app(x("id"),x("id"))), arr(v(3),v(3)))
test(elet("id",abs("x", x("x")), app(app(x("id"),x("id")), i(1))), TInt)
test(elet("id",abs("x", x("x")), app(app(x("id"),x("id")), b(true))), TBool)
test(elet("id",abs("x", eif(x("x"),i(1),i(2))),
app(x("id"), b(true))), TInt)
test(elet("id",abs("x", abs("y", eif(x("x"),x("y"),x("y")))),
app(app(x("id"), b(true)),i(2))), TInt)
test(letrec("id",abs("x", abs("y",
eif(app(app(x("="),x("x")),x("y")),x("x"),app(app(x("id"),x("x")),x("y"))))),
app(app(x("id"), i(1)),i(2))), TInt)
test(letrec("id",abs("x", abs("y",
eif(app(app(x("="),x("x")),x("y")),x("x"),app(app(x("id"),x("x")),x("y"))))),
x("id")), arr(v(7),arr(v(7),v(7))))
test(
letrec("sum",abs("x",
eif(app(app(x("="),x("x")),i(0)),
x("x"),
app(app(x("+"),app(x("sum"),app(app(x("-"),x("x")),i(1)))),x("x")))),
app(x("sum"), i(10))), TInt)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment