Last active
January 23, 2019 10:57
-
-
Save sergv/58ea87d730963d9cf59de54312331a9e to your computer and use it in GitHub Desktop.
Use recursion schemes to add type information to expressions in a modular way
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
---------------------------------------------------------------------------- | |
-- Tested with ghc 8.2.2 | |
---------------------------------------------------------------------------- | |
{-# LANGUAGE DeriveFoldable #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE DeriveTraversable #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE StandaloneDeriving #-} | |
{-# LANGUAGE TupleSections #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
import Control.Arrow ((&&&), first) | |
import Control.Monad | |
-- Generic helpers. | |
newtype Fix f = Fix (f (Fix f)) | |
unFix :: Fix f -> f (Fix f) | |
unFix (Fix x) = x | |
deriving instance Eq (f (Fix f)) => Eq (Fix f) | |
deriving instance Ord (f (Fix f)) => Ord (Fix f) | |
deriving instance Show (f (Fix f)) => Show (Fix f) | |
cata :: Functor f => (f a -> a) -> Fix f -> a | |
cata alg = go | |
where | |
go = alg . fmap go . unFix | |
cataM :: (Traversable f, Monad m) => (f a -> m a) -> Fix f -> m a | |
cataM alg = go | |
where | |
go = alg <=< traverse go . unFix | |
para :: Functor f => (f (a, Fix f) -> a) -> Fix f -> a | |
para alg = go | |
where | |
go = alg . fmap (go &&& id) . unFix | |
paraM :: (Traversable f, Monad m) => (f (a, Fix f) -> m a) -> Fix f -> m a | |
paraM alg = go | |
where | |
go = alg <=< traverse (\x -> (, x) <$> go x) . unFix | |
data Cofree f a = a :< (f (Cofree f a)) | |
deriving instance (Eq (f (Cofree f a)), Eq a) => Eq (Cofree f a) | |
deriving instance (Ord (f (Cofree f a)), Ord a) => Ord (Cofree f a) | |
deriving instance (Show (f (Cofree f a)), Show a) => Show (Cofree f a) | |
cataAnn :: Functor f => (b -> f a -> a) -> Cofree f b -> a | |
cataAnn alg = go | |
where | |
go (b :< rest) = alg b $ fmap go rest | |
-- The AST and rest of the compiler. | |
data ExprF e = | |
ILit Int | |
| BLit Bool | |
| Add e e | |
| Lt e e | |
| If e e e | |
deriving (Eq, Ord, Show, Functor, Foldable, Traversable) | |
data Type = TInt | TBool | |
deriving (Eq, Ord, Show) | |
-- Untyped expressions, e.g. right after parsing | |
type Expr = Fix ExprF | |
-- Typed expressions where each recursion level is annotated with a type. | |
type TExpr = Cofree ExprF Type | |
getType :: TExpr -> Type | |
getType (t :< _) = t | |
typecheckAlg :: ExprF (Type, Expr) -> Either String Type | |
typecheckAlg e = case fmap fst e of | |
ILit{} -> Right TInt | |
BLit{} -> Right TBool | |
Add TInt TInt -> Right TInt | |
Add lt rt -> Left $ malformedAddErr lt rt | |
Lt TInt TInt -> Right TBool | |
Lt lt rt -> Left $ malformedLtErr lt rt | |
If TBool tt ft | |
| tt == ft -> Right tt | |
| otherwise -> Left $ malformedIfBranchesErr tt ft | |
If ct _ _ -> Left $ malformedIfConditionErr ct | |
where | |
expr = Fix (fmap snd e) | |
malformedAddErr, malformedLtErr, malformedIfBranchesErr :: Type -> Type -> String | |
malformedAddErr lt rt = | |
"Malformed addition: left brach is of type " ++ show lt ++ " and right is of type " ++ show rt ++ ". Whole expression: " ++ show expr | |
malformedLtErr lt rt = | |
"Malformed comparison: left brach is of type " ++ show lt ++ " and right is of type " ++ show rt ++ ". Whole expression: " ++ show expr | |
malformedIfBranchesErr tt ft = | |
"Malformed condition: both branches must have the same type, but true brach is of type " ++ show tt ++ " and false branch is of type " ++ show ft ++ ". Whole expression: " ++ show expr | |
malformedIfConditionErr :: Type -> String | |
malformedIfConditionErr ct = | |
"Mall-formed condition: condition must have boolean type, but really it is of type " ++ show ct ++ ". Whole expression: " ++ show expr | |
-- | Check whether expression is well typed. | |
typecheck :: Expr -> Either String Type | |
typecheck = paraM typecheckAlg | |
-- | Check that expression is well-typed and annotate all its layers | |
-- with respective types. | |
typecheckAll :: Expr -> Either String TExpr | |
typecheckAll = paraM alg | |
where | |
alg :: ExprF (TExpr, Expr) -> Either String TExpr | |
alg e = do | |
t <- typecheckAlg $ fmap (first getType) e | |
pure $ t :< fmap fst e | |
-- 1 + 2 | |
expr1 :: Expr | |
expr1 = Fix (Add (Fix (ILit 1)) (Fix (ILit 2))) | |
-- typecheckAll expr1 | |
-- => Right (TInt :< Add (TInt :< ILit 1) (TInt :< ILit 2)) | |
-- if (if (3 < 2) False True) (3 + 4) False <- note the error | |
expr2 :: Expr | |
expr2 = | |
Fix (If (Fix (If (Fix (Lt (Fix (ILit 3)) (Fix (ILit 2)))) | |
(Fix (BLit False)) | |
(Fix (BLit True)))) | |
(Fix (Add (Fix (ILit 3)) | |
(Fix (ILit 4)))) | |
(Fix (BLit False))) | |
-- typecheckAll expr2 | |
-- => Left "Malformed condition: both branches must have the same type, but true brach is of type TInt and false branch is of type TBool. Whole expression: Fix (If (Fix (If (Fix (Lt (Fix (ILit 3)) (Fix (ILit 2)))) (Fix (BLit False)) (Fix (BLit True)))) (Fix (Add (Fix (ILit 3)) (Fix (ILit 4)))) (Fix (BLit False)))" | |
-- if (if (3 < 2) False True) (3 + 4) 42 <- expr2 with error fixed | |
expr3 :: Expr | |
expr3 = | |
Fix (If (Fix (If (Fix (Lt (Fix (ILit 3)) (Fix (ILit 2)))) | |
(Fix (BLit False)) | |
(Fix (BLit True)))) | |
(Fix (Add (Fix (ILit 3)) | |
(Fix (ILit 4)))) | |
(Fix (ILit 42))) | |
-- typecheck expr3 | |
-- => Right TInt | |
-- typecheckAll expr3 | |
-- => Right (TInt :< If (TBool :< If (TBool :< Lt (TInt :< ILit 3) (TInt :< ILit 2)) | |
-- (TBool :< BLit False) | |
-- (TBool :< BLit True)) | |
-- (TInt :< Add (TInt :< ILit 3) (TInt :< ILit 4)) | |
-- (TInt :< ILit 42)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment