Last active
March 5, 2024 15:23
-
-
Save mbillingr/18a673b7588aaa3b0befe3d76f128de1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::cell::RefCell; | |
use std::collections::{HashMap, HashSet}; | |
use std::rc::Rc; | |
use std::sync::atomic::{AtomicU64, Ordering}; | |
fn main() { | |
use Term::*; | |
let id = Lam { | |
name: "x".into(), | |
body: Var { name: "x".into() }.into(), | |
}; | |
let expr = App { | |
lhs: id.clone().into(), | |
rhs: Lit { value: 42 }.into(), | |
}; | |
let ctx = Ctx::default(); | |
let st = ctx.type_term(&expr); | |
let ty = coalesce_type(&st); | |
println!("{:?}", ty); | |
let twice = Lam { | |
name: "f".into(), | |
body: Lam { | |
name: "x".into(), | |
body: App { | |
lhs: Var { name: "f".into() }.into(), | |
rhs: App { | |
lhs: Var { name: "f".into() }.into(), | |
rhs: Var { name: "x".into() }.into(), | |
} | |
.into(), | |
} | |
.into(), | |
} | |
.into(), | |
}; | |
let ctx = Ctx::default(); | |
println!("{:?}", coalesce_type(&ctx.type_term(&twice))) | |
} | |
type Int = i64; | |
type Str = Ref<str>; | |
type Ref<T> = Rc<T>; | |
#[derive(Clone)] | |
enum Term { | |
Lit { | |
value: Int, | |
}, | |
Var { | |
name: Str, | |
}, | |
Lam { | |
name: Str, | |
body: Ref<Term>, | |
}, | |
App { | |
lhs: Ref<Term>, | |
rhs: Ref<Term>, | |
}, | |
Rcd { | |
fields: Vec<(Str, Term)>, | |
}, | |
Sel { | |
receiver: Ref<Term>, | |
field_name: Str, | |
}, | |
Let { | |
is_rec: bool, | |
name: Str, | |
rhs: Ref<Term>, | |
body: Ref<Term>, | |
}, | |
} | |
#[derive(Clone, Debug)] | |
enum SimpleType { | |
Variable(Ref<VariableState>), | |
Primitive(Str), | |
Function(Ref<SimpleType>, Ref<SimpleType>), | |
Record(Ref<HashMap<Str, SimpleType>>), | |
} | |
#[derive(Debug)] | |
struct VariableState { | |
lower_bounds: RefCell<List<SimpleType>>, | |
upper_bounds: RefCell<List<SimpleType>>, | |
unique_name: Str, | |
} | |
enum Type { | |
Top, | |
Bot, | |
Union(Ref<Type>, Ref<Type>), | |
Inter(Ref<Type>, Ref<Type>), | |
Function(Ref<Type>, Ref<Type>), | |
Record(Ref<HashMap<Str, Type>>), | |
Recursive { name: Str, body: Ref<Type> }, | |
Variable(Str), | |
Primitive(Str), | |
} | |
#[derive(Clone)] | |
struct PolarVariable(Ref<VariableState>, P); | |
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] | |
enum P { | |
Val, | |
Use, | |
} | |
impl std::ops::Not for P { | |
type Output = Self; | |
fn not(self) -> Self { | |
match self { | |
P::Val => P::Use, | |
P::Use => P::Val, | |
} | |
} | |
} | |
impl Eq for SimpleType {} | |
impl PartialEq for SimpleType { | |
fn eq(&self, other: &Self) -> bool { | |
use SimpleType::*; | |
match (self, other) { | |
(Variable(a), Variable(b)) => Ref::ptr_eq(a, b), | |
(Primitive(a), Primitive(b)) => a == b, | |
(Function(a1, r1), Function(a2, r2)) => a1 == a2 && r1 == r2, | |
(Record(a), Record(b)) => Ref::ptr_eq(a, b), | |
_ => false, | |
} | |
} | |
} | |
impl std::hash::Hash for SimpleType { | |
fn hash<H: std::hash::Hasher>(&self, h: &mut H) { | |
match self { | |
SimpleType::Variable(rc) => std::ptr::hash(Rc::as_ptr(rc), h), | |
SimpleType::Primitive(name) => name.hash(h), | |
SimpleType::Function(a, r) => { | |
a.hash(h); | |
r.hash(h); | |
} | |
SimpleType::Record(fs) => { | |
for f in fs.iter() { | |
f.hash(h); | |
} | |
} | |
} | |
} | |
} | |
impl Eq for PolarVariable {} | |
impl PartialEq for PolarVariable { | |
fn eq(&self, other: &Self) -> bool { | |
Rc::ptr_eq(&self.0, &other.0) && self.1 == other.1 | |
} | |
} | |
impl std::hash::Hash for PolarVariable { | |
fn hash<H: std::hash::Hasher>(&self, h: &mut H) { | |
std::ptr::hash(Rc::as_ptr(&self.0), h); | |
self.1.hash(h) | |
} | |
} | |
impl VariableState { | |
pub fn new() -> Self { | |
VariableState { | |
lower_bounds: Default::default(), | |
upper_bounds: Default::default(), | |
unique_name: unique_name(), | |
} | |
} | |
} | |
impl std::fmt::Debug for Type { | |
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | |
match self { | |
Type::Top => write!(f, "⊤"), | |
Type::Bot => write!(f, "⊥"), | |
Type::Union(lhs, rhs) => write!(f, "{lhs:?} ∨ {rhs:?}"), | |
Type::Inter(lhs, rhs) => write!(f, "{lhs:?} ∧ {rhs:?}"), | |
Type::Variable(v) => write!(f, "{v}"), | |
Type::Primitive(n) => write!(f, "{n}"), | |
Type::Function(lhs, rhs) => write!(f, "({:?} -> {:?})", lhs, rhs), | |
Type::Record(fs) => { | |
write!(f, "{{")?; | |
let mut fs = fs.iter(); | |
if let Some((n, t)) = fs.next() { | |
write!(f, "{n}: {t:?}")?; | |
} | |
for (n, t) in fs { | |
write!(f, ", {n}: {t:?}")?; | |
} | |
write!(f, "}}") | |
} | |
Type::Recursive{name, body} => write!(f, "({body:?} as {name})"), | |
} | |
} | |
} | |
impl SimpleType { | |
fn function(lhs: SimpleType, rhs: SimpleType) -> Self { | |
Self::Function(Ref::new(lhs), Ref::new(rhs)) | |
} | |
} | |
#[derive(Default)] | |
struct Ctx { | |
vars: HashMap<Str, SimpleType>, | |
constraint_cache: Ref<RefCell<HashSet<(SimpleType, SimpleType)>>>, | |
} | |
impl Ctx { | |
fn type_term(&self, term: &Term) -> SimpleType { | |
use SimpleType::*; | |
use Term::*; | |
match term { | |
Lit { .. } => Primitive("int".into()), | |
Var { name } => self | |
.vars | |
.get(&**name) | |
.cloned() | |
.unwrap_or_else(|| Self::err(format!("{} not found", name))), | |
Rcd { fields } => Record(Ref::new( | |
fields | |
.iter() | |
.map(|(n, t)| (n.clone(), self.type_term(t))) | |
.collect(), | |
)), | |
Lam { name, body } => { | |
let param = self.fresh_var(); | |
let ctx_ = self.bind_var(name.clone(), param.clone()); | |
SimpleType::function(param, ctx_.type_term(body)) | |
} | |
App { lhs, rhs } => { | |
let res = self.fresh_var(); | |
self.constrain( | |
self.type_term(lhs), | |
SimpleType::function(self.type_term(rhs), res.clone()), | |
); | |
res | |
} | |
Sel { | |
receiver, | |
field_name, | |
} => { | |
let res = self.fresh_var(); | |
let mut rec = HashMap::default(); | |
rec.insert(field_name.clone(), res.clone()); | |
self.constrain(self.type_term(receiver), Record(Ref::new(rec))); | |
res | |
} | |
_ => todo!(), | |
} | |
} | |
fn constrain(&self, lhs: SimpleType, rhs: SimpleType) { | |
let types = (lhs, rhs); | |
{ | |
let mut cc = self.constraint_cache.borrow_mut(); | |
if cc.contains(&types) { | |
return; | |
} | |
cc.insert(types.clone()); | |
} | |
use SimpleType::*; | |
match types { | |
(Primitive(a), Primitive(b)) if a == b => {} | |
(Function(a1, r1), Function(a2, r2)) => { | |
self.constrain((*a2).clone(), (*a1).clone()); | |
self.constrain((*r1).clone(), (*r2).clone()); | |
} | |
(Record(ref fs1), Record(fs2)) => { | |
for (n2, t2) in fs2.iter() { | |
match fs1.get(n2) { | |
Some(t1) => self.constrain(t1.clone(), t2.clone()), | |
None => Self::err(format!("missing field: {n2} in {:?}", types.0)), | |
} | |
} | |
} | |
(Variable(lhs), rhs) => { | |
cons_(rhs.clone(), &lhs.upper_bounds); | |
for lb in &(*lhs.lower_bounds.borrow()).clone() { | |
self.constrain(lb.clone(), rhs.clone()); | |
} | |
} | |
(lhs, Variable(rhs)) => { | |
cons_(lhs.clone(), &rhs.lower_bounds); | |
for ub in &(*rhs.upper_bounds.borrow()).clone() { | |
self.constrain(lhs.clone(), ub.clone()); | |
} | |
} | |
(lhs, rhs) => Self::err(format!("cannot constrain {lhs:?} <: {rhs:?}")), | |
} | |
} | |
fn fresh_var(&self) -> SimpleType { | |
SimpleType::Variable(Ref::new(VariableState::new())) | |
} | |
fn err(msg: impl ToString) -> ! { | |
panic!("type error: {}", msg.to_string()) | |
} | |
} | |
impl Ctx { | |
fn bind_var(&self, name: Str, ty: SimpleType) -> Self { | |
let mut vars = self.vars.clone(); | |
vars.insert(name, ty); | |
Ctx { | |
vars, | |
constraint_cache: self.constraint_cache.clone(), | |
} | |
} | |
} | |
fn coalesce_type(ty: &SimpleType) -> Type { | |
let mut recursive: HashMap<PolarVariable, Str> = Default::default(); | |
go(ty, P::Val, &Default::default(), &mut recursive) | |
} | |
fn go( | |
ty: &SimpleType, | |
polar: P, | |
in_process: &HashSet<PolarVariable>, | |
recursive: &mut HashMap<PolarVariable, Str>, | |
) -> Type { | |
match ty { | |
SimpleType::Primitive(name) => Type::Primitive(name.clone()), | |
SimpleType::Function(lhs, rhs) => Type::Function( | |
Ref::new(go(lhs, !polar, in_process, recursive)), | |
Ref::new(go(rhs, polar, in_process, recursive)), | |
), | |
SimpleType::Record(fs) => Type::Record(Ref::new( | |
fs.iter() | |
.map(|(n, t)| (n.clone(), go(t, polar, in_process, recursive))) | |
.collect(), | |
)), | |
SimpleType::Variable(vs) => { | |
let vs_pol = PolarVariable(vs.clone(), polar); | |
if in_process.contains(&vs_pol) { | |
let name = recursive.entry(vs_pol).or_insert_with(|| unique_name()); | |
Type::Variable(name.clone()) | |
} else { | |
let bounds = match polar { | |
P::Val => (*vs.lower_bounds.borrow()).clone(), | |
P::Use => (*vs.upper_bounds.borrow()).clone(), | |
}; | |
let mut ip_ = in_process.clone(); | |
ip_.insert(vs_pol.clone()); | |
let mut bound_types = vec![]; | |
for b in &bounds { | |
bound_types.push(go(b, polar, &ip_, recursive)) | |
} | |
let mut res = Type::Variable(vs.unique_name.clone()); | |
for t in bound_types { | |
match polar { | |
P::Val => res = Type::Union(Ref::new(t), Ref::new(res)), | |
P::Use => res = Type::Inter(Ref::new(t), Ref::new(res)), | |
} | |
} | |
match recursive.get(&vs_pol) { | |
None => res, | |
Some(name) => Type::Recursive { | |
name: name.clone(), | |
body: Ref::new(res), | |
}, | |
} | |
} | |
} | |
} | |
} | |
fn unique_name() -> Str { | |
format!("'{}", GLOBAL_COUNTER.fetch_add(1, Ordering::SeqCst)).into() | |
} | |
static GLOBAL_COUNTER: AtomicU64 = AtomicU64::new(0); | |
#[derive(Debug)] | |
enum List<T> { | |
Nil, | |
Item(Rc<(T, Self)>), | |
} | |
impl<T> Default for List<T> { | |
fn default() -> Self { | |
Self::new() | |
} | |
} | |
impl<T> Clone for List<T> { | |
fn clone(&self) -> Self { | |
match self { | |
List::Nil => List::Nil, | |
List::Item(rc) => List::Item(rc.clone()), | |
} | |
} | |
} | |
impl<T> List<T> { | |
pub fn new() -> Self { | |
List::Nil | |
} | |
} | |
impl<'a, T> Iterator for &'a List<T> { | |
type Item = &'a T; | |
fn next(&mut self) -> Option<&'a T> { | |
match self { | |
List::Nil => None, | |
List::Item(rc) => { | |
*self = &rc.1; | |
Some(&rc.0) | |
} | |
} | |
} | |
} | |
pub fn cons<T>(x: T, xs: List<T>) -> List<T> { | |
List::Item(Rc::new((x, xs))) | |
} | |
pub fn cons_<T>(x: T, xs: &RefCell<List<T>>) { | |
let xs_ = xs.borrow().clone(); | |
*xs.borrow_mut() = cons(x, xs_); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment