Skip to content

Instantly share code, notes, and snippets.

@sergv
Last active January 23, 2019 10:57
Show Gist options
  • Save sergv/58ea87d730963d9cf59de54312331a9e to your computer and use it in GitHub Desktop.
Save sergv/58ea87d730963d9cf59de54312331a9e to your computer and use it in GitHub Desktop.
Use recursion schemes to add type information to expressions in a modular way
----------------------------------------------------------------------------
-- Tested with ghc 8.2.2
----------------------------------------------------------------------------
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE UndecidableInstances #-}
import Control.Arrow ((&&&), first)
import Control.Monad
-- Generic helpers.
newtype Fix f = Fix (f (Fix f))
unFix :: Fix f -> f (Fix f)
unFix (Fix x) = x
deriving instance Eq (f (Fix f)) => Eq (Fix f)
deriving instance Ord (f (Fix f)) => Ord (Fix f)
deriving instance Show (f (Fix f)) => Show (Fix f)
cata :: Functor f => (f a -> a) -> Fix f -> a
cata alg = go
where
go = alg . fmap go . unFix
cataM :: (Traversable f, Monad m) => (f a -> m a) -> Fix f -> m a
cataM alg = go
where
go = alg <=< traverse go . unFix
para :: Functor f => (f (a, Fix f) -> a) -> Fix f -> a
para alg = go
where
go = alg . fmap (go &&& id) . unFix
paraM :: (Traversable f, Monad m) => (f (a, Fix f) -> m a) -> Fix f -> m a
paraM alg = go
where
go = alg <=< traverse (\x -> (, x) <$> go x) . unFix
data Cofree f a = a :< (f (Cofree f a))
deriving instance (Eq (f (Cofree f a)), Eq a) => Eq (Cofree f a)
deriving instance (Ord (f (Cofree f a)), Ord a) => Ord (Cofree f a)
deriving instance (Show (f (Cofree f a)), Show a) => Show (Cofree f a)
cataAnn :: Functor f => (b -> f a -> a) -> Cofree f b -> a
cataAnn alg = go
where
go (b :< rest) = alg b $ fmap go rest
-- The AST and rest of the compiler.
data ExprF e =
ILit Int
| BLit Bool
| Add e e
| Lt e e
| If e e e
deriving (Eq, Ord, Show, Functor, Foldable, Traversable)
data Type = TInt | TBool
deriving (Eq, Ord, Show)
-- Untyped expressions, e.g. right after parsing
type Expr = Fix ExprF
-- Typed expressions where each recursion level is annotated with a type.
type TExpr = Cofree ExprF Type
getType :: TExpr -> Type
getType (t :< _) = t
typecheckAlg :: ExprF (Type, Expr) -> Either String Type
typecheckAlg e = case fmap fst e of
ILit{} -> Right TInt
BLit{} -> Right TBool
Add TInt TInt -> Right TInt
Add lt rt -> Left $ malformedAddErr lt rt
Lt TInt TInt -> Right TBool
Lt lt rt -> Left $ malformedLtErr lt rt
If TBool tt ft
| tt == ft -> Right tt
| otherwise -> Left $ malformedIfBranchesErr tt ft
If ct _ _ -> Left $ malformedIfConditionErr ct
where
expr = Fix (fmap snd e)
malformedAddErr, malformedLtErr, malformedIfBranchesErr :: Type -> Type -> String
malformedAddErr lt rt =
"Malformed addition: left brach is of type " ++ show lt ++ " and right is of type " ++ show rt ++ ". Whole expression: " ++ show expr
malformedLtErr lt rt =
"Malformed comparison: left brach is of type " ++ show lt ++ " and right is of type " ++ show rt ++ ". Whole expression: " ++ show expr
malformedIfBranchesErr tt ft =
"Malformed condition: both branches must have the same type, but true brach is of type " ++ show tt ++ " and false branch is of type " ++ show ft ++ ". Whole expression: " ++ show expr
malformedIfConditionErr :: Type -> String
malformedIfConditionErr ct =
"Mall-formed condition: condition must have boolean type, but really it is of type " ++ show ct ++ ". Whole expression: " ++ show expr
-- | Check whether expression is well typed.
typecheck :: Expr -> Either String Type
typecheck = paraM typecheckAlg
-- | Check that expression is well-typed and annotate all its layers
-- with respective types.
typecheckAll :: Expr -> Either String TExpr
typecheckAll = paraM alg
where
alg :: ExprF (TExpr, Expr) -> Either String TExpr
alg e = do
t <- typecheckAlg $ fmap (first getType) e
pure $ t :< fmap fst e
-- 1 + 2
expr1 :: Expr
expr1 = Fix (Add (Fix (ILit 1)) (Fix (ILit 2)))
-- typecheckAll expr1
-- => Right (TInt :< Add (TInt :< ILit 1) (TInt :< ILit 2))
-- if (if (3 < 2) False True) (3 + 4) False <- note the error
expr2 :: Expr
expr2 =
Fix (If (Fix (If (Fix (Lt (Fix (ILit 3)) (Fix (ILit 2))))
(Fix (BLit False))
(Fix (BLit True))))
(Fix (Add (Fix (ILit 3))
(Fix (ILit 4))))
(Fix (BLit False)))
-- typecheckAll expr2
-- => Left "Malformed condition: both branches must have the same type, but true brach is of type TInt and false branch is of type TBool. Whole expression: Fix (If (Fix (If (Fix (Lt (Fix (ILit 3)) (Fix (ILit 2)))) (Fix (BLit False)) (Fix (BLit True)))) (Fix (Add (Fix (ILit 3)) (Fix (ILit 4)))) (Fix (BLit False)))"
-- if (if (3 < 2) False True) (3 + 4) 42 <- expr2 with error fixed
expr3 :: Expr
expr3 =
Fix (If (Fix (If (Fix (Lt (Fix (ILit 3)) (Fix (ILit 2))))
(Fix (BLit False))
(Fix (BLit True))))
(Fix (Add (Fix (ILit 3))
(Fix (ILit 4))))
(Fix (ILit 42)))
-- typecheck expr3
-- => Right TInt
-- typecheckAll expr3
-- => Right (TInt :< If (TBool :< If (TBool :< Lt (TInt :< ILit 3) (TInt :< ILit 2))
-- (TBool :< BLit False)
-- (TBool :< BLit True))
-- (TInt :< Add (TInt :< ILit 3) (TInt :< ILit 4))
-- (TInt :< ILit 42))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment