Last active
September 4, 2024 21:32
-
-
Save gelisam/90ca9dff6906abacf11387d960d46cab to your computer and use it in GitHub Desktop.
A recursion scheme for mutually-recursive types
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- Defining a custom recursion scheme to manipulate two mutually-recursive | |
-- types, in the context of a toy bidirectional type checker. | |
{-# LANGUAGE DerivingStrategies, GeneralizedNewtypeDeriving, ScopedTypeVariables #-} | |
module Main where | |
import Test.DocTest | |
import Control.Monad (when) | |
import Data.Bifunctor (Bifunctor(bimap)) | |
import Data.Bifoldable (Bifoldable(bifoldMap), bitraverse_) | |
import qualified Data.List as List | |
data Type | |
= Arr Type Type | |
| UnitType | |
| IntType | |
deriving (Eq, Show) | |
type Context = [(String, Type)] | |
-- Let's begin by defining our language's terms. | |
-- | |
-- data Term | |
-- = Lam String Term | |
-- | Var String | |
-- | App Term Term | |
-- | Ann Term Type | |
-- | UnitTerm | |
-- | IntTerm Int | |
-- | Negate Term | |
-- | Add Term Term | |
-- | Sub Term Term | |
-- | Mul Term Term | |
-- | |
-- The definition is commented out because there are two changes I want to | |
-- make. I want to partition it into "normal" and "neutral" forms: | |
-- | |
-- data Normal | |
-- = Lam String Normal | |
-- | Neu Neutral | |
-- | |
-- data Neutral | |
-- = Var String | |
-- | App Neutral Normal | |
-- | Ann Normal Type | |
-- | UnitTerm | |
-- | IntTerm Int | |
-- | Negate Normal | |
-- | Add Normal Normal | |
-- | Sub Normal Normal | |
-- | Mul Normal Normal | |
-- | |
-- And I want to turn it into a base functor: | |
-- | |
-- data TermF term | |
-- = Lam String term | |
-- | Var String | |
-- | App term term | |
-- | Ann term Type | |
-- | UnitTerm | |
-- | IntTerm Int | |
-- | Negate term | |
-- | Add term term | |
-- | Sub term term | |
-- | Mul term term | |
-- Let's discuss the partition first. This is a relatively common trick which | |
-- avoids the two catch-all cases you'd otherwise find at the end of the | |
-- typechecker's 'check' and 'infer' functions: | |
-- | |
-- check :: Context -> Term -> Type -> Either String () | |
-- check ctx (Lam x body) tp = ... | |
-- check ... | |
-- check ctx term expectedType = do | |
-- actualType <- infer ctx term | |
-- when (actualType /= expectedType) $ do | |
-- Left $ "expected " ++ show expectedType | |
-- ++ ", found " ++ show actualType | |
-- | |
-- infer :: Context -> Term -> Either String Type | |
-- infer ctx (Var x) = ... | |
-- infer ... | |
-- infer ctx term = do | |
-- Left $ "ambiguous type, please add a type annotation " | |
-- ++ "around " ++ show term | |
-- | |
-- Thanks to the partitioned representation, the 'check' catch-all case is now | |
-- a regular case (the 'Neu' case), and the 'infer' catch-all case is dropped | |
-- entirely because the error condition is now unrepresentable. Cool! | |
-- Next, the base functor representation. A single recursive type can be | |
-- turned into a base functor by adding a type parameter and replacing all the | |
-- recursive occurrences of the type by that type paramter. That trick doesn't | |
-- work here, because the partitioned representation means we now have _two_ | |
-- mutually-recursive types, and thus two different recursive occurrences to | |
-- replace. We must thus add _two_ type parameters, replacing each recursive | |
-- occurrence with the appropriate one. The result is two base _bifunctors_! | |
data NormalF normal neutral | |
= Lam String normal | |
| Neu neutral | |
deriving Show | |
data NeutralF normal neutral | |
= Var String | |
| App neutral normal | |
| Ann normal Type | |
| UnitTerm | |
| IntTerm Int | |
| Negate normal | |
| Add normal normal | |
| Sub normal normal | |
| Mul normal normal | |
deriving Show | |
instance Bifunctor NormalF where | |
bimap f _ (Lam name x) | |
= Lam name (f x) | |
bimap _ g (Neu y) | |
= Neu (g y) | |
instance Bifunctor NeutralF where | |
bimap _ _ (Var name) | |
= Var name | |
bimap f g (App y x) | |
= App (g y) (f x) | |
bimap f _ (Ann x tp) | |
= Ann (f x) tp | |
bimap _ _ UnitTerm | |
= UnitTerm | |
bimap _ _ (IntTerm n) | |
= IntTerm n | |
bimap f _ (Negate x) | |
= Negate (f x) | |
bimap f _ (Add x1 x2) | |
= Add (f x1) (f x2) | |
bimap f _ (Sub x1 x2) | |
= Sub (f x1) (f x2) | |
bimap f _ (Mul x1 x2) | |
= Mul (f x1) (f x2) | |
instance Bifoldable NormalF where | |
bifoldMap f _ (Lam _ x) | |
= f x | |
bifoldMap _ g (Neu y) | |
= g y | |
instance Bifoldable NeutralF where | |
bifoldMap _ _ (Var _) | |
= mempty | |
bifoldMap f g (App y x) | |
= g y <> f x | |
bifoldMap f _ (Ann x _) | |
= f x | |
bifoldMap _ _ UnitTerm | |
= mempty | |
bifoldMap _ _ (IntTerm _) | |
= mempty | |
bifoldMap f _ (Negate x) | |
= f x | |
bifoldMap f _ (Add x1 x2) | |
= f x1 <> f x2 | |
bifoldMap f _ (Sub x1 x2) | |
= f x1 <> f x2 | |
bifoldMap f _ (Mul x1 x2) | |
= f x1 <> f x2 | |
newtype Normal = Normal | |
{ unNormal :: NormalF Normal Neutral } | |
deriving newtype Show | |
newtype Neutral = Neutral | |
{ unNeutral :: NeutralF Normal Neutral } | |
deriving newtype Show | |
-- Next, I want to define a recursion scheme which captures the mutual | |
-- recursion between 'check' and 'infer'. The usual recursion scheme for | |
-- capturing mutual recursion is a mutumorphism: | |
-- | |
-- mutu | |
-- :: (f (a,b) -> a) -- first algebra | |
-- -> (f (a,b) -> b) -- second algebra | |
-- -> ( Fix f -> a -- first mutually-recursive function | |
-- , Fix f -> b -- second mutually-recursive function | |
-- ) | |
-- | |
-- The idea is that each algebra has access to the recursive results for both | |
-- of the mutually-recursive functions, which is the recursion-schemes | |
-- equivalent of both functions being able to call each other. | |
-- | |
-- Because of our partitioned representation, however, we need a variant of | |
-- 'mutu' in which each algebra manipulates a different 'f'. This variant is | |
-- pretty easy to define: | |
myMutu | |
:: forall a b | |
. (NormalF a b -> a) | |
-> (NeutralF a b -> b) | |
-> ( Normal -> a | |
, Neutral -> b | |
) | |
myMutu algA algB | |
= (fA, fB) | |
where | |
fA :: Normal -> a | |
fA = algA . bimap fA fB . unNormal | |
fB :: Neutral -> b | |
fB = algB . bimap fA fB . unNeutral | |
-- | | |
-- We are now ready to start implementing the type checker. Using it to check | |
-- that the expression | |
-- | |
-- \s z -> s (s z) | |
-- | |
-- has (among others) the type | |
-- | |
-- (Int -> Int) -> Int -> Int | |
-- | |
-- will look like this: | |
-- | |
-- >>> :{ | |
-- check [] | |
-- ( Normal $ Lam "s" | |
-- $ Normal $ Lam "z" | |
-- $ Normal $ Neu | |
-- $ Neutral $ App (Neutral $ Var "s") | |
-- $ Normal $ Neu | |
-- $ Neutral $ App (Neutral $ Var "s") | |
-- $ Normal $ Neu | |
-- $ Neutral $ Var "z" | |
-- ) | |
-- (Arr (Arr IntType IntType) (Arr IntType IntType)) | |
-- :} | |
-- Right () | |
check | |
:: Context | |
-> Normal | |
-> Type | |
-> Either String () | |
infer | |
:: Context | |
-> Neutral | |
-> Either String Type | |
(check, infer) | |
= -- I want my 'check' and 'infer' functions to have the standard API in | |
-- which the context is given before the term, just like in the typing | |
-- judgement "Γ ⊢ e : 𝜏". However, 'myMutu' requires the term to be the | |
-- first argument, so a bit of argument juggling is needed. | |
( \ctx normal tp -> check' normal ctx tp | |
, \ctx neutral -> infer' neutral ctx | |
) | |
where | |
-- The reason the term must come first is because 'myMutu' requires the | |
-- two algebras to have these types: | |
-- | |
-- Normal a b -> a | |
-- Neutral a b -> b | |
-- | |
-- Thus all the remaining arguments must be obtained by specializing 'a' | |
-- and 'b' to function types: | |
-- | |
-- a ~ (Context -> Type -> Either String ()) | |
-- b ~ (Context -> Either String Type) | |
check' | |
:: Normal | |
-> Context | |
-> Type | |
-> Either String () | |
infer' | |
:: Neutral | |
-> Context | |
-> Either String Type | |
(check', infer') | |
= -- I still want to use the natural parameter order in my | |
-- implementation of the two algebras though, so more argument | |
-- juggling is needed. | |
myMutu | |
(\normal ctx tp -> checkF ctx normal tp) | |
(\neutral ctx -> inferF ctx neutral) | |
-- We are finally ready to implement the two algebras! Depending on | |
-- whether a recursive position normally contains a normal or a neutral | |
-- term, that position in the base bifunctor either contains the result of | |
-- partially applying 'check' or partially applying 'infer' to that | |
-- sub-term. Thus, we need to supply the remaining arguments in order to | |
-- get the resulting @Either String ()@ or @Either String Type@. | |
checkF | |
:: Context | |
-> NormalF | |
(Context -> Type -> Either String ()) | |
(Context -> Either String Type) | |
-> Type | |
-> Either String () | |
checkF ctx (Lam x checkBody) tp = do | |
case tp of | |
Arr tArg tOut -> do | |
checkBody ((x,tArg) : ctx) tOut | |
_ -> do | |
Left $ "expected " ++ show tp ++ ", found lambda" | |
checkF ctx (Neu inferNeutral) tp = do | |
tp' <- inferNeutral ctx | |
when (tp /= tp') $ do | |
Left $ "expected " ++ show tp ++ ", found " ++ show tp' | |
inferF | |
:: Context | |
-> NeutralF | |
(Context -> Type -> Either String ()) | |
(Context -> Either String Type) | |
-> Either String Type | |
inferF ctx (Var name) = do | |
case List.lookup name ctx of | |
Nothing -> do | |
Left $ "variable " ++ show name ++ " not in scope" | |
Just tp -> do | |
pure tp | |
inferF ctx (App inferFun checkArg) = do | |
tp <- inferFun ctx | |
case tp of | |
Arr tArg tOut -> do | |
checkArg ctx tArg | |
pure tOut | |
_ -> do | |
Left $ show tp ++ " is not a function" | |
inferF ctx (Ann checkTerm tp) = do | |
checkTerm ctx tp | |
pure tp | |
inferF _ UnitTerm = do | |
pure UnitType | |
inferF ctx numericTerm = do | |
-- And now, the moment we are all waiting for: the payoff! All of this | |
-- was a lot more verbose than the general-recursion version, so we | |
-- better get something in return. | |
-- | |
-- What we get is the ability to write a single generic case for | |
-- 'IntTerm', 'Negate', 'Add', 'Sub', and 'Mul'. We can do that because | |
-- we can use the base bifunctor's Bifoldable instance to traverse all | |
-- of the sub-terms, without having to know how many sub-terms there are | |
-- nor which ones are normal or neutral. | |
let onNormal checkSubTerm = do | |
checkSubTerm ctx IntType | |
onNeutral inferSubTerm = do | |
actualType <- inferSubTerm ctx | |
when (actualType /= IntType) $ do | |
Left $ "numeric operation requires Int argument, " | |
++ "found " ++ show actualType | |
bitraverse_ onNormal onNeutral numericTerm | |
pure IntType | |
-- So, was all of that worth it, just to save on a few easy cases at the end | |
-- of the type checker? It depends! Some languages have a very large number of | |
-- terms, so the cost might be worth paying in that case. Some teams value | |
-- simplicity over succinctness and so will never be ready to pay the cost. | |
-- | |
-- The real benefit, really, is the knowledge gained along the way: whether | |
-- you're writing a typechecker or something else, you now have another tool | |
-- in your toolbox, waiting for you to encounter a challenge it solves well. | |
main :: IO () | |
main = do | |
putStrLn "typechecks." | |
test :: IO () | |
test = do | |
doctest ["src/Main.hs"] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment