Created
January 5, 2018 06:00
-
-
Save bkase/7adebbf841c497ad29ed045ccd900e22 to your computer and use it in GitHub Desktop.
Type Inference Algorithm W (ported from Swift Sandbox)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Algorithm W | |
// From Principal Types for functional programs | |
// http://web.cs.wpi.edu/~cs4536/c12/milner-damas_principal_types.pdf | |
// by bkase | |
// STD library | |
infix operator <>: AdditionPrecedence | |
extension Dictionary { | |
static func <>(lhs: Dictionary, rhs: Dictionary) -> Dictionary { | |
var d: [Key: Value] = [:] | |
lhs.forEach{ k, v in d[k] = v } | |
rhs.forEach{ k, v in d[k] = v } | |
return d | |
} | |
func mapValues<V2>(f: (Value) -> V2) -> Dictionary<Key, V2> { | |
var d: [Key: V2] = [:] | |
self.forEach{ k, v in d[k] = f(v) } | |
return d | |
} | |
} | |
extension Set { | |
static func <>(lhs: Set, rhs: Set) -> Set { | |
return lhs.union(rhs) | |
} | |
} | |
typealias Ident = String | |
struct Syntax { | |
private var freshVar: Int = -1 | |
mutating func newIdent() -> Ident { | |
self.freshVar += 1 | |
return "$new$_\(self.freshVar)" | |
} | |
indirect enum Exp { | |
case Val(Int) | |
case Var(Ident) | |
case App(Exp, Exp) | |
case Lam(Ident, Exp) | |
case Let(Ident, Exp, body: Exp) | |
static func prim(_ x: Int) -> Exp { | |
return .Val(x) | |
} | |
} | |
indirect enum Typ { | |
case IntType | |
case TyVar(Ident) | |
case Fun(Typ, Typ) | |
var tyvars: Set<Ident> { | |
switch self { | |
case .IntType: return Set<Ident>([]) | |
case .TyVar(let a): return Set([a]) | |
case .Fun(let tIn, let tOut): | |
return tIn.tyvars <> tOut.tyvars | |
} | |
} | |
func sub(tyContext: TyContext) -> Typ { | |
switch self { | |
case .IntType: return self | |
// TODO: Can we do this without the ! | |
case .TyVar(let typIdent) where tyContext[typIdent] != nil: | |
return tyContext[typIdent]! | |
case .TyVar(_): return self | |
case .Fun(let tIn, let tOut): | |
return .Fun(tIn.sub(tyContext: tyContext), tOut.sub(tyContext: tyContext)) | |
} | |
} | |
func sub(_ t: Typ, forTypIdent typIdent: Ident) -> Typ { | |
return sub(tyContext: [typIdent: t]) | |
} | |
} | |
} | |
extension Syntax.Typ: Equatable { | |
static func ==(lhs: Syntax.Typ, rhs: Syntax.Typ) -> Bool { | |
switch(lhs, rhs) { | |
case (.IntType, .IntType): | |
return true | |
case (.TyVar(let a1), .TyVar(let a2)): | |
return a1 == a2 | |
case (.Fun(let t1In, let t1Out), .Fun(let t2In, let t2Out)): | |
return t1In == t2In && t1Out == t2Out | |
case (.IntType, _): return false | |
case (.TyVar(_), _): return false | |
case (.Fun(_, _), _): return false | |
} | |
} | |
} | |
// TODO: DoctorPretty! | |
extension Syntax.Typ: CustomStringConvertible { | |
var description: String { | |
switch self { | |
case .IntType: | |
return "Int" | |
case .TyVar(let a): | |
return a | |
case .Fun(let t1In, let t1Out): | |
return "(" + t1In.description + " → " + t1Out.description + ")" | |
} | |
} | |
} | |
typealias TyContext = [Ident: Syntax.Typ] | |
var syntax = Syntax() | |
indirect enum TypScheme { | |
case Tau(Syntax.Typ) | |
case ForAll(Ident, TypScheme) | |
// TODO: This seems messed up... | |
static func tauLift(_ d: TyContext) -> Context { | |
return d.mapValues{ .Tau($0) } | |
} | |
var idents: Set<Ident> { | |
switch self { | |
case .Tau(_): return Set<Ident>([]) | |
case .ForAll(let x, let rest): return Set([x]) <> rest.idents | |
} | |
} | |
func monomorphize(tyContext: TyContext) -> TypScheme { | |
switch self { | |
case .Tau(let typ): | |
return .Tau(typ.sub(tyContext: tyContext)) | |
case .ForAll(let typIdent, let scheme) where tyContext[typIdent] != nil: | |
return scheme.monomorphize(tyContext: tyContext) | |
case .ForAll(let typIdent, let scheme): | |
return .ForAll(typIdent, scheme.monomorphize(tyContext: tyContext)) | |
} | |
} | |
func monomorphize(_ t: Syntax.Typ, forTypIdent typIdent1: Ident) -> TypScheme { | |
return monomorphize(tyContext: [typIdent1: t]) | |
} | |
func monomorphiseToFreeVar() -> Syntax.Typ { | |
let monomorphiseContext = self.idents.reduce([:] as TyContext) { acc, ident in | |
var d = acc | |
d[ident] = .TyVar(syntax.newIdent()) | |
return d | |
} | |
if case .Tau(let typ) = self.monomorphize(tyContext: monomorphiseContext) { | |
return typ | |
} else { | |
fatalError("Expected scheme \(self) to be monomorphized") | |
} | |
} | |
} | |
typealias Context = [Ident: TypScheme] | |
func forall(_ a: Context, typ: Syntax.Typ) -> TypScheme { | |
return typ.tyvars.subtracting(a.keys).reduce(.Tau(typ)) { .ForAll($1, $0) } | |
} | |
func unify(_ t1: Syntax.Typ, _ t2: Syntax.Typ) -> TyContext? { | |
// iota | |
if case .IntType = t1, | |
case .IntType = t2 { | |
return [:] | |
} | |
// var1 | |
else if case .TyVar(let typIdent) = t1 { | |
return [typIdent: t2] | |
} | |
// var2 | |
else if case .TyVar(let typIdent) = t2 { | |
return [typIdent: t1] | |
} | |
// fun | |
else if case .Fun(let t1In, let t1Out) = t1, | |
case .Fun(let t2In, let t2Out) = t2, | |
let v1 = unify(t1In, t2In), | |
let v2 = unify(t1Out, t2Out) { | |
return v1 <> v2 | |
} | |
else { | |
return nil | |
} | |
} | |
func syn(context a: Context, _ e: Syntax.Exp) -> Syntax.Typ? { | |
return synW(context: a, e).map{ tyctx, typ in | |
typ.sub(tyContext: tyctx) | |
} | |
} | |
func printIt(_ t: String) -> String? { | |
print(t) | |
return t | |
} | |
func synW(context a: Context, _ e: Syntax.Exp) -> (TyContext, Syntax.Typ)? { | |
// prim | |
if case .Val(_) = e { | |
return ([:], .IntType) | |
} | |
else if case .Var(let ident) = e, | |
let scheme = a[ident] { | |
return ([:], scheme.monomorphiseToFreeVar()) | |
} | |
else if case .App(let e1, let e2) = e, | |
let (s1, t2) = synW(context: a, e2), | |
let (s2, t1) = synW(context: a <> TypScheme.tauLift(s1), e1), | |
let beta = Optional<Syntax.Typ>.some(.TyVar(syntax.newIdent())), | |
let v = unify( | |
t1.sub(tyContext: s2), .Fun(t2, beta) | |
) { | |
return (v <> s2 <> s1, beta.sub(tyContext: v)) | |
} | |
else if case .Lam(let x, let e1) = e, | |
let beta = Optional<Syntax.Typ>.some(.TyVar(syntax.newIdent())), | |
// even though the paper says A_x the <> operator | |
// will overwrite the x key in the context, so we | |
// don't need to actually remove x from A | |
let (s1, t1) = synW(context: a <> TypScheme.tauLift([x: beta]), e1) { | |
return (s1, .Fun(beta.sub(tyContext: s1), t1)) | |
} | |
else if case .Let(let x, let e1, body: let e2) = e, | |
// let _ = printIt("Start let"), | |
let (s1, t1) = synW(context: a, e1), | |
// let _ = printIt("Foo \(t1), \(s1), \(t2)"), | |
let (s2, t2) = synW( | |
context: a <> TypScheme.tauLift(s1) <> | |
[x: forall( | |
a <> TypScheme.tauLift(s1), | |
typ: t1 | |
)], | |
e2 | |
) { | |
return (s2 <> s1, t2) | |
} | |
else { | |
return nil | |
} | |
} | |
// TDD yo | |
func check(e: Syntax.Exp, synthesizes t: Syntax.Typ) { | |
return check(e: e, synthesizes: t, underContext: [:]) | |
} | |
func check(e: Syntax.Exp, synthesizes t: Syntax.Typ, underContext ctx: Context) { | |
print("\n\n\nChecking that:\n\n\t\(e)\n\nSynthesizes:\n\n\t\(t)") | |
let synOut = syn(context: ctx, e) | |
print("Synthesized:\n\n\t\(String(describing:synOut))") | |
guard let tSyn = synOut, | |
unify(t, tSyn) != nil else { | |
fatalError("Assertion error: Expected to synthesize \(String(describing: t)) but did synthesize \(String(describing: synOut))") | |
} | |
} | |
// given [:], 4 ~> Int | |
check( | |
e: Syntax.Exp.prim(4), | |
synthesizes: .IntType | |
) | |
// given foo: Int, foo ~> Int | |
check( | |
e: Syntax.Exp.Var("foo"), | |
synthesizes: .IntType, | |
underContext: ["foo": .ForAll("alpha", .Tau(.IntType))] | |
) | |
// given foo: Int, \x -> foo ~> a -> Int | |
check( | |
e: .Lam("x", .Var("foo")), | |
synthesizes: .Fun(.IntType, .IntType), | |
underContext: ["foo": .ForAll("alpha", .Tau(.IntType))] | |
) | |
// given foo: Int, (\x -> foo)(3) ~> Int | |
check( | |
e: .App(.Lam("x", .Var("foo")), Syntax.Exp.prim(3)), | |
synthesizes: .IntType, | |
underContext: ["foo": .Tau(.IntType)] | |
) | |
// identity | |
check( | |
e: .Lam("x", .Var("x")), | |
synthesizes: .Fun(.IntType, .IntType) | |
) | |
// const | |
check( | |
e: .Lam("c", .Lam("x", .Var("c"))), | |
synthesizes: .Fun( | |
.IntType, | |
.Fun(.IntType, .IntType) | |
) | |
) | |
check( | |
e: .Let("const", .Lam("c", .Lam("x", .Var("c"))), | |
body: .App(.App(.Var("const"), Syntax.Exp.prim(3)), Syntax.Exp.prim(4))), | |
synthesizes: .IntType | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment