Created
February 5, 2024 21:36
-
-
Save reverofevil/4f7434275756fcf3c78bae8cce359197 to your computer and use it in GitHub Desktop.
Remy's algorithm in TypeScript
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type VarName = string; | |
type Expr = Var | App | Lam | Let; | |
type Var = { $: 'var', readonly name: VarName } | |
const Var = (name: VarName): Var => ({ $: 'var', name }); | |
type App = { $: 'app', readonly left: Expr, readonly right: Expr } | |
const App = (left: Expr, right: Expr): App => ({ $: 'app', left, right }); | |
type Lam = { $: 'lam', readonly name: VarName, readonly expr: Expr } | |
const Lam = (name: VarName, expr: Expr): Lam => ({ $: 'lam', name, expr }); | |
type Let = { $: 'let', readonly name: VarName, readonly def: Expr, readonly body: Expr } | |
const Let = (name: VarName, def: Expr, body: Expr): Let => ({ $: 'let', name, def, body }); | |
type Level = number; | |
const genericLevel: Level = 100000000; | |
const markedLevel: Level = -1; | |
type Type = TVar | TArrow | |
type TVar = { readonly $: 'tvar', link: TV } | |
const TVar = (link: TV): TVar => ({ $: 'tvar', link }); | |
type TArrow = { readonly $: 'tarrow', readonly left: Type, readonly right: Type, readonly levels: Levels } | |
const TArrow = (left: Type, right: Type, levels: Levels): TArrow => ({ $: 'tarrow', left, right, levels }); | |
type TV = Unbound | Link | |
type Unbound = { readonly $: 'unbound', name: string, level: Level; } | |
const Unbound = (name: string, level: Level): Unbound => ({ $: 'unbound', name, level }); | |
type Link = { readonly $: 'link', type: Type } | |
const Link = (type: Type): Link => ({ $: 'link', type }); | |
type Levels = { old: Level; new: Level } | |
type Env = Record<VarName, Type> | |
const repr = (type: Type): Type => { | |
if (type.$ === 'tvar' && type.link.$ === 'link') { | |
const r = repr(type.link.type); | |
type.link = Link(r); | |
return r; | |
} else { | |
return type; | |
} | |
}; | |
const getLevel = (type: Type): Level => { | |
if (type.$ === 'tvar' && type.link.$ === 'unbound') { | |
return type.link.level; | |
} else if (type.$ === 'tarrow') { | |
return type.levels.new; | |
} else { | |
throw new Error('Impossible'); | |
} | |
}; | |
const checkOccurs = (type: Type) => { | |
if (!type) debugger; | |
if (type.$ === 'tvar') { | |
if (type.link.$ === 'link') { | |
checkOccurs(type.link.type); | |
} | |
} else if (type.$ === 'tarrow') { | |
if (type.levels.new === markedLevel) { | |
throw new Error('Occurs check'); | |
} else { | |
const level = type.levels.new; | |
type.levels.new = markedLevel; | |
checkOccurs(type.left); | |
checkOccurs(type.right); | |
type.levels.new = level; | |
} | |
} | |
}; | |
const top_type_check = (exp: Expr): Type => { | |
let nextVarId = 0; | |
let currentLevel = 1; | |
// toBeLevelAdjusted should be a linked list for better performance | |
let toBeLevelAdjusted: TArrow[] = []; | |
const gensym = () => { | |
const n = nextVarId++; | |
return n < 26 ? String.fromCharCode('a'.charCodeAt(0) + n) : `t${n}`; | |
}; | |
const newvar = () => TVar(Unbound(gensym(), currentLevel)); | |
const new_arrow = (ty1: Type, ty2: Type) => { | |
if (!ty1) debugger; | |
return TArrow(ty1, ty2, { | |
new: currentLevel, | |
old: currentLevel, | |
}); | |
}; | |
const updateLevel = (l: Level, type: Type) => { | |
if (type.$ === 'tvar') { | |
if (type.link.$ === 'link' || type.link.level === genericLevel) { | |
throw new Error('Impossible'); | |
} | |
if (l < type.link.level) { | |
type.link = Unbound(type.link.name, l); | |
} | |
} else if (type.$ === 'tarrow') { | |
if (type.levels.new === genericLevel) { | |
throw new Error('Impossible'); | |
} | |
if (type.levels.new === markedLevel) { | |
throw new Error('Occurs check'); | |
} | |
if (l < type.levels.new) { | |
if (type.levels.new === type.levels.old) { | |
toBeLevelAdjusted.push(type); | |
} | |
type.levels.new = l; | |
} | |
} | |
}; | |
const unifyLevels = (l: Level, ty1: Type, ty2: Type) => { | |
ty1 = repr(ty1); | |
updateLevel(l, ty1); | |
unify(ty1, ty2); | |
}; | |
const unify = (t1: Type, t2: Type) => { | |
if (t1 === t2) return; | |
t1 = repr(t1); | |
t2 = repr(t2); | |
if (t1.$ === 'tvar' && t2.$ === 'tvar' && t1.link.$ === 'unbound' && t2.link.$ === 'unbound') { | |
if (t1.link === t2.link) return; | |
if (t1.link.level > t2.link.level) { | |
t1.link = Link(t2); | |
} else { | |
t2.link = Link(t1); | |
} | |
} else if (t1.$ === 'tvar' && t1.link.$ === 'unbound') { | |
updateLevel(t1.link.level, t2); | |
t1.link = Link(t2); | |
} else if (t2.$ === 'tvar' && t2.link.$ === 'unbound') { | |
updateLevel(t2.link.level, t1); | |
t2.link = Link(t1); | |
} else if (t1.$ === 'tarrow' && t2.$ === 'tarrow') { | |
if (t1.levels.new === markedLevel || t2.levels.new === markedLevel) { | |
throw new Error('Occurs check'); | |
} | |
const minLevel = Math.min(t1.levels.new, t2.levels.new); | |
t1.levels.new = t2.levels.new = markedLevel; | |
unifyLevels(minLevel, t1.left, t2.left); | |
unifyLevels(minLevel, t1.right, t2.right); | |
t1.levels.new = t2.levels.new = minLevel; | |
} else { | |
throw new Error('Unification error'); | |
} | |
}; | |
const forceDelayedAdjustment = () => { | |
const delayAgain: TArrow[] = []; | |
const loop = (level: Level, ty: Type): void => { | |
ty = repr(ty); | |
if (ty.$ === 'tvar' && ty.link.$ === 'unbound' && ty.link.level > level) { | |
// TVar ((Unbound (name,l)} as tvr) when l > level -> | |
ty.link = Unbound(ty.link.name, level); | |
} else if (ty.$ === 'tarrow') { | |
// TArrow (_,_,ls) as ty -> | |
if (ty.levels.new === markedLevel) { | |
throw new Error('Occurs check'); | |
} | |
ty.levels.new = Math.min(ty.levels.new, level); | |
adjustOne(ty); | |
} | |
}; | |
const adjustOne = (ty: TArrow) => { | |
if (ty.levels.old <= currentLevel) { | |
delayAgain.push(ty); | |
} else if (ty.levels.old !== ty.levels.new) { | |
const level = ty.levels.new; | |
ty.levels.new = markedLevel; | |
loop(level, ty.left); | |
loop(level, ty.right); | |
ty.levels.new = ty.levels.old = level; | |
} | |
}; | |
for (const ty of toBeLevelAdjusted) { | |
adjustOne(ty); | |
} | |
toBeLevelAdjusted = delayAgain; | |
}; | |
const inst = (ty: Type): Type => { | |
const subst: Record<string, TVar> = {}; | |
const loop = (ty: Type): Type => { | |
if (ty.$ === 'tvar' && ty.link.$ === 'unbound' && ty.link.level === genericLevel) { | |
return subst[ty.link.name] = ty.link.name in subst ? subst[ty.link.name] : newvar(); | |
} else if (ty.$ === 'tvar' && ty.link.$ === 'link') { | |
return loop(ty.link.type); | |
} else if (ty.$ === 'tarrow' && ty.levels.new === genericLevel) { | |
return new_arrow(loop(ty.left), loop(ty.right)); | |
} else { | |
return ty; | |
} | |
}; | |
return loop(ty); | |
}; | |
const gen = (ty: Type) => { | |
forceDelayedAdjustment(); | |
const loop = (ty: Type) => { | |
ty = repr(ty); | |
if (ty.$ === 'tvar' && ty.link.$ === 'unbound' && ty.link.level > currentLevel) { | |
ty.link = Unbound(ty.link.name, genericLevel); | |
} else if (ty.$ === 'tarrow' && ty.levels.new > currentLevel) { | |
const ty1 = repr(ty.left); | |
const ty2 = repr(ty.right); | |
loop(ty1); | |
loop(ty2); | |
const l = Math.max(getLevel(ty1), getLevel(ty2)); | |
ty.levels.old = ty.levels.new = l; | |
} | |
}; | |
loop(ty); | |
}; | |
const env: Env = {}; | |
const withEnv = <T>(name: string, type: Type, f: () => T) => { | |
const prev = env[name]; | |
env[name] = type; | |
const res = f(); | |
env[name] = prev; | |
return res; | |
}; | |
const infer = (exp: Expr): Type => { | |
if (exp.$ === 'var') { | |
return inst(env[exp.name]); | |
} else if (exp.$ === 'lam') { | |
const tyX = newvar(); | |
const tyE = withEnv(exp.name, tyX, () => infer(exp.expr)); | |
return new_arrow(tyX, tyE); | |
} else if (exp.$ === 'app') { | |
const ty_fun = infer(exp.left); | |
const ty_arg = infer(exp.right); | |
const ty_res = newvar(); | |
unify(ty_fun, new_arrow(ty_arg, ty_res)); | |
return ty_res; | |
} else if (exp.$ === 'let') { | |
++currentLevel; | |
const ty_e = infer(exp.def); | |
--currentLevel; | |
gen(ty_e); | |
return withEnv(exp.name, ty_e, () => infer(exp.body)); | |
} else { | |
throw new Error('Impossible'); | |
} | |
}; | |
const ty = infer(exp); | |
checkOccurs(ty); | |
return ty; | |
}; | |
interface TypeAlg<TV, T> { | |
unbound: (name: string, level: Level) => TV; | |
link: (type: T) => TV; | |
tvar: (link: TV) => T; | |
tarrow: (left: T, right: T, levels: Levels) => T; | |
} | |
const visitType = <U, T>({ tarrow, tvar, link, unbound }: TypeAlg<U, T>) => { | |
const visitTV = (tv: TV): U => { | |
if (tv.$ === 'unbound') { | |
return unbound(tv.name, tv.level); | |
} else if (tv.$ === 'link') { | |
return link(visitT(tv.type)); | |
} | |
}; | |
const visitT = (type: Type): T => { | |
if (type.$ === 'tarrow') { | |
return tarrow(visitT(type.left), visitT(type.right), type.levels); | |
} else if (type.$ === 'tvar') { | |
return tvar(visitTV(type.link)); | |
} | |
}; | |
return visitT; | |
}; | |
enum Side { N, L, R } | |
type ShowType = (prio: number, side: Side) => string; | |
const wrap = (cond: boolean, s: string) => cond ? `(${s})` : s; | |
const showTypeAlg: TypeAlg<ShowType, ShowType> = { | |
unbound: (name) => () => name, | |
link: (type) => type, | |
tvar: (link) => link, | |
tarrow: (left, right) => (prio, side) => { | |
return wrap( | |
prio > 1 || prio === 1 && side === Side.L, | |
`${left(1, Side.L)} -> ${right(1, Side.R)}`, | |
); | |
}, | |
}; | |
const showType = visitType(showTypeAlg); | |
const id = Lam("x", Var("x"));; | |
const c1 = Lam("x", Lam("y", App(Var("x"), Var("y"))));; | |
const level1 = { new: 1, old: 1 }; | |
type Ok = { readonly $: 'ok', value: Type }; | |
type Fail = { readonly $: 'fail' }; | |
type TestResult = Ok | Fail; | |
type Test = { | |
expr: Expr, | |
result: TestResult, | |
} | |
const tests: Test[] = [ | |
{ | |
expr: id, | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Unbound("a", 1)), TVar(Unbound("a", 1)), level1), | |
}, | |
}, | |
{ | |
expr: c1, | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Link(TArrow(TVar(Unbound("b", 1)), TVar(Unbound("c", 1)), level1))), TArrow(TVar(Unbound("b", 1)), TVar(Unbound("c", 1)), level1), level1), | |
}, | |
}, | |
{ | |
expr: Let("x", c1, Var("x")), | |
result: { | |
$: 'ok', | |
value: TArrow(TArrow(TVar(Unbound("d", 1)), TVar(Unbound("e", 1)), level1), TArrow(TVar(Unbound("d", 1)), TVar(Unbound("e", 1)), level1), level1), | |
}, | |
}, | |
{ | |
expr: Let("y", Lam("z", Var("z")), Var("y")), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Unbound("b", 1)), TVar(Unbound("b", 1)), level1), | |
}, | |
}, | |
{ | |
expr: Lam("x", Let("y", Lam("z", Var("z")), Var("y"))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Unbound("a", 1)), TArrow(TVar(Unbound("c", 1)), TVar(Unbound("c", 1)), level1), level1), | |
}, | |
}, | |
{ | |
expr: Lam("x", Let("y", Lam("z", Var("z")), App(Var("y"), Var("x")))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Link(TVar(Unbound("c", 1)))), TVar(Link(TVar(Unbound("c", 1)))), level1) | |
}, | |
}, | |
{ | |
expr: Lam("x", App(Var("x"), Var("x"))), | |
result: { | |
$: 'fail', | |
}, | |
}, | |
{ | |
expr: Let("x", Var("x"), Var("x")), | |
result: { | |
$: 'fail', | |
}, | |
}, | |
{ | |
// y -> y (z -> y z) d | |
expr: Lam("y", App(Var("y"), (Lam("z", (App(Var("y"), Var("z"))))))), | |
result: { | |
$: 'fail', | |
}, | |
}, | |
{ | |
// x -> y -> k -> k (k x y) (k y x) | |
expr: Lam("x", Lam("y", Lam("k", App(App(Var("k"), App(App(Var("k"), Var("x")), Var("y"))), App(App(Var("k"), Var("y")), Var("x")))))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Unbound("a", 1)), TArrow(TVar(Link(TVar(Unbound("a", 1)))), TArrow(TVar(Link(TArrow(TVar(Unbound("a", 1)), TVar(Link(TArrow(TVar(Link(TVar(Unbound("a", 1)))), TVar(Link(TVar(Unbound("a", 1)))), level1))), level1))), TVar(Link(TVar(Unbound("a", 1)))), level1), level1), level1), | |
}, | |
}, | |
{ | |
expr: Let("id", id, App(Var("id"), Var("id"))), | |
result: { | |
$: 'ok', | |
value: TVar(Link(TArrow(TVar(Unbound("c", 1)), TVar(Unbound("c", 1)), level1))), | |
}, | |
}, | |
{ | |
expr: Let("x", c1, Let("y", Let("z", App(Var("x"), id), Var("z")), Var("y"))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Unbound("i", 1)), TVar(Unbound("i", 1)), level1), | |
}, | |
}, | |
{ | |
// x -> y -> let x = x y in x -> y x | |
expr: Lam("x", Lam("y", Let("x", App(Var("x"), Var("y")), Lam("x", App(Var("y"), Var("x")))))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Link(TArrow(TVar(Link(TArrow(TVar(Unbound("d", 1)), TVar(Unbound("e", 1)), level1))), TVar(Unbound("c", 1)), level1))), TArrow(TVar(Link(TArrow(TVar(Unbound("d", 1)), TVar(Unbound("e", 1)), level1))), TArrow(TVar(Unbound("d", 1)), TVar(Unbound("e", 1)), level1), level1), level1), | |
}, | |
}, | |
{ | |
expr: Lam("x", Let("y", Var("x"), Var("y"))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Unbound("a", 1)), TVar(Unbound("a", 1)), level1), | |
}, | |
}, | |
{ | |
expr: Lam("x", Let("y", Lam("z", Var("x")), Var("y"))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Unbound("a", 1)), TArrow(TVar(Unbound("c", 1)), TVar(Unbound("a", 1)), level1), level1) | |
}, | |
}, | |
{ | |
expr: Lam("x", Let("y", Lam("z", App(Var("x"), Var("z"))), Var("y"))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Link(TArrow(TVar(Unbound("b", 1)), TVar(Unbound("c", 1)), level1))), TArrow(TVar(Unbound("b", 1)), TVar(Unbound("c", 1)), level1), level1), | |
}, | |
}, | |
{ | |
expr: Lam("x", Lam("y", Let("x", App(Var("x"), Var("y")), App(Var("x"), Var("y"))))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Link(TArrow(TVar(Unbound("b", 1)), TVar(Link(TArrow(TVar(Unbound("b", 1)), TVar(Unbound("d", 1)), level1))), level1))), TArrow(TVar(Unbound("b", 1)), TVar(Unbound("d", 1)), level1), level1), | |
}, | |
}, | |
{ | |
expr: Lam("x", Let("y", Var("x"), App(Var("y"), Var("y")))), | |
result: { | |
$: 'fail', | |
}, | |
}, | |
{ | |
// x -> let y = let z = x (x -> x) in z in y | |
expr: Lam("x", Let("y", Let("z", App(Var("x"), id), Var("z")), Var("y"))), | |
result: { | |
$: 'ok', | |
value: TArrow(TVar(Link(TArrow(TArrow(TVar(Unbound("b", 1)), TVar(Unbound("b", 1)), level1), TVar(Unbound("c", 1)), level1))), TVar(Unbound("c", 1)), level1), | |
} | |
}, | |
] as Test[]; | |
const execTest = (expr: Expr): TestResult => { | |
try { | |
return { $: 'ok', value: top_type_check(expr) }; | |
} catch (e) { | |
return { $: 'fail' }; | |
} | |
}; | |
for (const { expr, result: expected } of tests) { | |
const factual = execTest(expr); | |
if (factual.$ === 'ok' && expected.$ === 'ok') { | |
if (JSON.stringify(expected) === JSON.stringify(factual)) { | |
console.log('ok'); | |
} else { | |
console.log('fail', { expected, factual }); | |
} | |
} else if (factual.$ === 'fail' && expected.$ === 'fail') { | |
console.log('ok'); | |
} else { | |
console.log('fail', { expected, factual }); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment