Last active
February 18, 2024 17:43
-
-
Save seven1m/205e64d05ff56c36b68416691a0dbe7c to your computer and use it in GitHub Desktop.
Polymorphic Hindley-Milner type checking and inference algorithm as described by Cardelli
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Type checking and inference algorithm as described in | |
# Basic Polymorphic Typechecking by Luca Cardelli [1987] | |
# https://pages.cs.wisc.edu/~horwitz/CS704-NOTES/PAPERS/cardelli.pdf | |
# | |
# The history around this algorithm is confusing to me since I'm not very | |
# academic. I like this paper because it provides a lot of working code, | |
# which I can port to Ruby. I understand this to be generally known as | |
# the polymorphic Hindley-Milner type checking algorithm. | |
# | |
# Some interesting history bits from the paper: | |
# | |
# > Polymorphic types were already known as type schemas in combinatory | |
# > logic [Curry 58]. Extending Curry's work, and collaborating with him, | |
# > Hindley introduced the idea of a principal type schema, which is the | |
# > most general polymorphic type of an expression, and showed that if a | |
# > combinatorial term has a type, then it has a principal type [Hindley 69]. | |
# | |
# > The pragmatics of polymorphic typechecking has so far been restricted | |
# > to a small group of people. The only published description of the | |
# > algorithm is the one in [Milner 78] which is rather technical, | |
# > and mostly oriented towards the theoretical background. In the hope | |
# > of making the algorithm accessible to a larger group of people, | |
# > we present an implementation (in the form of a Modula-2 program) | |
# > which is very close to the one used in LCF, Hope and ML | |
# > [Gordon 79, Burstall 80, Milner 84]. | |
# | |
# I tried to port the Modula-2 program as directly to Ruby as possible, | |
# with a few changes to make it more idiomatic in Ruby. I kept most of the | |
# class and variable names the same where my taste allowed. | |
# | |
# I made one notable change to the core algorithm: in the `prune` method. | |
# The paper's code doesn't recursively prune TypeOperators, but I wasn't | |
# seeing the result I expected without doing that. | |
require 'set' | |
Cond = Struct.new(:test, :if_true, :if_false) | |
Lambda = Struct.new(:binder, :body) do | |
def to_s | |
"(fn #{binder} => #{body})" | |
end | |
end | |
Identifier = Struct.new(:name) do | |
alias to_s name | |
end | |
Apply = Struct.new(:fun, :arg) do | |
def to_s | |
"(#{fun} #{arg})" | |
end | |
end | |
Block = Struct.new(:decl, :scope) | |
Def = Struct.new(:binder, :def) | |
Seq = Struct.new(:first, :second) | |
Rec = Struct.new(:rec) | |
class TypeVariable | |
def initialize(type_checker) | |
@type_checker = type_checker | |
@id = @type_checker.next_variable_id | |
end | |
attr_accessor :id, :instance | |
def name | |
@name ||= @type_checker.next_variable_name | |
end | |
alias to_s name | |
def inspect | |
"TypeVariable(id = #{id})" | |
end | |
end | |
TypeOperator = Struct.new(:name, :types) do | |
def to_s | |
case types.size | |
when 0 | |
name | |
when 2 | |
"(#{types[0]} #{name} #{types[1]})" | |
else | |
"#{name} #{types.join(' ')}" | |
end | |
end | |
end | |
class Function < TypeOperator | |
def initialize(from_type, to_type) | |
super('->', [from_type, to_type]) | |
end | |
def inspect | |
"#<Function #{types[0].inspect} -> #{types[1].inspect}>" | |
end | |
end | |
IntType = TypeOperator.new('int', []) | |
BoolType = TypeOperator.new('bool', []) | |
class TypeChecker | |
class Error < StandardError; end | |
class RecursiveUnification < Error; end | |
class TypeClash < Error; end | |
class UndefinedSymbol < Error; end | |
def analyze(exp, env = build_initial_env, non_generic_vars = Set.new) | |
prune(analyze_exp(exp, env, non_generic_vars)) | |
end | |
def next_variable_id | |
if @next_variable_id | |
@next_variable_id += 1 | |
else | |
@next_variable_id = 0 | |
end | |
end | |
def next_variable_name | |
if @next_variable_name | |
@next_variable_name = @next_variable_name.succ | |
else | |
@next_variable_name = 'a' | |
end | |
end | |
private | |
def analyze_exp(exp, env, non_generic_vars) | |
case exp | |
when Identifier | |
if (exp2 = retrieve_type(exp.name, env, non_generic_vars)) | |
exp2 | |
elsif exp.name =~ /^\d+$/ | |
IntType | |
else | |
raise UndefinedSymbol, "undefined symbol #{exp.name}" | |
end | |
when Cond | |
unify_type(analyze_exp(exp.test, env, non_generic_vars), BoolType) | |
type_of_then = analyze_exp(exp.if_true, env, non_generic_vars) | |
type_of_else = analyze_exp(exp.if_false, env, non_generic_vars) | |
unify_type(type_of_then, type_of_else) | |
type_of_then | |
when Lambda | |
type_of_binder = TypeVariable.new(self) | |
body_env = env.merge(exp.binder => type_of_binder) | |
body_non_generic_vars = non_generic_vars + [type_of_binder] | |
type_of_body = analyze_exp(exp.body, body_env, body_non_generic_vars) | |
Function.new(type_of_binder, type_of_body) | |
when Apply | |
type_of_fun = analyze_exp(exp.fun, env, non_generic_vars) | |
type_of_arg = analyze_exp(exp.arg, env, non_generic_vars) | |
type_of_res = TypeVariable.new(self) | |
unify_type(type_of_fun, Function.new(type_of_arg, type_of_res)) | |
type_of_res | |
when Block | |
decl_env = analyze_decl(exp.decl, env, non_generic_vars) | |
analyze_exp(exp.scope, decl_env, non_generic_vars) | |
end | |
end | |
def analyze_decl(decl, env, non_generic_vars) | |
case decl | |
when Def | |
env.merge(decl.binder => analyze_exp(decl.def, env, non_generic_vars)) | |
when Seq | |
analyze_decl(decl.second, analyze_decl(decl.first, env, non_generic_vars), non_generic_vars) | |
when Rec | |
analyze_rec_decl_bind(decl.rec, env, non_generic_vars) | |
analyze_rec_decl(decl.rec, env, non_generic_vars) | |
env | |
end | |
end | |
# (p. 19) The first pass AnalyzeRecDeclBind simply creates a new set of | |
# non-generic type variables and associates them with identifiers. | |
def analyze_rec_decl_bind(decl, env, non_generic_vars) | |
case decl | |
when Def | |
new_type_var = TypeVariable.new(self) | |
env.merge!(decl.binder => new_type_var) | |
non_generic_vars << new_type_var | |
when Seq | |
analyze_rec_decl_bind(decl.first, env, non_generic_vars) | |
analyze_rec_decl_bind(decl.second, env, non_generic_vars) | |
when Rec | |
analyze_rec_decl_bind(decl.rec, env, non_generic_vars) | |
end | |
end | |
# (p. 19) The second pass AnalyzeRecDecl analyzes the declarations and makes | |
# calls to UnifyType to ensure the recursive type constraints. | |
def analyze_rec_decl(decl, env, non_generic_vars) | |
case decl | |
when Def | |
unify_type( | |
retrieve_type(decl.binder, env, non_generic_vars), | |
analyze_exp(decl.def, env, non_generic_vars) | |
) | |
when Seq | |
analyze_rec_decl(decl.first, env, non_generic_vars) | |
analyze_rec_decl(decl.second, env, non_generic_vars) | |
when Rec | |
analyze_rec_decl(decl.rec, env, non_generic_vars) | |
end | |
end | |
# Analogous to EnvMod.Retrieve | |
def retrieve_type(name, env, non_generic_vars) | |
return unless (exp = env[name]) | |
fresh_type(exp, non_generic_vars) | |
end | |
# (p. 19) FreshType makes a copy of a type expression, duplicating the generic variables | |
# and sharing the non-generic ones. | |
def fresh_type(type_exp, non_generic_vars, env = {}) | |
type_exp = prune(type_exp) | |
case type_exp | |
when TypeVariable | |
if occurs_in_type_list?(type_exp, non_generic_vars) | |
type_exp | |
else | |
env[type_exp] ||= TypeVariable.new(self) | |
end | |
when TypeOperator | |
TypeOperator.new( | |
type_exp.name, | |
type_exp.types.map { |t| fresh_type(t, non_generic_vars, env) } | |
) | |
end | |
end | |
# (p. 19) The function Prune is used whenever a type expression has to be inspected: it will always | |
# return a type expression which is either an uninstantiated type variable or a type operator; i.e. it | |
# will skip instantiated variables, and will actually prune them from expressions to remove long | |
# chains of instantiated variables. | |
def prune(type_exp) | |
case type_exp | |
when TypeVariable | |
if type_exp.instance.nil? | |
type_exp | |
else | |
type_exp.instance = prune(type_exp.instance) | |
end | |
when TypeOperator | |
# NOTE: The paper doesn't recursively prune TypeOperators -- it returns the type_exp here. | |
# I could not get a proper result from the algorithm without this change. | |
# It's very possible I messed something up somewhere else that made this a necessity. :-/ | |
TypeOperator.new( | |
type_exp.name, | |
type_exp.types.map { |t| prune(t) } | |
) | |
end | |
end | |
# (p. 19) The function OccursInType checks whether a type variable occurs in a type expression. | |
def occurs_in_type?(type_var, type_exp) | |
type_exp = prune(type_exp) | |
case type_exp | |
when TypeVariable | |
type_var == type_exp | |
when TypeOperator | |
occurs_in_type_list?(type_var, type_exp.types) | |
end | |
end | |
def occurs_in_type_list?(type_var, list) | |
list.any? { |t| occurs_in_type?(type_var, t) } | |
end | |
def unify_type(a, b) | |
a = prune(a) | |
b = prune(b) | |
case a | |
when TypeVariable | |
if occurs_in_type?(a, b) | |
unless a == b | |
raise RecursiveUnification, "recursive unification: #{b} contains #{a}" | |
end | |
else | |
a.instance = b | |
end | |
when TypeOperator | |
case b | |
when TypeVariable | |
unify_type(b, a) | |
when TypeOperator | |
if a.name == b.name | |
unify_args(a.types, b.types) | |
else | |
raise TypeClash, "#{a} cannot unify with #{b}" | |
end | |
end | |
end | |
end | |
def unify_args(list1, list2) | |
list1.zip(list2) do |a, b| | |
unify_type(a, b) | |
end | |
end | |
# (p. 7) In general, the type of an expression is determined by a set of type combination rules for the | |
# language constructs, and by the types of the primitive operators. The initial type environment | |
# could contain the following primitives for booleans, integers, pairs and lists (where → is the | |
# function type operator, × is cartesian product, and list is the list operator): | |
def build_initial_env | |
pair_first = TypeVariable.new(self) | |
pair_second = TypeVariable.new(self) | |
pair_type = TypeOperator.new('×', [pair_first, pair_second]) | |
list_type = TypeVariable.new(self) | |
list = TypeOperator.new('list', [list_type]) | |
list_pair_type = TypeOperator.new('×', [list_type, list]) | |
{ | |
'true' => TypeOperator.new('bool', []), | |
'false' => TypeOperator.new('bool', []), | |
'succ' => Function.new(IntType, IntType), | |
'pred' => Function.new(IntType, IntType), | |
'zero?' => Function.new(IntType, BoolType), | |
'times' => Function.new(IntType, Function.new(IntType, IntType)), | |
'minus' => Function.new(IntType, Function.new(IntType, IntType)), | |
'pair' => Function.new(pair_first, Function.new(pair_second, pair_type)), | |
'fst' => Function.new(pair_type, pair_first), # car | |
'snd' => Function.new(pair_type, pair_second), # cdr | |
'nil' => list, | |
'cons' => Function.new(list_pair_type, list), | |
'head' => Function.new(list, list_type), | |
'tail' => Function.new(list, list), | |
'null?' => Function.new(list, BoolType), | |
} | |
end | |
end | |
def debug(type, indent = 0) | |
case type | |
when Function | |
puts ' ' * indent + 'Function' | |
puts ' ' * indent + ' arg:' | |
debug(type.types[0], indent + 4) | |
puts ' ' * indent + ' body:' | |
debug(type.types[1], indent + 4) | |
when TypeVariable | |
puts ' ' * indent + 'TypeVariable' | |
puts ' ' * indent + " id: #{type.id}" | |
puts ' ' * indent + " name: #{type.name}" | |
puts ' ' * indent + ' instance:' | |
debug(type.instance, indent + 4) | |
when nil | |
puts ' ' * indent + 'nil' | |
end | |
end | |
if $0 == __FILE__ | |
require 'minitest/autorun' | |
require 'minitest/spec' | |
describe TypeChecker do | |
describe '#analyze' do | |
def analyze(exp) | |
TypeChecker.new.analyze(exp) | |
end | |
it 'determines type of the expression' do | |
exp = Lambda.new('f', Identifier.new('f')) | |
expect(analyze(exp).to_s).must_equal '(a -> a)' | |
exp = Lambda.new('f', | |
Lambda.new('g', | |
Lambda.new('arg', | |
Apply.new(Identifier.new('g'), | |
Apply.new(Identifier.new('f'), Identifier.new('arg')))))) | |
expect(analyze(exp).to_s).must_equal '((a -> b) -> ((b -> c) -> (a -> c)))' | |
exp = Lambda.new('g', | |
Block.new( | |
Def.new('f', | |
Lambda.new('x', Identifier.new('g'))), | |
Apply.new( | |
Apply.new(Identifier.new('pair'), | |
Apply.new(Identifier.new('f'), Identifier.new('3')) | |
), | |
Apply.new(Identifier.new('f'), Identifier.new('true'))))) | |
expect(analyze(exp).to_s).must_equal '(a -> (a × a))' | |
exp = Block.new( | |
Def.new('g', | |
Lambda.new('f', Identifier.new('5'))), | |
Apply.new(Identifier.new('g'), Identifier.new('g'))) | |
expect(analyze(exp).to_s).must_equal 'int' | |
pair = Apply.new( | |
Apply.new( | |
Identifier.new('pair'), | |
Apply.new( | |
Identifier.new('f'), | |
Identifier.new('4') | |
) | |
), | |
Apply.new( | |
Identifier.new('f'), | |
Identifier.new('true') | |
) | |
) | |
exp = Block.new( | |
Def.new('f', Lambda.new('x', Identifier.new('x'))), | |
pair) | |
expect(analyze(exp).to_s).must_equal '(int × bool)' | |
# def factorial(n) | |
# if n.zero? | |
# 1 | |
# else | |
# n * factorial(n - 1) | |
# end | |
# end | |
exp = Block.new( | |
Rec.new( | |
Def.new('factorial', | |
Lambda.new('n', # def factorial | |
Cond.new( # if | |
Apply.new(Identifier.new('zero?'), Identifier.new('n')), # (zero? n) | |
Identifier.new('1'), # then 1 | |
Apply.new( # else (times n ...) | |
Apply.new(Identifier.new('times'), Identifier.new('n')), # (times n) | |
Apply.new( # (factorial ((minus n) 1)) | |
Identifier.new('factorial'), | |
Apply.new( # ((minus n) 1) | |
Apply.new(Identifier.new('minus'), Identifier.new('n')), # (minus n) | |
Identifier.new('1')))))))), # 1 | |
Apply.new(Identifier.new('factorial'), Identifier.new('5'))) | |
expect(analyze(exp).to_s).must_equal 'int' | |
# (list 1 2) | |
exp = Apply.new(Identifier.new('cons'), | |
Apply.new(Apply.new(Identifier.new('pair'), Identifier.new('1')), | |
Apply.new(Identifier.new('cons'), | |
Apply.new(Apply.new(Identifier.new('pair'), Identifier.new('2')), Identifier.new('nil'))))) | |
expect(analyze(exp).to_s).must_equal 'list int' | |
# I think Seq is like Scheme's `let*`, where the bindings are evaluated | |
# one-by-one so that subsequent bindings can refer to previous ones. | |
exp = Block.new( | |
Seq.new( | |
Def.new('x', Identifier.new('2')), | |
Def.new('fn', Lambda.new('n', | |
Apply.new( | |
Apply.new(Identifier.new('times'), Identifier.new('n')), | |
Identifier.new('x'))))), | |
Apply.new( | |
Apply.new(Identifier.new('pair'), Identifier.new('x')), | |
Apply.new(Identifier.new('fn'), Identifier.new('3')))) | |
expect(analyze(exp).to_s).must_equal '(int × int)' | |
end | |
it 'raises an error on type clash' do | |
exp = Lambda.new('x', | |
Apply.new( | |
Apply.new(Identifier.new('pair'), | |
Apply.new(Identifier.new('x'), Identifier.new('3'))), | |
Apply.new(Identifier.new('x'), Identifier.new('true')))) | |
err = expect { analyze(exp) }.must_raise TypeChecker::TypeClash | |
expect(err.message).must_equal 'int cannot unify with bool' | |
end | |
it 'raises an error on type clash with a list' do | |
list = Apply.new(Identifier.new('cons'), | |
Apply.new(Apply.new(Identifier.new('pair'), Identifier.new('true')), | |
Apply.new(Identifier.new('cons'), | |
Apply.new(Apply.new(Identifier.new('pair'), Identifier.new('2')), Identifier.new('nil'))))) | |
err = expect { analyze(list) }.must_raise TypeChecker::TypeClash | |
expect(err.message).must_equal 'bool cannot unify with int' | |
end | |
it 'raises an error when the symbol is undefined' do | |
exp = Apply.new( | |
Apply.new(Identifier.new('pair'), Apply.new(Identifier.new('f'), Identifier.new('4'))), | |
Apply.new(Identifier.new('f'), Identifier.new('true'))) | |
err = expect { analyze(exp) }.must_raise TypeChecker::UndefinedSymbol | |
expect(err.message).must_equal 'undefined symbol f' | |
end | |
it 'raises an error on recursive unification' do | |
exp = Lambda.new('f', Apply.new(Identifier.new('f'), Identifier.new('f'))) | |
err = expect { analyze(exp) }.must_raise TypeChecker::RecursiveUnification | |
expect(err.message).must_equal 'recursive unification: (a -> b) contains a' | |
end | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment