Skip to content

Instantly share code, notes, and snippets.

@bkase
Created January 5, 2018 06:00
Show Gist options
  • Save bkase/7adebbf841c497ad29ed045ccd900e22 to your computer and use it in GitHub Desktop.
Save bkase/7adebbf841c497ad29ed045ccd900e22 to your computer and use it in GitHub Desktop.
Type Inference Algorithm W (ported from Swift Sandbox)
// Algorithm W
// From Principal Types for functional programs
// http://web.cs.wpi.edu/~cs4536/c12/milner-damas_principal_types.pdf
// by bkase
// STD library
infix operator <>: AdditionPrecedence
extension Dictionary {
static func <>(lhs: Dictionary, rhs: Dictionary) -> Dictionary {
var d: [Key: Value] = [:]
lhs.forEach{ k, v in d[k] = v }
rhs.forEach{ k, v in d[k] = v }
return d
}
func mapValues<V2>(f: (Value) -> V2) -> Dictionary<Key, V2> {
var d: [Key: V2] = [:]
self.forEach{ k, v in d[k] = f(v) }
return d
}
}
extension Set {
static func <>(lhs: Set, rhs: Set) -> Set {
return lhs.union(rhs)
}
}
typealias Ident = String
struct Syntax {
private var freshVar: Int = -1
mutating func newIdent() -> Ident {
self.freshVar += 1
return "$new$_\(self.freshVar)"
}
indirect enum Exp {
case Val(Int)
case Var(Ident)
case App(Exp, Exp)
case Lam(Ident, Exp)
case Let(Ident, Exp, body: Exp)
static func prim(_ x: Int) -> Exp {
return .Val(x)
}
}
indirect enum Typ {
case IntType
case TyVar(Ident)
case Fun(Typ, Typ)
var tyvars: Set<Ident> {
switch self {
case .IntType: return Set<Ident>([])
case .TyVar(let a): return Set([a])
case .Fun(let tIn, let tOut):
return tIn.tyvars <> tOut.tyvars
}
}
func sub(tyContext: TyContext) -> Typ {
switch self {
case .IntType: return self
// TODO: Can we do this without the !
case .TyVar(let typIdent) where tyContext[typIdent] != nil:
return tyContext[typIdent]!
case .TyVar(_): return self
case .Fun(let tIn, let tOut):
return .Fun(tIn.sub(tyContext: tyContext), tOut.sub(tyContext: tyContext))
}
}
func sub(_ t: Typ, forTypIdent typIdent: Ident) -> Typ {
return sub(tyContext: [typIdent: t])
}
}
}
extension Syntax.Typ: Equatable {
static func ==(lhs: Syntax.Typ, rhs: Syntax.Typ) -> Bool {
switch(lhs, rhs) {
case (.IntType, .IntType):
return true
case (.TyVar(let a1), .TyVar(let a2)):
return a1 == a2
case (.Fun(let t1In, let t1Out), .Fun(let t2In, let t2Out)):
return t1In == t2In && t1Out == t2Out
case (.IntType, _): return false
case (.TyVar(_), _): return false
case (.Fun(_, _), _): return false
}
}
}
// TODO: DoctorPretty!
extension Syntax.Typ: CustomStringConvertible {
var description: String {
switch self {
case .IntType:
return "Int"
case .TyVar(let a):
return a
case .Fun(let t1In, let t1Out):
return "(" + t1In.description + "" + t1Out.description + ")"
}
}
}
typealias TyContext = [Ident: Syntax.Typ]
var syntax = Syntax()
indirect enum TypScheme {
case Tau(Syntax.Typ)
case ForAll(Ident, TypScheme)
// TODO: This seems messed up...
static func tauLift(_ d: TyContext) -> Context {
return d.mapValues{ .Tau($0) }
}
var idents: Set<Ident> {
switch self {
case .Tau(_): return Set<Ident>([])
case .ForAll(let x, let rest): return Set([x]) <> rest.idents
}
}
func monomorphize(tyContext: TyContext) -> TypScheme {
switch self {
case .Tau(let typ):
return .Tau(typ.sub(tyContext: tyContext))
case .ForAll(let typIdent, let scheme) where tyContext[typIdent] != nil:
return scheme.monomorphize(tyContext: tyContext)
case .ForAll(let typIdent, let scheme):
return .ForAll(typIdent, scheme.monomorphize(tyContext: tyContext))
}
}
func monomorphize(_ t: Syntax.Typ, forTypIdent typIdent1: Ident) -> TypScheme {
return monomorphize(tyContext: [typIdent1: t])
}
func monomorphiseToFreeVar() -> Syntax.Typ {
let monomorphiseContext = self.idents.reduce([:] as TyContext) { acc, ident in
var d = acc
d[ident] = .TyVar(syntax.newIdent())
return d
}
if case .Tau(let typ) = self.monomorphize(tyContext: monomorphiseContext) {
return typ
} else {
fatalError("Expected scheme \(self) to be monomorphized")
}
}
}
typealias Context = [Ident: TypScheme]
func forall(_ a: Context, typ: Syntax.Typ) -> TypScheme {
return typ.tyvars.subtracting(a.keys).reduce(.Tau(typ)) { .ForAll($1, $0) }
}
func unify(_ t1: Syntax.Typ, _ t2: Syntax.Typ) -> TyContext? {
// iota
if case .IntType = t1,
case .IntType = t2 {
return [:]
}
// var1
else if case .TyVar(let typIdent) = t1 {
return [typIdent: t2]
}
// var2
else if case .TyVar(let typIdent) = t2 {
return [typIdent: t1]
}
// fun
else if case .Fun(let t1In, let t1Out) = t1,
case .Fun(let t2In, let t2Out) = t2,
let v1 = unify(t1In, t2In),
let v2 = unify(t1Out, t2Out) {
return v1 <> v2
}
else {
return nil
}
}
func syn(context a: Context, _ e: Syntax.Exp) -> Syntax.Typ? {
return synW(context: a, e).map{ tyctx, typ in
typ.sub(tyContext: tyctx)
}
}
func printIt(_ t: String) -> String? {
print(t)
return t
}
func synW(context a: Context, _ e: Syntax.Exp) -> (TyContext, Syntax.Typ)? {
// prim
if case .Val(_) = e {
return ([:], .IntType)
}
else if case .Var(let ident) = e,
let scheme = a[ident] {
return ([:], scheme.monomorphiseToFreeVar())
}
else if case .App(let e1, let e2) = e,
let (s1, t2) = synW(context: a, e2),
let (s2, t1) = synW(context: a <> TypScheme.tauLift(s1), e1),
let beta = Optional<Syntax.Typ>.some(.TyVar(syntax.newIdent())),
let v = unify(
t1.sub(tyContext: s2), .Fun(t2, beta)
) {
return (v <> s2 <> s1, beta.sub(tyContext: v))
}
else if case .Lam(let x, let e1) = e,
let beta = Optional<Syntax.Typ>.some(.TyVar(syntax.newIdent())),
// even though the paper says A_x the <> operator
// will overwrite the x key in the context, so we
// don't need to actually remove x from A
let (s1, t1) = synW(context: a <> TypScheme.tauLift([x: beta]), e1) {
return (s1, .Fun(beta.sub(tyContext: s1), t1))
}
else if case .Let(let x, let e1, body: let e2) = e,
// let _ = printIt("Start let"),
let (s1, t1) = synW(context: a, e1),
// let _ = printIt("Foo \(t1), \(s1), \(t2)"),
let (s2, t2) = synW(
context: a <> TypScheme.tauLift(s1) <>
[x: forall(
a <> TypScheme.tauLift(s1),
typ: t1
)],
e2
) {
return (s2 <> s1, t2)
}
else {
return nil
}
}
// TDD yo
func check(e: Syntax.Exp, synthesizes t: Syntax.Typ) {
return check(e: e, synthesizes: t, underContext: [:])
}
func check(e: Syntax.Exp, synthesizes t: Syntax.Typ, underContext ctx: Context) {
print("\n\n\nChecking that:\n\n\t\(e)\n\nSynthesizes:\n\n\t\(t)")
let synOut = syn(context: ctx, e)
print("Synthesized:\n\n\t\(String(describing:synOut))")
guard let tSyn = synOut,
unify(t, tSyn) != nil else {
fatalError("Assertion error: Expected to synthesize \(String(describing: t)) but did synthesize \(String(describing: synOut))")
}
}
// given [:], 4 ~> Int
check(
e: Syntax.Exp.prim(4),
synthesizes: .IntType
)
// given foo: Int, foo ~> Int
check(
e: Syntax.Exp.Var("foo"),
synthesizes: .IntType,
underContext: ["foo": .ForAll("alpha", .Tau(.IntType))]
)
// given foo: Int, \x -> foo ~> a -> Int
check(
e: .Lam("x", .Var("foo")),
synthesizes: .Fun(.IntType, .IntType),
underContext: ["foo": .ForAll("alpha", .Tau(.IntType))]
)
// given foo: Int, (\x -> foo)(3) ~> Int
check(
e: .App(.Lam("x", .Var("foo")), Syntax.Exp.prim(3)),
synthesizes: .IntType,
underContext: ["foo": .Tau(.IntType)]
)
// identity
check(
e: .Lam("x", .Var("x")),
synthesizes: .Fun(.IntType, .IntType)
)
// const
check(
e: .Lam("c", .Lam("x", .Var("c"))),
synthesizes: .Fun(
.IntType,
.Fun(.IntType, .IntType)
)
)
check(
e: .Let("const", .Lam("c", .Lam("x", .Var("c"))),
body: .App(.App(.Var("const"), Syntax.Exp.prim(3)), Syntax.Exp.prim(4))),
synthesizes: .IntType
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment