Created
July 13, 2024 17:55
-
-
Save kccqzy/fa8a8ae12a198b41c6339e8a5c45978a to your computer and use it in GitHub Desktop.
A toy implementation of Algorithm W for HN readers (modified from https://github.com/wh5a/Algorithm-W-Step-By-Step/blob/master/AlgorithmW.lhs)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Control.Monad.Except | |
import Control.Monad.State | |
import qualified Data.IntMap as IntMap | |
import qualified Data.IntSet as IntSet | |
import qualified Data.Map as Map | |
data Exp | |
= EVar String | |
| ELit Lit | |
| EApp Exp Exp | |
| EAbs String Exp | |
| ELet String Exp Exp | |
| EPlus Exp Exp | |
| EPlusFun | |
deriving (Eq, Ord, Show) | |
data Lit | |
= LInt Integer | |
| LBool Bool | |
deriving (Eq, Ord, Show) | |
data Type | |
= TVar Int | |
| TInt | |
| TBool | |
| TFun Type Type | |
deriving (Eq, Ord, Show) | |
-- | A polytype is a (qualified) type with a number of quantifiers (foralls) in front of it. | |
-- For example the identity function has type "forall x. x -> x" which is represented by having | |
-- the universally quantified "x" in the set. This set is only generated by the @generalize@ | |
-- function. | |
-- | |
-- Furthermore, PolyTypes can only appear at the top level. High-rank polymorphism is not allowed. | |
data PolyType = PolyType IntSet.IntSet Type deriving (Show) | |
-- | A type environment is a map from term variable to its scheme. | |
newtype TypeEnv = TypeEnv {getTypeEnv :: Map.Map String PolyType} | |
-- | A substitution is a map from type variable to its type. | |
newtype Subst = Subst (IntMap.IntMap Type) deriving (Show) | |
instance Semigroup Subst where | |
(Subst s1) <> (Subst s2) = Subst (IntMap.map (applySubstToType (Subst s1)) s2 `IntMap.union` s1) | |
instance Monoid Subst where | |
mempty = Subst mempty | |
-- | @ftvFromType@ computes the set of free type variables in a type. | |
ftvFromType :: Type -> IntSet.IntSet | |
ftvFromType (TVar n) = IntSet.singleton n | |
ftvFromType TInt = mempty | |
ftvFromType TBool = mempty | |
ftvFromType (TFun t1 t2) = ftvFromType t1 `IntSet.union` ftvFromType t2 | |
-- | @applySubstToType@ applies a substitution to a type. | |
applySubstToType :: Subst -> Type -> Type | |
applySubstToType (Subst s) (TVar n) = case IntMap.lookup n s of | |
Nothing -> TVar n | |
Just t -> t | |
applySubstToType s (TFun t1 t2) = TFun (applySubstToType s t1) (applySubstToType s t2) | |
applySubstToType _ t = t | |
ftvFromPolyType :: PolyType -> IntSet.IntSet | |
ftvFromPolyType (PolyType vars t) = ftvFromType t `IntSet.difference` vars | |
applySubstToPolyType :: Subst -> PolyType -> PolyType | |
applySubstToPolyType (Subst s) (PolyType vars t) = PolyType vars (applySubstToType (Subst (IntMap.withoutKeys s vars)) t) | |
ftvFromTypeEnv :: TypeEnv -> IntSet.IntSet | |
ftvFromTypeEnv (TypeEnv env) = foldMap ftvFromPolyType env | |
applySubstToTypeEnv :: Subst -> TypeEnv -> TypeEnv | |
applySubstToTypeEnv s (TypeEnv env) = TypeEnv (applySubstToPolyType s <$> env) | |
-- | @generalize@ constructs a polytype by finding free variables in the type that are not free variables in the environment and making universally quantified. | |
generalize :: TypeEnv -> Type -> PolyType | |
generalize env t = PolyType (ftvFromType t `IntSet.difference` ftvFromTypeEnv env) t | |
-- | @instantiate@ replaces universally quantified type variables in a type scheme with fresh type variables. | |
instantiate :: PolyType -> TypeCheck Type | |
instantiate (PolyType vars t) = do | |
ns <- traverse (const newTyVar) (IntMap.fromSet (const ()) vars) | |
return $ applySubstToType (Subst ns) t | |
newtype TIState = TIState Int | |
data TIError | |
= ErrorTypeUnify Type Type | |
| ErrorOccursCheck Int Type | |
| ErrorUnboundVariable String | |
| ErrorContext Exp TIError | |
deriving (Show) | |
addErrorContext :: Exp -> TypeCheck a -> TypeCheck a | |
addErrorContext e action = action `catchError` \err -> throwError (ErrorContext e err) | |
type TypeCheck a = ExceptT TIError (State TIState) a | |
runTI :: TypeCheck a -> Either TIError a | |
runTI t = | |
evalState (runExceptT t) (TIState 0) | |
newTyVar :: TypeCheck Type | |
newTyVar = do | |
TIState s <- get | |
put (TIState (s + 1)) | |
return (TVar s) | |
-- | @typeUnify@ unifies two types. | |
typeUnify :: Type -> Type -> TypeCheck Subst | |
typeUnify (TFun l r) (TFun l' r') = do | |
s1 <- typeUnify l l' | |
s2 <- typeUnify (applySubstToType s1 r) (applySubstToType s1 r') | |
return (s1 <> s2) | |
typeUnify (TVar u) t = varBind u t | |
typeUnify t (TVar u) = varBind u t | |
typeUnify TInt TInt = return mempty | |
typeUnify TBool TBool = return mempty | |
typeUnify t1 t2 = throwError (ErrorTypeUnify t1 t2) | |
-- | @varBind@ binds a type variable to a type, but avoids binding that type variable to itself. | |
-- Also performs the occurs check (infinite type). | |
varBind :: Int -> Type -> TypeCheck Subst | |
varBind u t | |
| t == TVar u = return mempty | |
| u `IntSet.member` ftvFromType t = | |
throwError (ErrorOccursCheck u t) | |
| otherwise = return (Subst (IntMap.singleton u t)) | |
-- | @ti@ performs type inference for an expression. Notably, it returns types not polytypes. | |
ti :: TypeEnv -> Exp -> TypeCheck (Subst, Type) | |
ti (TypeEnv env) (EVar n) = | |
case Map.lookup n env of | |
Nothing -> throwError (ErrorUnboundVariable n) | |
Just sigma -> do | |
t <- instantiate sigma | |
return (mempty, t) | |
ti _ (ELit (LInt _)) = pure (mempty, TInt) | |
ti _ (ELit (LBool _)) = pure (mempty, TBool) | |
ti env e@(EAbs n body) = addErrorContext e $ do | |
tv <- newTyVar | |
let env' = Map.delete n (getTypeEnv env) | |
env'' = TypeEnv (Map.insert n (PolyType mempty tv) env') | |
(s1, t1) <- ti env'' body | |
return (s1, TFun (applySubstToType s1 tv) t1) | |
ti env e@(EApp e1 e2) = addErrorContext e $ do | |
tv <- newTyVar | |
(s1, t1) <- ti env e1 | |
(s2, t2) <- ti (applySubstToTypeEnv s1 env) e2 | |
s3 <- typeUnify (applySubstToType s2 t1) (TFun t2 tv) | |
return (s3 <> s2 <> s1, applySubstToType s3 tv) | |
ti env e@(EPlus e1 e2) = addErrorContext e $ | |
-- The built-in plus operator has type Int -> Int -> Int which is represented | |
-- by the special value EPlusFun. Therefore we can do the same as EApp twice. | |
ti env (EApp (EApp EPlusFun e1) e2) | |
-- If we were to inline this, we would find that for the inner EApp, s1 = | |
-- mempty, t1 = TFun TInt (TFun TInt TInt), (s2, t2) is normally inferred, s3 | |
-- is the result of unifying t1 against (TFun t2 tv) which results in a | |
-- substitution of t2 -> TInt, tv -> TFun TInt TInt. In the outer EApp, s1 is | |
-- the substitution from the inner call, t1 = TFun TInt TInt, (s2, t2) is | |
-- normally inferred, s3 is the result of unifying t1 against (TFun t2 tv) | |
-- which results in a substitution t2 -> TInt, rv -> TInt. | |
ti _ EPlusFun = pure (mempty, TFun TInt (TFun TInt TInt)) | |
ti env e@(ELet x e1 e2) = addErrorContext e $ do | |
(s1, t1) <- ti env e1 | |
let env' = Map.delete x (getTypeEnv env) | |
t' = generalize (applySubstToTypeEnv s1 env) t1 | |
env'' = TypeEnv (Map.insert x t' env') | |
(s2, t2) <- ti (applySubstToTypeEnv s1 env'') e2 | |
return (s1 <> s2, t2) | |
typeInference :: TypeEnv -> Exp -> TypeCheck Type | |
typeInference env e = do | |
(s, t) <- ti env e | |
return (applySubstToType s t) | |
examples :: [Exp] | |
examples = | |
[ EAbs "x" (EVar "x"), | |
EAbs "f" (EAbs "x" (EApp (EVar "f") (EApp (EVar "f") (EVar "x")))), | |
ELet | |
"id" | |
(EAbs "x" (EVar "x")) | |
(EVar "id"), | |
ELet | |
"id" | |
(EAbs "x" (EVar "x")) | |
(EApp (EVar "id") (EVar "id")), | |
ELet | |
"id" | |
(EAbs "x" (ELet "y" (EVar "x") (EVar "y"))) | |
(EApp (EVar "id") (EVar "id")), | |
ELet | |
"id" | |
(EAbs "x" (ELet "y" (EVar "x") (EVar "y"))) | |
(EApp (EApp (EVar "id") (EVar "id")) (ELit (LInt 2))), | |
ELet | |
"wrong" | |
(EAbs "x" (EApp (EVar "x") (EVar "x"))) | |
(EVar "wrong"), | |
ELet | |
"wrong2" | |
(EAbs "x" (EApp (EApp (EVar "x") (EVar "x")) (EVar "x"))) | |
(EVar "wrong2"), | |
EAbs | |
"m" | |
( ELet | |
"y" | |
(EVar "m") | |
( ELet | |
"x" | |
(EApp (EVar "y") (ELit (LBool True))) | |
(EVar "x") | |
) | |
), | |
EApp (ELit (LInt 2)) (ELit (LInt 2)), | |
ELet "id" (EAbs "x" (EVar "x")) (EApp (EVar "id") (ELit (LInt 2))), | |
ELet | |
"omega" | |
(EApp (EAbs "x" (EApp (EVar "x") (EVar "x"))) (EAbs "x" (EApp (EVar "x") (EVar "x")))) | |
(EVar "omega"), | |
EAbs "f" (EAbs "x" (EApp (EVar "f") (EApp (EVar "f") (EVar "x")))), | |
ELet | |
"plusOne" | |
(EAbs "x" (EPlus (ELit (LInt 1)) (EVar "x"))) | |
(ELet "two" (ELit (LInt 2)) (EApp (EVar "plusOne") (EVar "two"))), | |
ELet | |
"plusplus" | |
(EAbs "x" (EAbs "y" (EPlus (EVar "x") (EPlus (EVar "y") (EVar "x"))))) | |
(ELet "two" (ELit (LInt 2)) (EApp (EVar "plusplus") (EVar "two"))), | |
let f = EAbs "x" (EPlus (ELit (LInt 1)) (EVar "x")) | |
z = ELit (LInt 0) | |
in ELet "church2" (EAbs "f" (EAbs "x" (EApp (EVar "f") (EApp (EVar "f") (EVar "x"))))) (EApp (EApp (EVar "church2") f) z) | |
] | |
test :: Exp -> IO () | |
test e = | |
case runTI (typeInference (TypeEnv Map.empty) e) of | |
Left err -> putStrLn $ show e ++ "\n " ++ show err ++ "\n" | |
Right t -> putStrLn $ show e ++ " :: " ++ show t ++ "\n" | |
main :: IO () | |
main = mapM_ test examples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment