Last active
May 14, 2026 06:46
-
-
Save neel-krishnaswami/7241bb1b7bce4275ac9318b6cf8b53cd to your computer and use it in GitHub Desktop.
An example of bidirectional type inference which doesn't stop on error, and instead annotates subterm with error info. This is very useful for building LSP services!
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| (* Code associated with my blog post, "Bidirectional Typechecking That Does Not Stop", at | |
| https://semantic-domain.blogspot.com/2026/05/bidirectional-typechecking-that-does.html | |
| *) | |
| type tp = Bool | Arrow of tp * tp | Tuple of tp list | |
| let rec dist os = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match os with | |
| | [] -> return [] | |
| | o :: os -> let+ x = o in | |
| let+ xs = dist os in | |
| return (x :: xs) | |
| module Exp = struct | |
| type var = string | |
| type 'a exp' = | |
| | BLit of bool | |
| | If of 'a t * 'a t * 'a t | |
| | Lam of var * 'a t | |
| | App of 'a t * 'a t | |
| | Annot of 'a t * tp | |
| | Var of var | |
| | Let of var * 'a t * 'a t | |
| | Tuple of 'a t list | |
| | LetTuple of var list * 'a t * 'a t | |
| and 'a t = In of ('a * 'a exp') | |
| let shape (In(info, e')) = e' | |
| let info (In(info, e')) = info | |
| let make info e' = In(info, e') | |
| let update f e = make (f (info e)) (shape e) | |
| module Mk = struct | |
| let blit info b = make info (BLit b) | |
| let if' info e1 e2 e3 = make info (If(e1, e2, e3)) | |
| let lam info x e = make info (Lam(x, e)) | |
| let app info e1 e2 = make info (App(e1, e2)) | |
| let annot info e tp = make info (Annot(e, tp)) | |
| let var info x = make info (Var x) | |
| let let' info x e1 e2 = make info (Let(x, e1, e2)) | |
| let tuple info es = make info (Tuple es) | |
| let lettuple info xs e1 e2 = make info (LetTuple(xs, e1, e2)) | |
| end | |
| end | |
| module Get = struct | |
| let bool tp = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match tp with | |
| | Bool -> return () | |
| | _ -> fail | |
| let arrow tp = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match tp with | |
| | Arrow(tp1, tp2) -> return (tp1, tp2) | |
| | _ -> fail | |
| let tuple n tp = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match tp with | |
| | Tuple ts when List.length ts = n -> return ts | |
| | _ -> fail | |
| let eq tp1 tp2 = | |
| if tp1 = tp2 then Some () else None | |
| end | |
| module Basic = struct | |
| open Exp | |
| type ctx = (var * tp) list | |
| let lookup ctx x = List.assoc_opt x ctx | |
| (* check : ctx -> exp -> tp -> unit option *) | |
| let rec check ctx e tp = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match shape e with | |
| | BLit b -> Get.bool tp | |
| | If(e1, e2, e3) -> | |
| let+ tp1 = synth ctx e1 in | |
| let+ () = Get.bool tp1 in | |
| let+ () = check ctx e2 tp in | |
| let+ () = check ctx e3 tp in | |
| return () | |
| | Lam(x, e') -> | |
| let+ (tp1, tp2) = Get.arrow tp in | |
| check ((x,tp1) :: ctx) e' tp2 | |
| | Tuple es -> | |
| let+ tps = Get.tuple (List.length es) tp in | |
| let+ _ = dist (List.map2 (check ctx) es tps) in | |
| return () | |
| | LetTuple(xs, e1, e2) -> | |
| let+ tp1 = synth ctx e1 in | |
| let+ tps = Get.tuple (List.length xs) tp1 in | |
| let+ () = check (List.combine xs tps @ ctx) e2 tp in | |
| return () | |
| | Let(x, e1, e2) -> | |
| let+ tp1 = synth ctx e1 in | |
| check ((x, tp1) :: ctx) e2 tp | |
| | _ -> | |
| let+ tp' = synth ctx e in | |
| let+ () = Get.eq tp tp' in | |
| return () | |
| (* val synth : ctx -> exp -> (tp, Error.t) result *) | |
| and synth ctx e = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match shape e with | |
| | App(e1, e2) -> | |
| let+ tp1 = synth ctx e1 in | |
| let+ (tp2, tp) = Get.arrow tp1 in | |
| let+ () = check ctx e2 tp2 in | |
| return tp | |
| | Var x -> lookup ctx x | |
| | Annot(e, tp) -> | |
| let+ () = check ctx e tp in | |
| return tp | |
| | _ -> fail | |
| end | |
| module ElabFail = struct | |
| open Exp | |
| type ctx = (var * tp) list | |
| let lookup ctx x = List.assoc_opt x ctx | |
| (* check : ctx -> 'a exp -> tp -> tp exp option *) | |
| let rec check ctx e tp = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match shape e with | |
| | BLit b -> | |
| let+ () = Get.bool tp in | |
| return (Mk.blit tp b) | |
| | Lam(x, e') -> | |
| let+ (tp1, tp2) = Get.arrow tp in | |
| let+ t = check ((x,tp1) :: ctx) e' tp2 in | |
| return (Mk.lam tp x t) | |
| | If(e1, e2, e3) -> | |
| let+ t1 = synth ctx e1 in | |
| let+ () = Get.bool (info t1) in | |
| let+ t2 = check ctx e2 tp in | |
| let+ t3 = check ctx e3 tp in | |
| return (Mk.if' tp t1 t2 t3) | |
| | Tuple es -> | |
| let+ tps = Get.tuple (List.length es) tp in | |
| let+ ts = dist (List.map2 (check ctx) es tps) in | |
| return (Mk.tuple tp ts) | |
| | LetTuple(xs, e1, e2) -> | |
| let+ t1 = synth ctx e1 in | |
| let+ tps = Get.tuple (List.length xs) (info t1) in | |
| let+ t2 = check (List.combine xs tps @ ctx) e2 tp in | |
| return (Mk.lettuple tp xs t1 t2) | |
| | Let(x, e1, e2) -> | |
| let+ t1 = synth ctx e1 in | |
| let+ t2 = check ((x, info t1) :: ctx) e2 tp in | |
| return (Mk.let' tp x t1 t2) | |
| | _ -> | |
| let+ t = synth ctx e in | |
| let+ () = Get.eq (info t) tp in | |
| return t | |
| (* synth : ctx -> 'a exp -> tp exp option *) | |
| and synth ctx e = | |
| let (return, fail, (let+)) = Option.(some, none, bind) in | |
| match shape e with | |
| | App(e1, e2) -> | |
| let+ t1 = synth ctx e1 in | |
| let+ (tp2, tp) = Get.arrow (info t1) in | |
| let+ t2 = check ctx e2 tp2 in | |
| return (Mk.app tp t1 t2) | |
| | Var x -> let+ tp = lookup ctx x in | |
| return (Mk.var tp x) | |
| | Annot(e, tp) -> | |
| let+ t = check ctx e tp in | |
| return (Mk.annot tp t tp) | |
| | _ -> fail | |
| end | |
| module TpView = struct | |
| type 'a t = { | |
| get : (tp option -> 'a); | |
| put : ('a -> tp option); | |
| } | |
| let bool = | |
| let get tp' = | |
| match tp' with | |
| | Some Bool -> Some () | |
| | _ -> None | |
| in | |
| let put tp' = | |
| match tp' with | |
| | Some () -> Some Bool | |
| | None -> None | |
| in | |
| {get; put} | |
| let arrow = | |
| let get tp' = | |
| match tp' with | |
| | Some (Arrow(tp1, tp2)) -> (Some tp1, Some tp2) | |
| | _ -> (None, None) | |
| in | |
| let put (tp1', tp2') = | |
| match tp1', tp2' with | |
| | Some tp1, Some tp2 -> Some (Arrow(tp1, tp2)) | |
| | _, _ -> None | |
| in | |
| {get; put} | |
| let tuple n = | |
| let get = function | |
| | Some (Tuple tps) when List.length tps = n -> List.map Option.some tps | |
| | _ -> List.init n (fun _ -> None) | |
| in | |
| let put tps' = | |
| match dist tps' with | |
| | Some tps when List.length tps = n -> Some (Tuple tps) | |
| | _ -> None | |
| in | |
| {get; put} | |
| let eq (tp' : tp option) = | |
| let get = function | |
| | tp'' when tp' = tp'' -> Some () | |
| | _ -> None | |
| in | |
| let put = function | |
| | Some () -> tp' | |
| | _ -> None | |
| in | |
| {get; put} | |
| let is view tp = view.put (view.get tp) | |
| end | |
| module ElabNonstop = struct | |
| open TpView | |
| open Exp | |
| type ctx = (var * tp option) list | |
| let lookup ctx x = Option.join (List.assoc_opt x ctx) | |
| (* check : ctx -> 'a exp -> tp option -> tp exp option *) | |
| let rec check ctx e tp = | |
| match shape e with | |
| | BLit b -> | |
| Mk.blit (is bool tp) b | |
| | Lam(x, e') -> | |
| let (tp1, tp2) = arrow.get tp in | |
| let t = check ((x,tp1) :: ctx) e' tp2 in | |
| Mk.lam (is arrow tp) x t | |
| | If(e1, e2, e3) -> | |
| let t1 = update (is bool) (synth ctx e1) in | |
| let t2 = check ctx e2 tp in | |
| let t3 = check ctx e3 tp in | |
| Mk.if' tp t1 t2 t3 | |
| | Tuple es -> | |
| let n = List.length es in | |
| let tps = (tuple n).get tp in | |
| let ts = List.map2 (check ctx) es tps in | |
| Mk.tuple (is (tuple n) tp) ts | |
| | LetTuple(xs, e1, e2) -> | |
| let n = List.length xs in | |
| let t1 = synth ctx e1 in | |
| let tps = (tuple n).get (info t1) in | |
| let t2 = check (List.combine xs tps @ ctx) e2 tp in | |
| Mk.lettuple tp xs t1 t2 | |
| | Let(x, e1, e2) -> | |
| let t1 = synth ctx e1 in | |
| let t2 = check ((x, info t1) :: ctx) e2 tp in | |
| Mk.let' tp x t1 t2 | |
| | _ -> | |
| update (is (eq tp)) (synth ctx e) | |
| (* synth : ctx -> 'a exp -> tp option exp *) | |
| and synth ctx e = | |
| match shape e with | |
| | App(e1, e2) -> | |
| let t1 = synth ctx e1 in | |
| let (tp2, tp) = arrow.get (info t1) in | |
| let t2 = check ctx e2 tp2 in | |
| Mk.app tp t1 t2 | |
| | Var x -> | |
| let tp = lookup ctx x in | |
| Mk.var tp x | |
| | Annot(e, tp) -> | |
| let t = check ctx e (Some tp) in | |
| Mk.annot (Some tp) t tp | |
| | _ -> check ctx e None | |
| end | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment