Last active
December 20, 2015 11:09
-
-
Save shhyou/6120923 to your computer and use it in GitHub Desktop.
An type inference program for HM-system lambda calculus.
to be checked later...
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts #-} | |
{- | |
runSolve (Arrow (TypeVar 0) (TypeVar 1)) (Arrow (TypeVar 1) (TypeVar 0)) | |
runSolve (Arrow IntType (TypeVar 1)) (Arrow (TypeVar 0) (TypeVar 0)) | |
runSolve (Arrow (Arrow (TypeVar 0) (TypeVar 1)) (TypeVar 2)) (Arrow (TypeVar 3) (Arrow (TypeVar 4) (TypeVar 5))) | |
runSolve (TypeVar 0) (Arrow (TypeVar 0) (TypeVar 1)) | |
runSolve (TypeVar 0) (Arrow (TypeVar 1) (TypeVar 0)) | |
typeInfer prog_I | |
typeInfer prog_Paradox | |
typeInfer prog_letid | |
-} | |
import Control.Monad.State | |
import Control.Monad.Error | |
import Control.Monad.Identity | |
import Control.Applicative (Applicative, (<*>), (<$>), pure) | |
import Data.List (nub,(\\)) | |
import qualified Data.Map as Map | |
data Expr = ConstInt Int | |
| Var String | |
| Ap Expr Expr | |
| Lambda String Expr | |
| Let String Expr Expr | |
deriving Show | |
data Type = TypeVar Int | |
| IntType | |
| Arrow Type Type | |
type Context = Map.Map Int Type | |
typeVars :: Type -> [Int] | |
typeVars t = nub $ typeVars' t [] | |
where typeVars' (TypeVar v) = (v:) | |
typeVars' IntType = id | |
typeVars' (Arrow t1 t2) = typeVars' t2 . typeVars' t1 | |
prettyPrintType :: Type -> Bool -> String | |
prettyPrintType (TypeVar n) _ = "v" ++ show n | |
prettyPrintType IntType _ = "Int" | |
prettyPrintType (Arrow t1 t2) pos = lparen ++ prettyPrintType t1 True ++ | |
" -> " ++ prettyPrintType t2 False ++ rparen | |
where (lparen, rparen) | pos = ("(", ")") | |
| otherwise = ("", "") | |
prettyPrintQuantification :: Type -> String | |
prettyPrintQuantification = ("forall" ++) . concat . map ((" v" ++) . show) . typeVars | |
instance Show Type where | |
show t = prettyPrintType t False | |
-- \x -> x | |
prog_I = Lambda "x" (Var "x") | |
-- \x y z -> (x z) (y z) | |
prog_S = Lambda "x" | |
(Lambda "y" | |
(Lambda "z" | |
(Ap (Ap (Var "x") (Var "z")) | |
(Ap (Var "y") (Var "z"))))) | |
-- \x y -> x | |
prog_K = Lambda "x" (Lambda "y" (Var "x")) | |
-- \x -> x x | |
-- not typable | |
prog_Paradox = Lambda "x" (Ap (Var "x") (Var "x")) | |
-- \x y -> \z -> z (x y) (y x) | |
-- not typable | |
prog_Cycle = Lambda "x" | |
(Lambda "y" | |
(Lambda "z" | |
(Ap (Ap (Var "z") | |
(Ap (Var "x") (Var "y"))) | |
(Ap (Var "y") (Var "x"))))) | |
-- Y: \f -> (\x -> f (x x)) (\x -> f (x x)) | |
-- not typable | |
prog_Y = Lambda "f" | |
(Ap | |
(Lambda "x" (Ap (Var "f") (Ap (Var "x") (Var "x")))) | |
(Lambda "x" (Ap (Var "f") (Ap (Var "x") (Var "x"))))) | |
-- \f g h x y -> h (g x) (g y) (f y) (f x) | |
prog_A = Lambda "f" | |
(Lambda "g" | |
(Lambda "h" | |
(Lambda "x" | |
(Lambda "y" | |
(Ap (Ap (Ap (Ap | |
(Var "h") | |
(Ap (Var "g") (Var "x"))) | |
(Ap (Var "g") (Var "y"))) | |
(Ap (Var "f") (Var "y"))) | |
(Ap (Var "f") (Var "x"))))))) | |
-- let id = \x -> x in id id | |
prog_letid = Let "id" (Lambda "x" (Var "x")) $ | |
Ap (Var "id") (Var "id") | |
-- To check that bound variable is not generalized by let | |
-- not typable | |
prog_fid = Lambda "f" | |
(Let "f'" (Var "f") | |
(Ap (Var "f'") (Var "f'"))) | |
assertLookup :: (Eq a, Show a) => a -> [(a,b)] -> b | |
assertLookup x env = | |
case lookup x env of | |
Just val -> val | |
Nothing -> error ("Unbound variable " ++ show x) | |
deepFind :: Context -> Type -> Type | |
deepFind _ IntType = IntType | |
deepFind cxt (TypeVar v) = case Map.lookup v cxt of | |
Just t -> deepFind cxt t | |
Nothing -> TypeVar v | |
deepFind cxt (Arrow t1 t2) = Arrow (deepFind cxt t1) (deepFind cxt t2) | |
occursIn :: Context -> Int -> Type -> Bool | |
occursIn cxt v (TypeVar v') = | |
case deepFind cxt (TypeVar v') of | |
TypeVar v'' -> v == v'' | |
t'' -> occursIn cxt v t'' | |
occursIn _ _ IntType = False | |
occursIn cxt v (Arrow t1 t2) = occursIn cxt v t1 || occursIn cxt v t2 | |
solve :: (Applicative m, MonadState Context m, MonadError String m) | |
=> Type | |
-> Type | |
-> m () | |
solve IntType IntType = | |
return () | |
solve (Arrow dom1 codom1) (Arrow dom2 codom2) = do | |
solve dom1 dom2 | |
solve codom1 codom2 | |
solve (TypeVar v1) t2 = do | |
cxt <- get | |
case Map.lookup v1 cxt of | |
Just t1 -> solve t1 t2 | |
Nothing -> case deepFind cxt t2 of | |
TypeVar v2 | v1 == v2 -> return () | |
t2' -> if occursIn cxt v1 t2' | |
then throwError . strMsg $ "Unable to construct type " ++ show (TypeVar v1) ++ " = " ++ show t2' | |
++ "\nIn the context of: " ++ show cxt | |
else put $ Map.insert v1 t2' cxt | |
solve t1 (TypeVar v2) = | |
solve (TypeVar v2) t1 | |
solve t1 t2 = do | |
cxt <- get | |
throwError . strMsg $ "Unable to solve " ++ show t1 ++ " with " ++ show t2 | |
++ "\nin the context of " ++ show cxt | |
nextVar :: MonadState Int m | |
=> m Int | |
nextVar = modify (+1) >> get | |
instantiate :: MonadState Int m | |
=> Type | |
-> [Int] | |
-> m Type | |
instantiate t oldVars = do | |
vars <- mapM (const nextVar) oldVars | |
return (substitute (zip oldVars vars) t) | |
where substitute _ IntType = IntType | |
substitute lst (TypeVar v) = case lookup v lst of | |
Just v' -> TypeVar v' | |
Nothing -> TypeVar v | |
substitute lst (Arrow t1 t2) = Arrow (substitute lst t1) (substitute lst t2) | |
buildUp :: (Applicative m, MonadState Context m, MonadError String m) | |
=> [(String, Int)] -- context | |
-> [(String, (Type, [Int]))] -- polymorphic type | |
-> Expr | |
-> StateT Int m Type | |
buildUp _ _ (ConstInt _) = | |
return IntType | |
buildUp cxt poly (Var x) = | |
case lookup x cxt of | |
Just v -> deepFind <$> (lift get) <*> pure (TypeVar v) | |
Nothing -> uncurry instantiate (assertLookup x poly) | |
buildUp cxt poly (Lambda x e) = do | |
va <- nextVar | |
b <- buildUp ((x,va):cxt) poly e | |
deepFind <$> (lift get) <*> pure (Arrow (TypeVar va) b) | |
buildUp cxt poly (Ap e1 e2) = do | |
a2b <- buildUp cxt poly e1 | |
a <- buildUp cxt poly e2 | |
b <- TypeVar <$> nextVar | |
lift $ solve a2b (Arrow a b) | |
deepFind <$> (lift get) <*> pure b | |
buildUp cxt poly (Let x e body) = do | |
t <- buildUp cxt poly e | |
buildUp cxt ((x, (t, typeVars t \\ map snd cxt)):poly) body | |
runMonads :: s -> StateT s (ErrorT String Identity) a -> Either String (a, s) | |
runMonads initState = runIdentity . runErrorT . flip runStateT initState | |
printEitherType x = | |
case x of | |
Left str -> putStrLn str | |
Right t -> print t | |
runSolve' :: Type -> Type -> Either String Type | |
runSolve' t1 t2 = fmap (\(_, cxt) -> deepFind cxt t1) $ | |
runMonads Map.empty $ | |
solve t1 t2 | |
runSolve :: Type -> Type -> IO () | |
runSolve t1 t2 = printEitherType $ runSolve' t1 t2 | |
typeInfer = printEitherType . fmap (fst . fst) . runMonads Map.empty . flip runStateT 0 . buildUp [] [] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment