Created
September 28, 2011 10:48
-
-
Save ddrone/1247631 to your computer and use it in GitHub Desktop.
Hindley-Millner type inference
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.TreeMap; | |
public class Main { | |
// Expression type definitions | |
static abstract class Expression implements Comparable<Expression> { | |
@Override | |
public int compareTo(Expression arg0) { | |
return this.toString().compareTo(arg0.toString()); | |
} | |
} | |
static class VariableExpression extends Expression { | |
String varName; | |
public VariableExpression(String name) { | |
varName = name; | |
} | |
@Override | |
public String toString() { | |
return varName; | |
} | |
} | |
static class AbstractionExpression extends Expression { | |
String varName; | |
Expression abstrBody; | |
public AbstractionExpression(String name, Expression body) { | |
varName = name; | |
abstrBody = body; | |
} | |
@Override | |
public String toString() { | |
return "\\" + varName + "." + abstrBody.toString(); | |
} | |
} | |
static class ApplicationExpression extends Expression { | |
Expression applFunction; | |
Expression applArgument; | |
public ApplicationExpression(Expression func, Expression arg) { | |
applFunction = func; | |
applArgument = arg; | |
} | |
@Override | |
public String toString() { | |
return "(" + applFunction.toString() + " " + applArgument.toString() + ")"; | |
} | |
} | |
// "Type" type definition | |
static abstract class Type { | |
} | |
static class VariableType extends Type { | |
String varName; | |
public VariableType(String name) { | |
varName = name; | |
} | |
@Override | |
public String toString() { | |
return varName; | |
} | |
} | |
static class ArrowType extends Type { | |
Type arrowLeft; | |
Type arrowRight; | |
public ArrowType(Type left, Type right) { | |
arrowLeft = left; | |
arrowRight = right; | |
} | |
@Override | |
public String toString() { | |
String result; | |
if (arrowLeft instanceof ArrowType) { | |
result = "(" + arrowLeft.toString() + ")"; | |
} else { | |
result = arrowLeft.toString(); | |
} | |
result = result + " -> " + arrowRight.toString(); | |
return result; | |
} | |
} | |
// Name generator | |
static interface NameGenerator { | |
public String getNext(); | |
} | |
static class SimpleNameGenerator implements NameGenerator { | |
char curChar; | |
int curNumber; | |
public SimpleNameGenerator() { | |
curChar = 'a'; | |
curNumber = 0; | |
} | |
public String getNext() { | |
String result = Character.toString(curChar); | |
if (curNumber > 0) { | |
result += Integer.toString(curNumber); | |
} | |
if (curChar == 'z') { | |
curChar = 'a'; | |
curNumber++; | |
} else { | |
curChar++; | |
} | |
return result; | |
} | |
} | |
// Type inference | |
static void inferType(Expression expr, NameGenerator gen, TreeMap<Expression, Type> environment) throws UnificationError, InferenceError { | |
if (expr instanceof VariableExpression) { | |
if (environment.containsKey(expr)) { | |
return; | |
} else { | |
environment.put(expr, new VariableType(gen.getNext())); | |
} | |
} else if (expr instanceof ApplicationExpression) { | |
ApplicationExpression e = (ApplicationExpression) expr; | |
inferType(e.applArgument, gen, environment); | |
Type argType = environment.get(e.applArgument); | |
if (e.applFunction instanceof VariableExpression) { | |
VariableExpression v = (VariableExpression) e.applFunction; | |
if (environment.containsKey(v)) { | |
Type funcType = environment.get(v); | |
unifyTypes(new ArrowType(argType, new VariableType(gen.getNext())), funcType, environment); | |
funcType = environment.get(v); | |
if (funcType instanceof ArrowType) { | |
environment.put(expr, ((ArrowType) funcType).arrowRight); | |
} else { | |
System.err.println("ArrowType expected"); | |
throw new InferenceError(); | |
} | |
} else { | |
String curName = gen.getNext(); | |
environment.put(expr, new VariableType(curName)); | |
environment.put(e.applFunction, new ArrowType(argType, new VariableType(curName))); | |
} | |
} else if (e.applFunction instanceof ApplicationExpression) { | |
ApplicationExpression appl = (ApplicationExpression) e.applFunction; | |
inferType(appl, gen, environment); | |
Type funcType = environment.get(appl); | |
unifyTypes(funcType, new ArrowType(argType, new VariableType(gen.getNext())), environment); | |
funcType = environment.get(e.applFunction); | |
if (funcType instanceof ArrowType) { | |
ArrowType a = (ArrowType) funcType; | |
environment.put(expr, a.arrowRight); | |
} else { | |
System.err.println("ArrowType expected"); | |
throw new InferenceError(); | |
} | |
} else if (e.applFunction instanceof AbstractionExpression) { | |
AbstractionExpression abstr = (AbstractionExpression) e.applFunction; | |
inferType(abstr, gen, environment); | |
Type funcType = environment.get(abstr); | |
if (funcType instanceof ArrowType) { | |
ArrowType arr = (ArrowType) funcType; | |
if (areTypesCompatible(arr.arrowLeft, argType)) { | |
unifyTypes(arr.arrowLeft, argType, environment); | |
funcType = environment.get(abstr); | |
if (funcType instanceof ArrowType) { | |
environment.put(expr, ((ArrowType) funcType).arrowRight); | |
} else { | |
System.err.println("ArrowType expected"); | |
throw new InferenceError(); | |
} | |
} else { | |
throw new InferenceError(); | |
} | |
} else { | |
System.err.println("ArrowType expected"); | |
throw new InferenceError(); | |
} | |
} | |
} else if (expr instanceof AbstractionExpression) { | |
AbstractionExpression e = (AbstractionExpression) expr; | |
if (environment.containsKey(new VariableExpression (e.varName))) { | |
System.err.println("Duplicate bound variable in lambda abstraction!"); | |
throw new InferenceError(); | |
} else { | |
inferType(e.abstrBody, gen, environment); | |
Type bodyType = environment.get(e.abstrBody); | |
if (environment.containsKey(new VariableExpression (e.varName))) { | |
Type argType = environment.get(new VariableExpression (e.varName)); | |
environment.put(expr, new ArrowType(argType, bodyType)); | |
} else { | |
String curName = gen.getNext(); | |
environment.put(new VariableExpression (e.varName), new VariableType(curName)); | |
environment.put(expr, new ArrowType(new VariableType(curName), bodyType)); | |
} | |
} | |
} | |
} | |
// Helper function | |
static Type getType(Expression expr) throws UnificationError, InferenceError { | |
TreeMap<Expression, Type> env = new TreeMap<Expression, Type>(); | |
inferType(expr, new SimpleNameGenerator(), env); | |
return env.get(expr); | |
} | |
// Type compatibility checker | |
static boolean areTypesCompatible(Type t1, Type t2) { | |
if (t1 instanceof VariableType) { | |
return true; | |
} else if (t1 instanceof ArrowType && t2 instanceof ArrowType) { | |
ArrowType a1 = (ArrowType) t1; | |
ArrowType a2 = (ArrowType) t2; | |
return (areTypesCompatible(a1.arrowLeft, a2.arrowLeft) && areTypesCompatible(a1.arrowRight, a2.arrowRight)); | |
} | |
return false; | |
} | |
// Type variable substitution | |
static Type substituteTypeVariable(String var, Type replacement, Type haystack) { | |
if (haystack instanceof VariableType) { | |
VariableType t = (VariableType) haystack; | |
if (t.varName.equals(var)) { | |
return replacement; | |
} else { | |
return haystack; | |
} | |
} else if (haystack instanceof ArrowType) { | |
ArrowType t = (ArrowType) haystack; | |
return new ArrowType(substituteTypeVariable(var, replacement, t.arrowLeft), | |
substituteTypeVariable(var, replacement, t.arrowRight)); | |
} | |
return null; | |
} | |
// Type unification exception | |
static class UnificationError extends Exception { | |
private static final long serialVersionUID = 728975700118640646L; | |
} | |
// Type inference exception | |
static class InferenceError extends Exception { | |
private static final long serialVersionUID = 9186955675628337698L; | |
} | |
// Type unification | |
static void unifyTypes(Type t1, Type t2, TreeMap<Expression, Type> env) throws UnificationError { | |
if (t1 instanceof VariableType) { | |
VariableType v = (VariableType) t1; | |
for (Expression e : env.keySet()) { | |
env.put(e, substituteTypeVariable(v.varName, t2, env.get(e))); | |
} | |
} else if (t1 instanceof ArrowType && t2 instanceof ArrowType) { | |
ArrowType a1 = (ArrowType) t1; | |
ArrowType a2 = (ArrowType) t2; | |
unifyTypes(a1.arrowLeft, a2.arrowLeft, env); | |
unifyTypes(a1.arrowRight, a2.arrowRight, env); | |
} else { | |
throw new UnificationError(); | |
} | |
} | |
static void printTypedExpression(Expression expr) { | |
Type t = null; | |
try { | |
t = getType(expr); | |
} catch (Exception e) { | |
// e.printStackTrace(); | |
} | |
System.out.print(expr.toString() + " :: "); | |
if (t == null) { | |
System.out.println("Type inference error"); | |
} else { | |
System.out.println(t.toString()); | |
} | |
} | |
static class ParseError extends Exception { | |
private static final long serialVersionUID = 4127019321896484621L; | |
} | |
static class ExpressionParser { | |
char[] str; | |
int pos; | |
public ExpressionParser(String s) { | |
str = s.toCharArray(); | |
pos = 0; | |
} | |
public static boolean isAllowedChar(char c) { | |
return ('a' <= c && c <= 'z'); | |
} | |
public Expression parseExpression() throws ParseError { | |
if (pos >= str.length) { | |
throw new ParseError(); | |
} | |
if (isAllowedChar(str[pos])) { | |
return new VariableExpression(Character.toString(str[pos++])); | |
} else if (str[pos] == '\\') { | |
pos++; | |
Expression arg = parseExpression(); | |
if (arg instanceof VariableExpression) { | |
VariableExpression varArg = (VariableExpression) arg; | |
if (pos >= str.length) { | |
throw new ParseError(); | |
} | |
if (str[pos] == '.') { | |
pos++; | |
return new AbstractionExpression(varArg.varName, parseExpression()); | |
} else { | |
throw new ParseError(); | |
} | |
} else { | |
throw new ParseError(); | |
} | |
} else if (str[pos] == '(') { | |
pos++; | |
Expression func = parseExpression(); | |
if (pos >= str.length) { | |
throw new ParseError(); | |
} | |
if (str[pos] == ' ') { | |
pos++; | |
Expression arg = parseExpression(); | |
if (pos >= str.length) { | |
throw new ParseError(); | |
} | |
if (str[pos] == ')') { | |
pos++; | |
return new ApplicationExpression(func, arg); | |
} else { | |
throw new ParseError(); | |
} | |
} else { | |
throw new ParseError(); | |
} | |
} else { | |
throw new ParseError(); | |
} | |
} | |
} | |
static void printParsedTypedExpression(String str) { | |
try { | |
printTypedExpression(new ExpressionParser(str).parseExpression()); | |
} catch (ParseError e) { | |
// e.printStackTrace(); | |
} | |
} | |
public static void main(String[] args) { | |
Expression expr = new AbstractionExpression("x", | |
new AbstractionExpression("y", new ApplicationExpression( | |
new VariableExpression("x"), | |
new VariableExpression("y")))); | |
Expression expr2 = new AbstractionExpression("x", | |
new AbstractionExpression("y", | |
new VariableExpression("x"))); | |
Expression expr3 = new AbstractionExpression("x", | |
new AbstractionExpression("y", | |
new AbstractionExpression("z", | |
new ApplicationExpression( | |
new ApplicationExpression(new VariableExpression("x"), new VariableExpression("z")), | |
new ApplicationExpression(new VariableExpression("y"), new VariableExpression("z")))))); | |
printTypedExpression(expr); | |
printTypedExpression(expr2); | |
printTypedExpression(expr3); | |
printParsedTypedExpression("\\x.\\y.\\z.((y x) z)"); | |
printParsedTypedExpression("\\x.\\x.(x x)"); | |
printParsedTypedExpression("\\y.(\\x.(y x) y)"); | |
printParsedTypedExpression("\\x.\\y.((x y) y)"); | |
printParsedTypedExpression("\\x.\\x.x"); | |
return; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment