Last active
September 24, 2019 13:13
-
-
Save pedrominicz/4cd39439a17111adc0fda9696eb2e6e5 to your computer and use it in GitHub Desktop.
Hindley-Milner type inference.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- https://crypto.stanford.edu/~blynn/lambda/pcf.html | |
-- http://dev.stephendiehl.com/fun/006_hindley_milner.html | |
-- https://en.wikipedia.org/wiki/Hindley–Milner_type_system#Free_type_variables | |
-- https://en.wikipedia.org/wiki/Hindley–Milner_type_system#Algorithm_J | |
module Let where | |
import Control.Monad.Except | |
import Control.Monad.Reader | |
import Control.Monad.State | |
import Data.List (elemIndex) | |
import Safe (atMay) | |
import Text.Parsec hiding (State, parse) | |
type Name = String | |
data Expr | |
= Ref Int | |
| Global Name | |
| Lam (Maybe Type) Expr | |
| App Expr Expr | |
| Num Integer | |
| Bool Bool | |
| Let Expr Expr | |
deriving Show | |
data Type | |
= TVar Int | |
| LamT Type Type | |
| NumT | |
| BoolT | |
data Scheme = Forall [Int] Type | |
deriving Show | |
type Environment = [(Name, Scheme)] | |
type Binding = (Int, [(Int, Type)]) | |
type Infer = StateT Binding (ReaderT Environment (Except String)) | |
newType :: Infer Type | |
newType = do | |
(i, env) <- get | |
put (i + 1, env) | |
return $ TVar i | |
infer :: [Scheme] -> Expr -> Infer Type | |
infer env (Ref x) = | |
case atMay env x of | |
Just x' -> instantiate x' | |
Nothing -> throwError $ "unbound reference: " ++ show x | |
infer _ (Global x) = do | |
env <- ask | |
case lookup x env of | |
Just x' -> instantiate x' | |
Nothing -> throwError $ "unbound variable: " ++ x | |
infer env (Lam (Just t) x) = do | |
tx <- infer (Forall [] t:env) x | |
return $ LamT t tx | |
infer env (Lam Nothing x) = do | |
t <- newType | |
tx <- infer (Forall [] t:env) x | |
return $ LamT t tx | |
infer env (App x y) = do | |
tx <- infer env x | |
ty <- infer env y | |
t <- newType | |
tx `unify` LamT ty t | |
return t | |
infer _ (Num _) = return NumT | |
infer _ (Bool _) = return BoolT | |
infer env (Let x y) = do | |
t <- generalize env <$> infer env x | |
infer (t:env) y | |
instantiate :: Scheme -> Infer Type | |
instantiate (Forall xs x) = do | |
xs' <- mapM (\x' -> (,) x' <$> newType) xs | |
return $ instantiate' xs' x | |
instantiate' :: [(Int, Type)] -> Type -> Type | |
instantiate' xs t@(TVar x) = | |
case lookup x xs of | |
Just x' -> x' | |
Nothing -> t | |
instantiate' xs (LamT x y) = LamT (instantiate' xs x) (instantiate' xs y) | |
instantiate' _ x = x | |
generalize :: [Scheme] -> Type -> Scheme | |
generalize env x = Forall (filter (`notElem` freeEnv) freeVar) x | |
where freeEnv = concatMap (\(Forall _ x') -> free x') env | |
freeVar = free x | |
free :: Type -> [Int] | |
free (TVar x) = [x] | |
free (LamT x y) = free x ++ free y | |
free _ = [] | |
unify :: Type -> Type -> Infer () | |
unify x y = do | |
x' <- applyBindings x | |
y' <- applyBindings y | |
unify' x' y' | |
unify' :: Type -> Type -> Infer () | |
unify' (TVar x) y = do | |
y' <- applyBindings y | |
modify (\(i, env) -> (i, (x, y'):env)) | |
unify' x y@(TVar _) = unify' y x | |
unify' (LamT x x') (LamT y y') = do | |
unify' x y | |
x'' <- applyBindings x' | |
y'' <- applyBindings y' | |
unify' x'' y'' | |
unify' NumT NumT = return () | |
unify' BoolT BoolT = return () | |
unify' _ _ = throwError "cannot match types" | |
applyBindings :: Type -> Infer Type | |
applyBindings (TVar x) = do | |
(_, env) <- get | |
case lookup x env of | |
Just x' -> do | |
occursGuard x x' | |
applyBindings x' | |
Nothing -> return $ TVar x | |
applyBindings (LamT x y) = do | |
x' <- applyBindings x | |
y' <- applyBindings y | |
return $ LamT x' y' | |
applyBindings NumT = return NumT | |
applyBindings BoolT = return BoolT | |
occursGuard :: Int -> Type -> Infer () | |
occursGuard x (LamT x' y') = do | |
occursGuard x x' | |
occursGuard x y' | |
occursGuard x (TVar x') | x == x' = throwError "infinite type" | |
occursGuard _ _ = return () | |
check :: Environment -> Expr -> Except String Type | |
check env x = runReaderT (evalStateT (infer [] x >>= applyBindings) (0, [])) env | |
-- Parser | |
type Parser = ParsecT String () (Reader [String]) | |
parse :: String -> Except String Expr | |
parse s = | |
case runReader (runParserT (whitespace *> expression <* eof) () "" s) [] of | |
Left e -> throwError $ show e | |
Right x -> return $ x | |
isReserved :: Name -> Bool | |
isReserved = flip elem ["let", "in", "Num", "Bool", "true", "false"] | |
expression :: Parser Expr | |
expression = letExpr | |
<|> lambda | |
<|> application | |
<|> boolean | |
<|> variable | |
<|> number | |
<|> parens expression | |
letExpr :: Parser Expr | |
letExpr = try $ do | |
reserved "let" () | |
x <- name | |
char '=' *> whitespace | |
e <- expression | |
reserved "in" () | |
y <- local (x:) expression | |
return (Let e y) | |
lambda :: Parser Expr | |
lambda = try $ do | |
optional $ char 'λ' *> whitespace | |
x <- name | |
t <- maybeType | |
char '.' *> whitespace | |
y <- local (x:) expression | |
return (Lam t y) | |
maybeType :: Parser (Maybe Type) | |
maybeType = optionMaybe $ do | |
char ':' *> whitespace | |
t <- lambdaType | |
return t | |
lambdaType :: Parser Type | |
lambdaType = ty `chainr1` arrow | |
where ty = reserved "Num" NumT | |
<|> reserved "Bool" BoolT | |
<|> parens lambdaType | |
arrow = return LamT <* string "->" <* whitespace | |
application :: Parser Expr | |
application = try $ expression' `chainl1` return App | |
where expression' = boolean | |
<|> variable | |
<|> number | |
<|> parens expression | |
boolean :: Parser Expr | |
boolean = Bool <$> (reserved "true" True <|> reserved "false" False) | |
variable :: Parser Expr | |
variable = do | |
x <- name | |
env <- ask | |
case elemIndex x env of | |
Just i -> return $ Ref i | |
Nothing -> return $ Global x | |
number :: Parser Expr | |
number = do | |
sign <- option ' ' (char '-') | |
digits <- many1 digit | |
whitespace | |
return $ Num $ read (sign:digits) | |
name :: Parser String | |
name = do | |
c <- letter | |
cs <- many alphaNum | |
whitespace | |
let s = c:cs | |
if isReserved s | |
then unexpected s | |
else return s | |
reserved :: String -> a -> Parser a | |
reserved s x = try $ do | |
string s *> notFollowedBy alphaNum *> whitespace | |
return x | |
parens :: Parser a -> Parser a | |
parens p = between open close p | |
where open = char '(' <* whitespace | |
close = char ')' <* whitespace | |
whitespace :: Parser () | |
whitespace = skipMany space | |
-- Main | |
instance Show Type where | |
show x = evalState (showType False x) [] | |
showType :: Bool -> Type -> State [Int] String | |
showType left (TVar x) = do | |
env <- get | |
case elemIndex x env of | |
Just x' -> return $ vars !! x' | |
Nothing -> do | |
modify (++[x]) | |
showType left (TVar x) | |
showType left (LamT x y) = do | |
x' <- showType True x | |
y' <- showType False y | |
if left | |
then return $ "(" ++ x' ++ " -> " ++ y' ++ ")" | |
else return $ x' ++ " -> " ++ y' | |
showType _ NumT = return $ "Num" | |
showType _ BoolT = return $ "Bool" | |
vars :: [Name] | |
vars = [c:show' n | n <- [0..], c <- ['a'..'z']] | |
where show' :: Int -> String | |
show' 0 = "" | |
show' x = show x | |
(~>) :: Type -> Type -> Type | |
(~>) = LamT | |
infixr 5 ~> | |
prelude :: Environment | |
prelude = [("even", Forall [] $ NumT ~> BoolT), | |
("odd", Forall [] $ NumT ~> BoolT), | |
("add", Forall [] $ NumT ~> NumT ~> NumT), | |
("mul", Forall [] $ NumT ~> NumT ~> NumT), | |
("div", Forall [] $ NumT ~> NumT ~> NumT), | |
("s", Forall [0,1,2] $ (TVar 0 ~> TVar 1 ~> TVar 2) ~> (TVar 0 ~> TVar 1) ~> TVar 0 ~> TVar 2), | |
("k", Forall [0,1] $ TVar 0 ~> TVar 1 ~> TVar 0), | |
("i", Forall [0] $ TVar 0 ~> TVar 0)] | |
main :: IO () | |
main = do | |
expr <- getLine | |
case runExcept (parse expr >>= check prelude) of | |
Left e -> putStrLn e | |
Right x -> putStrLn $ show x | |
main |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment