Skip to content

Instantly share code, notes, and snippets.

@pedrominicz
Last active September 24, 2019 13:13
Show Gist options
  • Save pedrominicz/4cd39439a17111adc0fda9696eb2e6e5 to your computer and use it in GitHub Desktop.
Save pedrominicz/4cd39439a17111adc0fda9696eb2e6e5 to your computer and use it in GitHub Desktop.
Hindley-Milner type inference.
-- https://crypto.stanford.edu/~blynn/lambda/pcf.html
-- http://dev.stephendiehl.com/fun/006_hindley_milner.html
-- https://en.wikipedia.org/wiki/Hindley–Milner_type_system#Free_type_variables
-- https://en.wikipedia.org/wiki/Hindley–Milner_type_system#Algorithm_J
module Let where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.List (elemIndex)
import Safe (atMay)
import Text.Parsec hiding (State, parse)
type Name = String
data Expr
= Ref Int
| Global Name
| Lam (Maybe Type) Expr
| App Expr Expr
| Num Integer
| Bool Bool
| Let Expr Expr
deriving Show
data Type
= TVar Int
| LamT Type Type
| NumT
| BoolT
data Scheme = Forall [Int] Type
deriving Show
type Environment = [(Name, Scheme)]
type Binding = (Int, [(Int, Type)])
type Infer = StateT Binding (ReaderT Environment (Except String))
newType :: Infer Type
newType = do
(i, env) <- get
put (i + 1, env)
return $ TVar i
infer :: [Scheme] -> Expr -> Infer Type
infer env (Ref x) =
case atMay env x of
Just x' -> instantiate x'
Nothing -> throwError $ "unbound reference: " ++ show x
infer _ (Global x) = do
env <- ask
case lookup x env of
Just x' -> instantiate x'
Nothing -> throwError $ "unbound variable: " ++ x
infer env (Lam (Just t) x) = do
tx <- infer (Forall [] t:env) x
return $ LamT t tx
infer env (Lam Nothing x) = do
t <- newType
tx <- infer (Forall [] t:env) x
return $ LamT t tx
infer env (App x y) = do
tx <- infer env x
ty <- infer env y
t <- newType
tx `unify` LamT ty t
return t
infer _ (Num _) = return NumT
infer _ (Bool _) = return BoolT
infer env (Let x y) = do
t <- generalize env <$> infer env x
infer (t:env) y
instantiate :: Scheme -> Infer Type
instantiate (Forall xs x) = do
xs' <- mapM (\x' -> (,) x' <$> newType) xs
return $ instantiate' xs' x
instantiate' :: [(Int, Type)] -> Type -> Type
instantiate' xs t@(TVar x) =
case lookup x xs of
Just x' -> x'
Nothing -> t
instantiate' xs (LamT x y) = LamT (instantiate' xs x) (instantiate' xs y)
instantiate' _ x = x
generalize :: [Scheme] -> Type -> Scheme
generalize env x = Forall (filter (`notElem` freeEnv) freeVar) x
where freeEnv = concatMap (\(Forall _ x') -> free x') env
freeVar = free x
free :: Type -> [Int]
free (TVar x) = [x]
free (LamT x y) = free x ++ free y
free _ = []
unify :: Type -> Type -> Infer ()
unify x y = do
x' <- applyBindings x
y' <- applyBindings y
unify' x' y'
unify' :: Type -> Type -> Infer ()
unify' (TVar x) y = do
y' <- applyBindings y
modify (\(i, env) -> (i, (x, y'):env))
unify' x y@(TVar _) = unify' y x
unify' (LamT x x') (LamT y y') = do
unify' x y
x'' <- applyBindings x'
y'' <- applyBindings y'
unify' x'' y''
unify' NumT NumT = return ()
unify' BoolT BoolT = return ()
unify' _ _ = throwError "cannot match types"
applyBindings :: Type -> Infer Type
applyBindings (TVar x) = do
(_, env) <- get
case lookup x env of
Just x' -> do
occursGuard x x'
applyBindings x'
Nothing -> return $ TVar x
applyBindings (LamT x y) = do
x' <- applyBindings x
y' <- applyBindings y
return $ LamT x' y'
applyBindings NumT = return NumT
applyBindings BoolT = return BoolT
occursGuard :: Int -> Type -> Infer ()
occursGuard x (LamT x' y') = do
occursGuard x x'
occursGuard x y'
occursGuard x (TVar x') | x == x' = throwError "infinite type"
occursGuard _ _ = return ()
check :: Environment -> Expr -> Except String Type
check env x = runReaderT (evalStateT (infer [] x >>= applyBindings) (0, [])) env
-- Parser
type Parser = ParsecT String () (Reader [String])
parse :: String -> Except String Expr
parse s =
case runReader (runParserT (whitespace *> expression <* eof) () "" s) [] of
Left e -> throwError $ show e
Right x -> return $ x
isReserved :: Name -> Bool
isReserved = flip elem ["let", "in", "Num", "Bool", "true", "false"]
expression :: Parser Expr
expression = letExpr
<|> lambda
<|> application
<|> boolean
<|> variable
<|> number
<|> parens expression
letExpr :: Parser Expr
letExpr = try $ do
reserved "let" ()
x <- name
char '=' *> whitespace
e <- expression
reserved "in" ()
y <- local (x:) expression
return (Let e y)
lambda :: Parser Expr
lambda = try $ do
optional $ char 'λ' *> whitespace
x <- name
t <- maybeType
char '.' *> whitespace
y <- local (x:) expression
return (Lam t y)
maybeType :: Parser (Maybe Type)
maybeType = optionMaybe $ do
char ':' *> whitespace
t <- lambdaType
return t
lambdaType :: Parser Type
lambdaType = ty `chainr1` arrow
where ty = reserved "Num" NumT
<|> reserved "Bool" BoolT
<|> parens lambdaType
arrow = return LamT <* string "->" <* whitespace
application :: Parser Expr
application = try $ expression' `chainl1` return App
where expression' = boolean
<|> variable
<|> number
<|> parens expression
boolean :: Parser Expr
boolean = Bool <$> (reserved "true" True <|> reserved "false" False)
variable :: Parser Expr
variable = do
x <- name
env <- ask
case elemIndex x env of
Just i -> return $ Ref i
Nothing -> return $ Global x
number :: Parser Expr
number = do
sign <- option ' ' (char '-')
digits <- many1 digit
whitespace
return $ Num $ read (sign:digits)
name :: Parser String
name = do
c <- letter
cs <- many alphaNum
whitespace
let s = c:cs
if isReserved s
then unexpected s
else return s
reserved :: String -> a -> Parser a
reserved s x = try $ do
string s *> notFollowedBy alphaNum *> whitespace
return x
parens :: Parser a -> Parser a
parens p = between open close p
where open = char '(' <* whitespace
close = char ')' <* whitespace
whitespace :: Parser ()
whitespace = skipMany space
-- Main
instance Show Type where
show x = evalState (showType False x) []
showType :: Bool -> Type -> State [Int] String
showType left (TVar x) = do
env <- get
case elemIndex x env of
Just x' -> return $ vars !! x'
Nothing -> do
modify (++[x])
showType left (TVar x)
showType left (LamT x y) = do
x' <- showType True x
y' <- showType False y
if left
then return $ "(" ++ x' ++ " -> " ++ y' ++ ")"
else return $ x' ++ " -> " ++ y'
showType _ NumT = return $ "Num"
showType _ BoolT = return $ "Bool"
vars :: [Name]
vars = [c:show' n | n <- [0..], c <- ['a'..'z']]
where show' :: Int -> String
show' 0 = ""
show' x = show x
(~>) :: Type -> Type -> Type
(~>) = LamT
infixr 5 ~>
prelude :: Environment
prelude = [("even", Forall [] $ NumT ~> BoolT),
("odd", Forall [] $ NumT ~> BoolT),
("add", Forall [] $ NumT ~> NumT ~> NumT),
("mul", Forall [] $ NumT ~> NumT ~> NumT),
("div", Forall [] $ NumT ~> NumT ~> NumT),
("s", Forall [0,1,2] $ (TVar 0 ~> TVar 1 ~> TVar 2) ~> (TVar 0 ~> TVar 1) ~> TVar 0 ~> TVar 2),
("k", Forall [0,1] $ TVar 0 ~> TVar 1 ~> TVar 0),
("i", Forall [0] $ TVar 0 ~> TVar 0)]
main :: IO ()
main = do
expr <- getLine
case runExcept (parse expr >>= check prelude) of
Left e -> putStrLn e
Right x -> putStrLn $ show x
main
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment