Skip to content

Instantly share code, notes, and snippets.

@vhxs
Last active November 30, 2022 03:13
Show Gist options
  • Save vhxs/436251d0f6df8e84d23bb4a0b5e29263 to your computer and use it in GitHub Desktop.
Save vhxs/436251d0f6df8e84d23bb4a0b5e29263 to your computer and use it in GitHub Desktop.
Hindley-Milner type inference
# Hindley-Milner type inference in Pyret
# needs fixing, Pyret's syntax has changed substantially since 2013
# this was written in a functional language.
# can I take this and turn it into Haskell?
#lang pyret/whalesong
data Expr:
| idE(name :: String)
| numE(value :: Number)
| strE(value :: String)
| uopE(op :: UnaryOperator, arg :: Expr)
| bopE(op :: Operator, left :: Expr, right :: Expr)
| cifE(cond :: Expr, consq :: Expr, altern :: Expr)
| letE(name :: String, expr :: Expr, body :: Expr)
| lamE(param :: String, body :: Expr)
| appE(func :: Expr, arg :: Expr)
| emptyE
end
data UnaryOperator:
| firstOp
| restOp
| emptyOp # tests whether a list is empty
end
data Operator:
| plus
| minus
| append
| str-eq
| linkOp
end
data Type:
| varT(name :: String)
| baseT(typ :: BaseType)
| conT(constr :: ConstrType, args :: List<Type>)
end
fun normalize(typ :: Type) -> Type:
doc: "Put a type into a normal form, in which type variables are named sequentially."
alphabet = "abcdefghijklmnopqrstuvwxyz"
fun int-to-letter(n :: Number) -> String:
if n < 26:
alphabet.char-at(n)
else:
int-to-letter((n / 26).truncate() - 1) + alphabet.char-at(n.modulo(26))
end
end
var mapping = [list: ] # Map old variable names to new variable names
var count = 0 # Keep track of the latest new variable name
fun lookup-var(v :: String) -> String:
cases(Option) mapping.find(method(pair): pair.get(0) == v end):
| some(pair) => pair.get(1)
| none => v2 = int-to-letter(count)
count := count + 1
mapping := mapping + [list: [list: v, v2]]
v2
end
end
fun normalize-rec(t :: Type) -> Type:
cases(Type) t:
| varT(v) => varT(lookup-var(v))
| baseT(b) => baseT(b)
| conT(c, args) => conT(c, map(normalize-rec, args))
end
end
normalize-rec(typ)
end
fun same-type(t1 :: Type, t2 :: Type) -> Bool:
doc: "Check to see if two types are the same, up to variable renaming."
normalize(t1) == normalize(t2)
end
data BaseType:
| numT
| strT
end
data ConstrType:
| funT
| listT
end
fun parse(prog) -> Expr:
fun convert(sexpr):
if sexpr == "empty":
emptyE
else if List(sexpr):
head = sexpr.first
if head == "string":
strE(sexpr.get(1))
else if head == "if":
cifE(convert(sexpr.get(1)),
convert(sexpr.get(2)),
convert(sexpr.get(3)))
else if head == "let":
letE(sexpr.get(1).get(0),
convert(sexpr.get(1).get(1)),
convert(sexpr.get(2)))
else if head == "fun":
when sexpr.get(1).length() <> 1:
raise("In Polly, functions always take exactly one argument.")
end
lamE(sexpr.get(1).get(0), convert(sexpr.get(2)))
else if head == "+":
bopE(plus, convert(sexpr.get(1)), convert(sexpr.get(2)))
else if head == "-":
bopE(minus, convert(sexpr.get(1)), convert(sexpr.get(2)))
else if head == "++":
bopE(append, convert(sexpr.get(1)), convert(sexpr.get(2)))
else if head == "==":
bopE(str-eq, convert(sexpr.get(1)), convert(sexpr.get(2)))
else if head == "link":
bopE(linkOp, convert(sexpr.get(1)), convert(sexpr.get(2)))
else if head == "first":
uopE(firstOp, convert(sexpr.get(1)))
else if head == "rest":
uopE(restOp, convert(sexpr.get(1)))
else if head == "empty?":
uopE(emptyOp, convert(sexpr.get(1)))
else:
func = convert(head)
arg = convert(sexpr.get(1))
appE(func, arg)
end
else if Number(sexpr):
numE(sexpr)
else if String(sexpr):
idE(sexpr)
end
end
convert(prog)
end
# Lookup variable in environment to see if it exists
fun lookup(s :: String, env :: List<String>) -> Bool:
cases (List<String>) env:
| empty => false
| link(t, rest) =>
if s == t: true else: lookup(s, rest) end
end
end
# Gather together all variables names used in the scope of the program
# Will raise 'Unbound identifier' exception if it finds an unbound identifier
# Used later on to create fresh type variables
fun gather-vars(e :: Expr, vars :: Set<String>, env :: List<String>) -> Set<String>:
cases (Expr) e:
| idE(name) =>
if not(lookup(name, env)):
raise('Unbound identifier')
else: vars end
| numE(_) => vars
| strE(_) => vars
| uopE(_, arg) => gather-vars(arg, vars, env)
| bopE(_, left, right) => gather-vars(left, vars, env).union(gather-vars(right, vars, env))
| cifE(cond, consq, altern) => gather-vars(cond, vars, env).union(gather-vars(consq, vars, env)).
union(gather-vars(altern, vars, env))
| letE(name, expr, body) => gather-vars(expr, vars, env).union(gather-vars(body, vars.add(name), link(name, env)))
| lamE(param, body) => gather-vars(body, vars.add(param), link(param, env))
| appE(func, arg) => gather-vars(func, vars, env).union(gather-vars(arg, vars, env))
| emptyE => vars
end
end
# Equality constraint, represented by a pair of Types
data Constraint:
| eq-con(lhs :: Type, rhs :: Type)
end
# Generates unary operation constraints
fun uop-gen(e :: Expr, id :: String, arg-id :: String, g :: (-> String)) -> List<Constraint>:
type-var = g()
cases (UnaryOperator) e.op:
| firstOp =>
[list: eq-con(varT(id), varT(type-var)), eq-con(varT(arg-id), conT(listT, [list: varT(type-var)]))]
| restOp =>
[list: eq-con(varT(id), conT(listT, [list: varT(type-var)])), eq-con(varT(arg-id), conT(listT, [list: varT(type-var)]))]
| emptyOp =>
[list: eq-con(varT(id), baseT(numT)), eq-con(varT(arg-id), conT(listT, [list: varT(type-var)]))]
end
end
# Generates binary operation contraints
fun bop-gen(e :: Expr, id :: String, left-id :: String, right-id :: String, g :: (-> String)) -> List<Constraint>:
cases (Operator) e.op:
| plus =>
[list: eq-con(varT(id), baseT(numT)), eq-con(varT(left-id), baseT(numT)), eq-con(varT(right-id), baseT(numT))]
| minus =>
[list: eq-con(varT(id), baseT(numT)), eq-con(varT(left-id), baseT(numT)), eq-con(varT(right-id), baseT(numT))]
| append =>
[list: eq-con(varT(id), baseT(strT)), eq-con(varT(left-id), baseT(strT)), eq-con(varT(right-id), baseT(strT))]
| str-eq =>
[list: eq-con(varT(id), baseT(numT)), eq-con(varT(left-id), baseT(strT)), eq-con(varT(right-id), baseT(strT))]
| linkOp =>
type-var = g()
[list: eq-con(varT(id), conT(listT, [list: varT(type-var)])), eq-con(varT(left-id), varT(type-var)),
eq-con(varT(right-id), conT(listT, [list: varT(type-var)]))]
end
end
# ... and all the other constraints ...
fun cif-gen(e :: Expr, id :: String, cond-id :: String, consq-id :: String, altern-id :: String) -> List<Constraint>:
[list: eq-con(varT(id), varT(consq-id)), eq-con(varT(id), varT(altern-id)), eq-con(varT(cond-id), baseT(numT))]
end
fun let-gen(e :: Expr, id :: String, expr-id :: String, body-id :: String) -> List<Constraint>:
[list: eq-con(varT(id), varT(body-id)), eq-con(varT(expr-id), varT(e.name))]
end
fun lam-gen(e :: Expr, id :: String, body-id :: String) -> List<Constraint>:
[list: eq-con(varT(id), conT(funT, [list: varT(e.param), varT(body-id)]))]
end
fun app-gen(e :: Expr, id :: String, func-id :: String, arg-id :: String) -> List<Constraint>:
[list: eq-con(varT(func-id), conT(funT, [list: varT(arg-id), varT(id)]))]
end
fun empty-gen(id :: String, g :: (-> String)) -> List<Constraint>:
type-var = g()
[eq-con(varT(id), conT(listT, [list: varT(type-var)]))]
end
# Generates all constraints resulting from an expression.
# The last argument is a function which generates fresh type variable names.
fun constr-gen(e :: Expr, id :: String, g :: (-> String)) -> List<Contraint>:
cases (Expr) e:
| idE(name) => [list: eq-con(varT(id), varT(name))]
| numE(_) => [list: eq-con(varT(id), baseT(numT))]
| strE(_) => [list: eq-con(varT(id), baseT(strT))]
| uopE(op, arg) =>
arg-id = g()
constr-gen(arg, arg-id, g) + uop-gen(e, id, arg-id, g)
| bopE(op, left, right) =>
left-id = g()
right-id = g()
constr-gen(left, left-id, g) + constr-gen(right, right-id, g) + bop-gen(e, id, left-id, right-id, g)
| cifE(cond, consq, altern) =>
cond-id = g()
consq-id = g()
altern-id = g()
constr-gen(cond, cond-id, g) + constr-gen(consq, consq-id, g) + constr-gen(altern, altern-id, g)
+ cif-gen(e, id, cond-id, consq-id, altern-id)
| letE(name, expr, body) =>
expr-id = g()
body-id = g()
constr-gen(expr, expr-id, g) + constr-gen(body, body-id, g) + let-gen(e, id, expr-id, body-id)
| lamE(param, body) =>
body-id = g()
constr-gen(body, body-id, g) + lam-gen(e, id, body-id)
| appE(func, arg) =>
func-id = g()
arg-id = g()
constr-gen(func, func-id, g) + constr-gen(arg, arg-id, g) + app-gen(e, id, func-id, arg-id)
| emptyE => empty-gen(id, g)
end
end
# String representation of a Type.
# For debugging.
fun type-repr(typ :: Type) -> String:
cases (Type) typ:
| varT(s) => "[" + s + "]"
| baseT(stype) =>
cases (BaseType) stype:
| numT => "numT"
| strT => "strT"
end
| conT(stype, args) =>
cases (ConstrType) stype:
| funT => type-repr(args.first) + " -> " + type-repr(args.rest.first)
| listT => "<" + type-repr(args.first) + ">"
end
end
end
# Pretty-print all constraints.
fun print-cons(cs :: List<Constraint>):
for map(c from cs):
left = c.lhs
right = c.rhs
print(type-repr(left) + " = " + type-repr(right))
end
end
# Pretty-print all substitutions.
fun print-subst(theta :: List<Subst>):
for map(subst from theta):
term = subst.term
mustbe = subst.mustbe
print(type-repr(term) + " is " + type-repr(mustbe))
end
end
# The occurs check.
# Does t1 occur in t2?
# Additional Boolean argument to ignore trivial case x = x
fun occurs(t1 :: Type, t2 :: Type, nontriv :: Bool) -> Bool:
cases (Type) t2:
| varT(_) => if (nontriv and (t1.name == t2.name)): true else: false end
| baseT(_) => false
| conT(ct, args) =>
cases (ConstrType) ct:
| funT => occurs(t1, args.first, true) or occurs(t1, args.rest.first, true)
| listT => occurs(t1, args.first, true)
end
end
end
# A substitution is also a pair of types.
data Subst:
| a-subst(term :: Type, mustbe :: Type)
end
# Replace all instances of t1 with t2 in t3
fun replace-type(t1 :: Type, t2 :: Type, t3 :: Type) -> Type:
cases (Type) t3:
| varT(name) => if (name == t1.name): t2 else: t3 end
| baseT(_) => t3
| conT(ct, args) =>
cases (ConstrType) ct:
| funT =>
dom = args.first
rng = args.rest.first
conT(funT, [list: replace-type(t1, t2, dom), replace-type(t1, t2, rng)])
| listT =>
arg = args.first
conT(listT, [list: replace-type(t1, t2, arg)])
end
end
end
# Replace all instances of t1 with t2 in substitution set theta
fun replace-subst(t1 :: Type, t2 :: Type, theta :: List<Subst>) -> List<Subst>:
cases (List<Subst>) theta:
| empty => empty
| link(subst, rest) =>
link(a-subst(replace-type(t1, t2, subst.term), replace-type(t1, t2, subst.mustbe)), replace-subst(t1, t2, rest))
end
end
# Replace all instances of t1 with t2 in constraint set
fun replace-constr(t1 :: Type, t2 :: Type, cs :: List<Constraint>) -> List<Constraint>:
cases (List<Subst>) cs:
| empty => empty
| link(c, rest) =>
link(eq-con(replace-type(t1, t2, c.lhs), replace-type(t1, t2, c.rhs)), replace-constr(t1, t2, rest))
end
end
# Add a corresponding substitution given a contraint.
# Make appropriate substitutions to existing ones.
fun unify-one(c :: Constraint, theta :: List<Subst>) -> List<Subst>:
lhs = c.lhs
rhs = c.rhs
if is-baseT(lhs) and is-baseT(rhs):
if lhs.typ == rhs.typ:
theta
else:
raise('Failed to unify: inconsistent base types') end
else if is-varT(lhs):
if is-varT(rhs):
if lhs.name == rhs.name:
theta
else:
link(a-subst(lhs, rhs), replace-subst(lhs, rhs, theta)) end
else:
link(a-subst(lhs, rhs), replace-subst(lhs, rhs, theta)) end
else if is-varT(rhs):
link(a-subst(rhs, lhs), replace-subst(rhs, lhs, theta))
else: theta end
end
# Update the constraint set when a constraint is removed from it.
fun update-constr(c :: Constraint, cs :: List<Constraint>) -> List<Constraint>:
lhs = c.lhs
rhs = c.rhs
if is-baseT(lhs) and is-baseT(rhs):
cs
else if is-varT(lhs):
replace-constr(lhs, rhs, cs)
else if is-varT(rhs):
replace-constr(rhs, lhs, cs)
else if is-conT(lhs):
if is-conT(rhs):
if is-funT(lhs.constr) and is-funT(rhs.constr):
new-c1 = eq-con(lhs.args.first, rhs.args.first)
new-c2 = eq-con(lhs.args.rest.first, rhs.args.rest.first)
[list: new-c1, new-c2] + cs
else if is-listT(lhs.constr) and is-listT(rhs.constr):
new-c = eq-con(lhs.args.first, rhs.args.first)
link(new-c, cs)
else:
raise('failed to unify: function with list') end
else:
raise('Failed to unify: constructor with non-constructor') end
else if is-conT(rhs):
raise('Failed to unify: constructor with non-constructor')
else:
raise('Failed to unify') end
end
# The unifier
fun unify(cs :: List<Constraint>, theta :: List<Subst>) -> List<Subst>:
cases (List<Constraint>) cs:
| empty => theta
| link(c, rest) =>
lhs = c.lhs
rhs = c.rhs
if is-varT(lhs) and occurs(lhs, rhs, false):
raise('Failed to unify: occurs check failed')
else if is-varT(rhs) and occurs(rhs, lhs, false):
raise('Failed to unify: occurs check failed')
else: nothing end
new-theta = unify-one(c, theta)
new-cs = update-constr(c, rest)
unify(new-cs, new-theta)
end
end
# Search for a variable type's substitution
fun lookup-type(id :: String, theta :: List<Subst>) -> Type:
cases (List<Subst>) theta:
| empty => raise('Program not assigned type!')
| link(subst, rest) =>
if id == subst.term.name: subst.mustbe else: lookup-type(id, rest) end
end
end
fun type-infer(prog :: String):
# Create abstract syntax tree
ast = parse(read-sexpr(prog))
# Get the set of all variables
vars = gather-vars(ast, set(empty), empty)
# Concatenate all of them, and add a 0 to the front. It is impossible for this variable name to appear in the original program.
fresh = '0' + (for fold(f from '', varb from vars.to-list()): f + varb end)
# Define a variable name generator, using gensym and the above fresh variable name
gen = method(): gensym(fresh) end
# Assign a variable name to the program expression
prog-id = gen()
# Generate constraints
cs = constr-gen(ast, prog-id, gen)
# Unify them
sb = unify(cs, empty)
#print-subst(sb)
# Get the type of the program expression
lookup-type(prog-id, sb)
end
#p = fun (x): parse(read-sexpr(x)) end
check:
# Basic types
type-infer('0') satisfies same-type(_, baseT(numT))
type-infer('""') satisfies same-type(_, baseT(strT))
# Unbound identifier
type-infer('x') raises ''
# Occurs check
# A = List(A)
type-infer('(let (x empty) (link (first x) (first x)))') raises ''
type-infer('(let (x empty) (link (rest x) (rest x)))') raises ''
# A = A -> B
type-infer('((fun (x) (x x)) (fun (x) (x x)))') raises ''
# Arithmetic
# Success
type-infer('(+ 0 0)') satisfies same-type(_, baseT(numT))
type-infer('(- 0 0)') satisfies same-type(_, baseT(numT))
type-infer('(++ "" "")') satisfies same-type(_, baseT(strT))
type-infer('(== "" "")') satisfies same-type(_, baseT(numT))
# Failure
type-infer('(+ 0 "")') raises ''
type-infer('(- "" 0)') raises ''
type-infer('(++ 0 0)') raises ''
type-infer('(== 0 0)') raises ''
# If
type-infer('(if 0 0 0)') satisfies same-type(_, baseT(numT))
type-infer('(if 0 "" "")') satisfies same-type(_, baseT(strT))
type-infer('(if 0 empty empty)') satisfies same-type(_, conT(listT, [list: varT('A')]))
type-infer('(if 0 (link "" empty) empty)') satisfies same-type(_, conT(listT, [list: baseT(strT)]))
type-infer('(if 0 (fun (x) x) (fun (y) (+ y y)))') satisfies same-type(_, conT(funT, [list: baseT(numT), baseT(numT)]))
type-infer('(if 0 0 "")') raises ''
type-infer('(if 0 empty 0)') raises ''
type-infer('(if 0 "" 0)') raises ''
type-infer('(if "" 0 0)') raises ''
type-infer('(if empty 0 0)') raises ''
# Let
type-infer('(let (x 0) x)') satisfies same-type(_, baseT(numT))
type-infer('(let (x "") x)') satisfies same-type(_, baseT(strT))
type-infer('(let (x empty) (rest x))') satisfies same-type(_, conT(listT, [list: varT('A')]))
type-infer('(let (x empty) (+ (first x) 0))') satisfies same-type(_, baseT(numT))
type-infer('(let (x 0) (+ 0 x))') satisfies same-type(_, baseT(numT))
type-infer('(let (x (fun (y) y)) x)') satisfies same-type(_, conT(funT, [list: varT('A'), varT('A')]))
type-infer('(let (x (fun (y) y)) (x 0))') satisfies same-type(_, baseT(numT))
type-infer('(let (x (fun (y) y)) (== (x 0) ""))') raises ''
type-infer('(let (x 0) (rest x))') raises ''
type-infer('(let (x "") (+ 0 x))') raises ''
# Functions
type-infer('(fun (x) x)') satisfies same-type(_, conT(funT, [list: varT('A'), varT('A')]))
type-infer('(fun (x) 0)') satisfies same-type(_, conT(funT, [list: varT('A'), baseT(numT)]))
type-infer('(fun (x) (+ x x))') satisfies same-type(_, conT(funT, [list: baseT(numT), baseT(numT)]))
type-infer('(fun (x) (++ x x))') satisfies same-type(_, conT(funT, [list: baseT(strT), baseT(strT)]))
type-infer('(fun (x) (first x))') satisfies same-type(_, conT(funT, [list: conT(listT, [list: varT('A')]), varT('A')]))
type-infer('(fun (x) (fun (y) (y x)))') satisfies same-type(_, conT(funT, [list: varT('A'), conT(funT, [list: conT(funT, [list: varT('A'), varT('B')]), varT('B')])]))
# Applications
type-infer('((fun (x) (+ x 0)) 0)') satisfies same-type(_, baseT(numT))
type-infer('((fun (x) (++ x "")) "")') satisfies same-type(_, baseT(strT))
type-infer('(+ 0 ((fun (x) (+ x 0)) 0))') satisfies same-type(_, baseT(numT))
type-infer('(++ "" ((fun (x) (++ x "")) ""))') satisfies same-type(_, baseT(strT))
type-infer('(+ 0 ((fun (x) (== x "")) ""))') satisfies same-type(_, baseT(numT))
type-infer('(link 0 ((fun (x) (rest x)) empty))') satisfies same-type(_, conT(listT, [list: baseT(numT)]))
# Return clash
type-infer('(link 0 ((fun (x) (link "" x)) empty))') raises ''
type-infer('(rest ((fun (x) (+ x x) 0)))') raises ''
# Argument clash
type-infer('((fun (x) (+ x x)) empty)') raises ''
type-infer('((fun (x) (link 0 x)) (link "" empty))') raises ''
# Lists
type-infer('empty') satisfies same-type(_, conT(listT, [list: varT('A')]))
type-infer('(empty? empty)') satisfies same-type(_, baseT(numT))
type-infer('(empty? 0)') raises ''
type-infer('(first empty)') satisfies same-type(_, varT('A'))
type-infer('(first (link 0 empty))') satisfies same-type(_, baseT(numT))
type-infer('(first "")') raises ''
type-infer('(rest empty)') satisfies same-type(_, conT(listT, [list: varT('A')]))
type-infer('(rest (link (fun (x) x) empty))') satisfies same-type(_, conT(listT, [list: conT(funT, [list: varT('A'), varT('A')])]))
type-infer('(rest (fun (x) x))') raises ''
type-infer('(link 0 empty)') satisfies same-type(_, conT(listT, [list: baseT(numT)]))
type-infer('(link "" empty)') satisfies same-type(_, conT(listT, [list: baseT(strT)]))
type-infer('(link empty empty)') satisfies same-type(_, conT(listT, [list: conT(listT, [list: varT('A')])]))
type-infer('(link (fun (x) x) (link (fun (y) (+ y y)) empty))') satisfies same-type(_, conT(listT, [list: conT(funT, [list: baseT(numT), baseT(numT)])]))
type-infer('(link 0 (link "" empty))') raises ''
type-infer('(link (fun (x) (== x x)) (link (fun (y) (+ y y)) empty))') raises ''
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment