Skip to content

Instantly share code, notes, and snippets.

@shhyou
Last active December 20, 2015 11:09
Show Gist options
  • Save shhyou/6120923 to your computer and use it in GitHub Desktop.
Save shhyou/6120923 to your computer and use it in GitHub Desktop.
An type inference program for HM-system lambda calculus. to be checked later...
{-# LANGUAGE FlexibleContexts #-}
{-
runSolve (Arrow (TypeVar 0) (TypeVar 1)) (Arrow (TypeVar 1) (TypeVar 0))
runSolve (Arrow IntType (TypeVar 1)) (Arrow (TypeVar 0) (TypeVar 0))
runSolve (Arrow (Arrow (TypeVar 0) (TypeVar 1)) (TypeVar 2)) (Arrow (TypeVar 3) (Arrow (TypeVar 4) (TypeVar 5)))
runSolve (TypeVar 0) (Arrow (TypeVar 0) (TypeVar 1))
runSolve (TypeVar 0) (Arrow (TypeVar 1) (TypeVar 0))
typeInfer prog_I
typeInfer prog_Paradox
typeInfer prog_letid
-}
import Control.Monad.State
import Control.Monad.Error
import Control.Monad.Identity
import Control.Applicative (Applicative, (<*>), (<$>), pure)
import Data.List (nub,(\\))
import qualified Data.Map as Map
data Expr = ConstInt Int
| Var String
| Ap Expr Expr
| Lambda String Expr
| Let String Expr Expr
deriving Show
data Type = TypeVar Int
| IntType
| Arrow Type Type
type Context = Map.Map Int Type
typeVars :: Type -> [Int]
typeVars t = nub $ typeVars' t []
where typeVars' (TypeVar v) = (v:)
typeVars' IntType = id
typeVars' (Arrow t1 t2) = typeVars' t2 . typeVars' t1
prettyPrintType :: Type -> Bool -> String
prettyPrintType (TypeVar n) _ = "v" ++ show n
prettyPrintType IntType _ = "Int"
prettyPrintType (Arrow t1 t2) pos = lparen ++ prettyPrintType t1 True ++
" -> " ++ prettyPrintType t2 False ++ rparen
where (lparen, rparen) | pos = ("(", ")")
| otherwise = ("", "")
prettyPrintQuantification :: Type -> String
prettyPrintQuantification = ("forall" ++) . concat . map ((" v" ++) . show) . typeVars
instance Show Type where
show t = prettyPrintType t False
-- \x -> x
prog_I = Lambda "x" (Var "x")
-- \x y z -> (x z) (y z)
prog_S = Lambda "x"
(Lambda "y"
(Lambda "z"
(Ap (Ap (Var "x") (Var "z"))
(Ap (Var "y") (Var "z")))))
-- \x y -> x
prog_K = Lambda "x" (Lambda "y" (Var "x"))
-- \x -> x x
-- not typable
prog_Paradox = Lambda "x" (Ap (Var "x") (Var "x"))
-- \x y -> \z -> z (x y) (y x)
-- not typable
prog_Cycle = Lambda "x"
(Lambda "y"
(Lambda "z"
(Ap (Ap (Var "z")
(Ap (Var "x") (Var "y")))
(Ap (Var "y") (Var "x")))))
-- Y: \f -> (\x -> f (x x)) (\x -> f (x x))
-- not typable
prog_Y = Lambda "f"
(Ap
(Lambda "x" (Ap (Var "f") (Ap (Var "x") (Var "x"))))
(Lambda "x" (Ap (Var "f") (Ap (Var "x") (Var "x")))))
-- \f g h x y -> h (g x) (g y) (f y) (f x)
prog_A = Lambda "f"
(Lambda "g"
(Lambda "h"
(Lambda "x"
(Lambda "y"
(Ap (Ap (Ap (Ap
(Var "h")
(Ap (Var "g") (Var "x")))
(Ap (Var "g") (Var "y")))
(Ap (Var "f") (Var "y")))
(Ap (Var "f") (Var "x")))))))
-- let id = \x -> x in id id
prog_letid = Let "id" (Lambda "x" (Var "x")) $
Ap (Var "id") (Var "id")
-- To check that bound variable is not generalized by let
-- not typable
prog_fid = Lambda "f"
(Let "f'" (Var "f")
(Ap (Var "f'") (Var "f'")))
assertLookup :: (Eq a, Show a) => a -> [(a,b)] -> b
assertLookup x env =
case lookup x env of
Just val -> val
Nothing -> error ("Unbound variable " ++ show x)
deepFind :: Context -> Type -> Type
deepFind _ IntType = IntType
deepFind cxt (TypeVar v) = case Map.lookup v cxt of
Just t -> deepFind cxt t
Nothing -> TypeVar v
deepFind cxt (Arrow t1 t2) = Arrow (deepFind cxt t1) (deepFind cxt t2)
occursIn :: Context -> Int -> Type -> Bool
occursIn cxt v (TypeVar v') =
case deepFind cxt (TypeVar v') of
TypeVar v'' -> v == v''
t'' -> occursIn cxt v t''
occursIn _ _ IntType = False
occursIn cxt v (Arrow t1 t2) = occursIn cxt v t1 || occursIn cxt v t2
solve :: (Applicative m, MonadState Context m, MonadError String m)
=> Type
-> Type
-> m ()
solve IntType IntType =
return ()
solve (Arrow dom1 codom1) (Arrow dom2 codom2) = do
solve dom1 dom2
solve codom1 codom2
solve (TypeVar v1) t2 = do
cxt <- get
case Map.lookup v1 cxt of
Just t1 -> solve t1 t2
Nothing -> case deepFind cxt t2 of
TypeVar v2 | v1 == v2 -> return ()
t2' -> if occursIn cxt v1 t2'
then throwError . strMsg $ "Unable to construct type " ++ show (TypeVar v1) ++ " = " ++ show t2'
++ "\nIn the context of: " ++ show cxt
else put $ Map.insert v1 t2' cxt
solve t1 (TypeVar v2) =
solve (TypeVar v2) t1
solve t1 t2 = do
cxt <- get
throwError . strMsg $ "Unable to solve " ++ show t1 ++ " with " ++ show t2
++ "\nin the context of " ++ show cxt
nextVar :: MonadState Int m
=> m Int
nextVar = modify (+1) >> get
instantiate :: MonadState Int m
=> Type
-> [Int]
-> m Type
instantiate t oldVars = do
vars <- mapM (const nextVar) oldVars
return (substitute (zip oldVars vars) t)
where substitute _ IntType = IntType
substitute lst (TypeVar v) = case lookup v lst of
Just v' -> TypeVar v'
Nothing -> TypeVar v
substitute lst (Arrow t1 t2) = Arrow (substitute lst t1) (substitute lst t2)
buildUp :: (Applicative m, MonadState Context m, MonadError String m)
=> [(String, Int)] -- context
-> [(String, (Type, [Int]))] -- polymorphic type
-> Expr
-> StateT Int m Type
buildUp _ _ (ConstInt _) =
return IntType
buildUp cxt poly (Var x) =
case lookup x cxt of
Just v -> deepFind <$> (lift get) <*> pure (TypeVar v)
Nothing -> uncurry instantiate (assertLookup x poly)
buildUp cxt poly (Lambda x e) = do
va <- nextVar
b <- buildUp ((x,va):cxt) poly e
deepFind <$> (lift get) <*> pure (Arrow (TypeVar va) b)
buildUp cxt poly (Ap e1 e2) = do
a2b <- buildUp cxt poly e1
a <- buildUp cxt poly e2
b <- TypeVar <$> nextVar
lift $ solve a2b (Arrow a b)
deepFind <$> (lift get) <*> pure b
buildUp cxt poly (Let x e body) = do
t <- buildUp cxt poly e
buildUp cxt ((x, (t, typeVars t \\ map snd cxt)):poly) body
runMonads :: s -> StateT s (ErrorT String Identity) a -> Either String (a, s)
runMonads initState = runIdentity . runErrorT . flip runStateT initState
printEitherType x =
case x of
Left str -> putStrLn str
Right t -> print t
runSolve' :: Type -> Type -> Either String Type
runSolve' t1 t2 = fmap (\(_, cxt) -> deepFind cxt t1) $
runMonads Map.empty $
solve t1 t2
runSolve :: Type -> Type -> IO ()
runSolve t1 t2 = printEitherType $ runSolve' t1 t2
typeInfer = printEitherType . fmap (fst . fst) . runMonads Map.empty . flip runStateT 0 . buildUp [] []
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment