Created
August 21, 2024 14:41
-
-
Save zesterer/de62cbe06a47efee779859159394529c to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//! Type checking and inference in 100 lines of Rust | |
//! ---------------------------------- | |
//! (if you don't count comments) | |
#![allow(dead_code)] | |
/// The ID of a type variable. | |
#[derive(Copy, Clone, Debug, PartialEq)] | |
struct TyVar(usize); | |
/// Possibly-incomplete information known about a type variable's type. | |
#[derive(Copy, Clone, Debug)] | |
enum TyInfo { | |
/// No information is known about the type. | |
Unknown, | |
/// The type is equal to another type. | |
Ref(TyVar), | |
/// The type is an `Int` | |
Int, | |
/// The type is a `Bool` | |
Bool, | |
// Function, `A -> B` | |
Func(TyVar, TyVar), | |
} | |
/// An expression in the AST of a programming language. | |
#[derive(Debug)] | |
enum Expr<'a> { | |
/// Integer literal | |
Int(u64), | |
/// Boolean literal | |
Bool(bool), | |
/// Variable | |
Var(&'a str), | |
/// Let binding, `let lhs = rhs; then` | |
Let { lhs: &'a str, rhs: Box<Self>, then: Box<Self> }, | |
/// Inline function/lambda/closure, `fn(arg) body` | |
Func { arg: &'a str, body: Box<Self> }, | |
/// Function application/call, `func(arg)` | |
Apply { func: Box<Self>, arg: Box<Self> } | |
} | |
/// The final type of an expression. | |
#[derive(Debug)] | |
enum Ty { | |
/// The expression has type `Int`. | |
Int, | |
/// The expression has type `Bool`. | |
Bool, | |
/// The expression is a function from type `A` to `B` | |
Func(Box<Self>, Box<Self>), | |
} | |
/// Contains the state of the type solver. | |
#[derive(Default)] | |
struct Solver { vars: Vec<TyInfo> } | |
impl Solver { | |
/// Create a new type variable in the type solver' environment, with the given information. | |
fn create_ty(&mut self, info: TyInfo) -> TyVar { self.vars.push(info); TyVar(self.vars.len() - 1) } | |
/// Unify two type variables together, forcing them to be equal. | |
fn unify(&mut self, a: TyVar, b: TyVar) { | |
match (self.vars[a.0], self.vars[b.0]) { | |
(TyInfo::Unknown, _) => self.vars[a.0] = TyInfo::Ref(b), | |
(_, TyInfo::Unknown) => self.vars[b.0] = TyInfo::Ref(a), | |
(TyInfo::Ref(a), _) => self.unify(a, b), | |
(_, TyInfo::Ref(b)) => self.unify(a, b), | |
(TyInfo::Int, TyInfo::Int) | (TyInfo::Bool, TyInfo::Bool) => {}, | |
(TyInfo::Func(a_i, a_o), TyInfo::Func(b_i, b_o)) => { | |
self.unify(a_i, b_i); | |
self.unify(a_o, b_o); | |
}, | |
(a, b) => panic!("Type mismatch between {a:?} and {b:?}"), | |
} | |
} | |
/// Type-check an expression, returning a type variable representing its type, with the given environment/scope. | |
fn check<'ast>(&mut self, expr: &Expr<'ast>, env: &mut Vec<(&'ast str, TyVar)>) -> TyVar { | |
match expr { | |
// Literal expressions are easy, their type doesn't need inferring. | |
Expr::Int(_) => self.create_ty(TyInfo::Int), | |
Expr::Bool(_) => self.create_ty(TyInfo::Bool), | |
// We search the environment backward until we find a binding matching the variable name. | |
Expr::Var(name) => env.iter_mut().rev().find(|(n, _)| n == name).expect("No such variable in scope").1, | |
// In a let expression, `rhs` gets bound with name `lhs` in the environment used to type-check `then`. | |
Expr::Let { lhs, rhs, then } => { | |
let rhs = self.check(rhs, env); | |
env.push((lhs, rhs)); | |
let out = self.check(then, env); | |
env.pop(); | |
out | |
}, | |
// In a function, the argument becomes an unknown type in the environment used to type-check `body`. | |
Expr::Func { arg, body } => { | |
let arg_ty = self.create_ty(TyInfo::Unknown); | |
env.push((arg, arg_ty)); | |
let body = self.check(body, env); | |
env.pop(); | |
self.create_ty(TyInfo::Func(arg_ty, body)) | |
}, | |
// During function application, both argument and function are type-checked and then we force the latter to be a function of the former. | |
Expr::Apply { func, arg } => { | |
let func = self.check(func, env); | |
let arg = self.check(arg, env); | |
let out = self.create_ty(TyInfo::Unknown); | |
let func_ty = self.create_ty(TyInfo::Func(arg, out)); | |
self.unify(func_ty, func); | |
out | |
}, | |
} | |
} | |
/// Convert a type variable into a final type once type-checking has finished. | |
pub fn solve(&self, var: TyVar) -> Ty { | |
match self.vars[var.0] { | |
TyInfo::Unknown => panic!("Cannot infer type"), | |
TyInfo::Ref(var) => self.solve(var), | |
TyInfo::Int => Ty::Int, | |
TyInfo::Bool => Ty::Bool, | |
TyInfo::Func(i, o) => Ty::Func(Box::new(self.solve(i)), Box::new(self.solve(o))), | |
} | |
} | |
} | |
fn expect(tokens: &mut &[Token], expected: Token) -> Result<(), String> { | |
match tokens { | |
[tok, tail @ ..] if *tok == expected => Ok(*tokens = tail), | |
[tok, ..] => Err(format!("Expected {expected:?}, found {tok:?}")), | |
[] => Err(format!("Expected {expected:?}, found end of input")), | |
} | |
} | |
fn parse_list<'a, R>(tokens: &mut &'a [Token], mut f: impl FnMut(&mut &'a [Token]) -> Result<R, String>) -> Result<Vec<R>, String> { | |
let mut items = Vec::new(); | |
loop { | |
items.push(f(tokens)?); | |
match *tokens { | |
[Token::Comma] | [] => break Ok(items), | |
[Token::Comma, tail @ ..] => { | |
*tokens = tail; | |
}, | |
[tok, ..] => return Err(format!("Expected argument, found {tok:?}")), | |
} | |
} | |
} | |
fn parse_ident<'a>(tokens: &mut &'a [Token]) -> Result<&'a str, String> { | |
match *tokens { | |
[Token::Ident(ident), tail @ ..] => { | |
*tokens = tail; | |
Ok(ident) | |
}, | |
[tok, ..] => Err(format!("Expected ident, found {tok:?}")), | |
[] => Err(format!("Expected ident, found end of input")), | |
} | |
} | |
fn parse_expr<'a>(tokens: &mut &'a [Token]) -> Result<Expr<'a>, String> { | |
let mut expr = match *tokens { | |
[Token::Ident(name), tail @ ..] => { | |
*tokens = tail; | |
Expr::Var(name) | |
}, | |
[Token::Int(x), tail @ ..] => { | |
*tokens = tail; | |
Expr::Int(*x) | |
}, | |
[Token::Let, Token::Ident(lhs), tail @ ..] => { | |
*tokens = tail; | |
expect(tokens, Token::Eq)?; | |
let rhs = Box::new(parse_expr(tokens)?); | |
expect(tokens, Token::Semicolon)?; | |
let then = Box::new(parse_expr(tokens)?); | |
Expr::Let { lhs, rhs, then } | |
}, | |
[Token::Fn, Token::Parens(args), tail @ ..] => { | |
*tokens = tail; | |
let args = parse_list(&mut &args[..], parse_ident)?; | |
args.into_iter().rev().fold( | |
parse_expr(tokens)?, | |
|body, arg| Expr::Func { arg, body: Box::new(body) }, | |
) | |
}, | |
[tok, ..] => return Err(format!("Expected expression, found {tok:?}")), | |
[] => return Err(format!("Expected expression, found end of input")), | |
}; | |
while let [Token::Parens(args), tail @ ..] = *tokens { | |
*tokens = tail; | |
let args = parse_list(&mut &args[..], parse_expr)?; | |
expr = args.into_iter().fold( | |
expr, | |
|func, arg| Expr::Apply { func: Box::new(func), arg: Box::new(arg) }, | |
); | |
} | |
Ok(expr) | |
} | |
#[derive(Debug, PartialEq)] | |
enum Token<'a> { | |
Int(u64), | |
Ident(&'a str), | |
Let, | |
Eq, | |
Fn, | |
Semicolon, | |
Comma, | |
Parens(Vec<Self>), | |
} | |
fn take<'a>(s: &mut &'a str, mut f: impl FnMut(&char) -> bool) -> &'a str { | |
match s.char_indices().skip_while(|(_, c) | f(c)).next() { | |
Some((idx, _)) => { | |
let r = &s[..idx]; | |
*s = &s[idx..]; | |
r | |
}, | |
None => { | |
let r = *s; | |
*s = ""; | |
r | |
}, | |
} | |
} | |
fn skip(s: &mut &str) { | |
let mut chars = s.chars(); | |
chars.next(); | |
*s = chars.as_str(); | |
} | |
fn lex<'a>(src: &mut &'a str) -> Result<Vec<Token<'a>>, String> { | |
let mut tokens = Vec::new(); | |
loop { | |
tokens.push(match src.chars().next() { | |
Some(c) if c.is_ascii_digit() => { | |
let x = take(src, char::is_ascii_digit); | |
Token::Int(x.parse().unwrap()) | |
}, | |
Some(c) if c.is_ascii_alphabetic() => { | |
match take(src, char::is_ascii_alphanumeric) { | |
"let" => Token::Let, | |
"fn" => Token::Fn, | |
x => Token::Ident(x), | |
} | |
}, | |
Some('=') => { skip(src); Token::Eq }, | |
Some(';') => { skip(src); Token::Semicolon }, | |
Some('(') => { | |
skip(src); | |
Token::Parens(lex(src)?) | |
}, | |
Some(c) if c.is_whitespace() => { skip(src); continue }, | |
Some(')') | None => { skip(src); break Ok(tokens) }, | |
Some(c) => break Err(format!("Unexpected character {c:?}")), | |
}) | |
} | |
} | |
fn main() { | |
let tokens = lex(&mut "let f = fn(x) x; f(42)").unwrap(); | |
let expr = parse_expr(&mut &*tokens).unwrap(); | |
println!("{expr:?}"); | |
let mut solver = Solver::default(); | |
let program_ty = solver.check(&expr, &mut Vec::new()); | |
println!("The expression outputs type `{:?}`", solver.solve(program_ty)); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment