Skip to content

Instantly share code, notes, and snippets.

@pedrominicz
Last active September 21, 2019 19:06
Show Gist options
  • Save pedrominicz/2d8e6d13e829b72e739a7a386eb58417 to your computer and use it in GitHub Desktop.
Save pedrominicz/2d8e6d13e829b72e739a7a386eb58417 to your computer and use it in GitHub Desktop.
Hindley-Milner Type Inference for Lambda Calculus (doodle made while reading a tutorial).
module Infer where
-- https://crypto.stanford.edu/~blynn/lambda/hm.html
import Control.Monad.Reader
import Control.Monad.State
import Data.List (elemIndex)
import Safe (atMay)
import Text.Parsec hiding (State, parse)
type Name = String
data Expr
= Ref Int
| Global Name
| Lam (Maybe Type) Expr
| App Expr Expr
| Num Integer
| Bool Bool
deriving Show
data Type
= Type Int
| LamT Type Type
| NumT
| BoolT
type Constraint = (Type, Type)
type Environment = [(Name, Type)]
gather :: Environment -> Expr -> StateT ([Constraint], Int) Maybe Type
gather env (Ref x) = do
x' <- lift $ atMay env x
return $ snd x'
gather env (Global x) = lift $ lookup x env
gather env (Lam (Just targ) body) = do
tbody <- gather (("", targ):env) body
return $ LamT targ tbody
gather env (Lam Nothing body) = do
targ <- nextType
tbody <- gather (("", targ):env) body
return $ LamT targ tbody
gather env (App x y) = do
tx <- gather env x
ty <- gather env y
t <- nextType
(cs, i) <- get
put ((tx, LamT ty t):cs, i)
return t
gather _ (Num _) = return NumT
gather _ (Bool _) = return BoolT
nextType :: StateT ([(Type, Type)], Int) Maybe Type
nextType = do
(cs, i) <- get
put (cs, i + 1)
return $ Type i
unify :: [Constraint] -> ReaderT [(Int, Type)] Maybe [(Int, Type)]
unify [] = ask
unify ((Type x, y):rest) =
if occurs x y
then lift Nothing
else local ((x, y):) $ unify $ tuplify (substitute (x, y)) <$> rest
unify ((x, y@(Type _)):rest) = unify ((y, x):rest)
unify ((LamT x y, LamT x' y'):rest) =
unify $ (x, x'):(y, y'):rest
unify ((NumT, NumT):rest) = unify rest
unify ((BoolT, BoolT):rest) = unify rest
unify _ = lift $ Nothing
occurs :: Int -> Type -> Bool
occurs x (LamT x' y') = occurs x x' || occurs x y'
occurs x (Type x') | x == x' = True
occurs _ _ = False
tuplify :: (a -> a) -> (a, a) -> (a, a)
tuplify f (x, y) = (f x, f y)
substitute :: (Int, Type) -> Type -> Type
substitute (x, y) (Type x')
| x == x' = y
substitute x (LamT x' y') = LamT (substitute x x') (substitute x y')
substitute _ x = x
solve :: Environment -> Expr -> Maybe Type
solve env x = do
(t, (cs, _)) <- runStateT (gather env x) ([], 0)
foldr substitute t <$> runReaderT (unify cs) []
-- Parser
type Parser = ParsecT String () (Reader [String])
parse :: String -> Maybe Expr
parse s =
case runReader (runParserT (whitespace *> expression <* eof) () "" s) [] of
Left _ -> Nothing
Right x -> Just x
expression :: Parser Expr
expression = lambda
<|> application
<|> boolean
<|> variable
<|> number
<|> parens expression
lambda :: Parser Expr
lambda = try $ do
optional $ char 'λ' *> whitespace
x <- name
t <- maybeType
char '.' *> whitespace
y <- local (x:) expression
return (Lam t y)
maybeType :: Parser (Maybe Type)
maybeType = optionMaybe $ do
char ':' *> whitespace
t <- lambdaType
return t
lambdaType :: Parser Type
lambdaType = ty `chainr1` arrow
where ty = reserved "Num" NumT
<|> reserved "Bool" BoolT
<|> parens lambdaType
arrow = pure LamT <* string "->" <* whitespace
application :: Parser Expr
application = expression' `chainl1` return App
where expression' = boolean
<|> variable
<|> number
<|> parens expression
boolean :: Parser Expr
boolean = Bool <$> (reserved "true" True <|> reserved "false" False)
variable :: Parser Expr
variable = do
x <- name
env <- ask
case elemIndex x env of
Just i -> return $ Ref i
Nothing -> return $ Global x
number :: Parser Expr
number = do
sign <- option ' ' (char '-')
digits <- many1 digit
whitespace
return $ Num $ read (sign:digits)
name :: Parser String
name = do
c <- letter
cs <- many alphaNum
whitespace
return (c:cs)
reserved :: String -> a -> Parser a
reserved s x = try $ do
s' <- name
if s == s'
then return x
else unexpected s'
parens :: Parser a -> Parser a
parens p = between open close p
where open = char '(' <* whitespace
close = char ')' <* whitespace
whitespace :: Parser ()
whitespace = skipMany space
-- Main
instance Show Type where
show x = showType False x
showType :: Bool -> Type -> String
showType _ (Type x) = vars !! x
showType left (LamT x y)
| left = "(" ++ x' ++ " -> " ++ y' ++ ")"
| otherwise = x' ++ " -> " ++ y'
where x' = showType True x
y' = showType False y
showType _ NumT = "Num"
showType _ BoolT = "Bool"
vars :: [Name]
vars = [c:show' n | n <- [0..], c <- ['a'..'z']]
where show' 0 = ""
show' x = show x
(~>) :: Type -> Type -> Type
(~>) = LamT
infixr 5 ~>
prelude :: Environment
prelude = [("even", NumT ~> BoolT),
("odd", NumT ~> BoolT),
("add", NumT ~> NumT ~> NumT),
("mul", NumT ~> NumT ~> NumT),
("div", NumT ~> NumT ~> NumT)]
main :: IO ()
main = do
expr <- getLine
case parse expr >>= solve prelude of
Just x -> putStrLn $ show x
Nothing -> putStrLn "untypeable expression"
main
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment