Created
May 19, 2020 02:18
-
-
Save yelouafi/daaf624f80a6fe8a7fe609bb108fed2c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# Language TypeSynonymInstances, FlexibleInstances #-} | |
module Algj where | |
import Debug.Trace | |
import Data.Maybe | |
import qualified Data.Map as M | |
import qualified Data.Set as S | |
import Control.Monad.Trans.State | |
import Text.Parsec hiding (State, token) | |
import Text.Parsec.Char | |
----------------------------------------------------------------------------------- | |
----------------------------------------------------------------------------------- | |
-- AST | |
----------------------------------------------------------------------------------- | |
----------------------------------------------------------------------------------- | |
data Exp | |
= Var String | |
| Int Int | |
| Pair Exp Exp | |
| App Exp Exp | |
| Fun String Exp | |
| Let String Exp Exp | |
deriving (Eq, Show) | |
data Type | |
= TVar String | |
| TInt | |
| TPair Type Type | |
| TFun Type Type | |
deriving Eq | |
data Scheme = Forall (S.Set String) Type | |
deriving (Eq) | |
instance Show Type where | |
show (TVar x) = x | |
show TInt = "Int" | |
show (TPair t1 t2) = "(" ++ show t1 ++ ", " ++ show t2 ++ ")" | |
show (TFun tyf@(TFun _ _) ty) = "(" ++ show tyf ++ ")" ++ " -> " ++ show ty | |
show (TFun ty1 ty2) = show ty1 ++ " → " ++ show ty2 | |
instance Show Scheme where | |
show (Forall xs m) = "∀(" ++ unwords (S.toList xs) ++ "). " ++ show m | |
----------------------------------------------------------------------------------- | |
----------------------------------------------------------------------------------- | |
-- Parser | |
----------------------------------------------------------------------------------- | |
----------------------------------------------------------------------------------- | |
-- e = fact* | |
-- fact = n | x | fn(x) e | let x = e in e | (e) | |
type Parser = Parsec String () | |
keywords = ["fn", "let", "in"] | |
token :: Parser a -> Parser a | |
token p = p <* spaces | |
sym :: String -> Parser String | |
sym s = (token . try) (string s) | |
kw :: String -> Parser String | |
kw s = (token . try) (string s) | |
parens p = sym "(" *> p <* sym ")" | |
identifier = (token . try) ((many1 letter) >>= check) where | |
check s = if s `elem` keywords | |
then fail (s ++ " can't be a variable") | |
else return s | |
int = Int . read <$> token (many1 digit) | |
factor :: Parser Exp | |
factor = | |
int | |
<|> Var <$> identifier | |
<|> let_ | |
<|> fn | |
<|> parens (try pair <|> expr) | |
fn :: Parser Exp | |
fn = Fun <$> (kw "fn" *> parens identifier) <*> expr | |
let_ :: Parser Exp | |
let_ = Let <$> (kw "let" *> identifier) <*> (sym "=" *> expr) <*> (kw "in" *> expr) | |
pair :: Parser Exp | |
pair = Pair <$> expr <*> (sym "," *> expr) | |
expr :: Parser Exp | |
expr = foldApp <$> (many1 factor) where | |
foldApp (e:es) = foldl App e es | |
main = spaces *> expr <* eof | |
testParser :: Parser a -> String -> a | |
testParser p s = case parse p "Test" s of | |
Left err -> error (show err) | |
Right x -> x | |
----------------------------------------------------------------------------------- | |
----------------------------------------------------------------------------------- | |
-- Type checker | |
----------------------------------------------------------------------------------- | |
----------------------------------------------------------------------------------- | |
type Ctx = M.Map String Scheme | |
type Subst = M.Map String Type | |
type Ti = State (Int, Subst) | |
emptySub :: Subst | |
emptySub = M.empty | |
infixr 5 |> | |
(|>) :: Subst -> Subst ->Subst | |
s1 |> s2 = M.union s1 ((s1 #) <$> s2) | |
infix 5 # | |
class Types t where | |
ftv :: t -> S.Set String | |
(#) :: Subst -> t -> t | |
instance Types Type where | |
ftv (TVar x) = S.singleton x | |
ftv (TPair t1 t2) = S.union (ftv t1) (ftv t2) | |
ftv (TFun t1 t2) = S.union (ftv t1) (ftv t2) | |
ftv _ = S.empty | |
s # (TVar x) = fromMaybe (TVar x) $ M.lookup x s | |
s # (TPair t1 t2) = TPair (s # t1) (s # t2) | |
s # (TFun t1 t2) = TFun (s # t1) (s # t2) | |
s # ty = ty | |
instance Types Scheme where | |
ftv (Forall xs t) = S.difference (ftv t) xs | |
s # (Forall xs t) = | |
let s' = M.withoutKeys s xs | |
in Forall xs (s' # t) | |
instance Types Ctx where | |
ftv ctx = mconcat (ftv <$> M.elems ctx) | |
s # ctx = (s #) <$> ctx | |
fresh :: Ti Type | |
fresh = do | |
(i, s) <- get | |
put (i + 1, s) | |
return $ TVar ("a" ++ show i) | |
instantiate :: Ctx -> Scheme -> Ti Type | |
instantiate ctx (Forall xs t) = do | |
s <- sequenceA (M.fromSet (const fresh) xs) | |
return (s # t) | |
generalise :: Ctx -> Type -> Scheme | |
generalise ctx t = | |
let fvs = S.difference (ftv t) (ftv ctx) | |
in Forall fvs t | |
infer :: Ctx -> Exp -> Ti Type | |
infer ctx (Int _) = return TInt | |
infer ctx (Var x) = | |
case M.lookup x ctx of | |
Nothing -> error $ "Unkown variable " ++ x | |
Just sch -> instantiate ctx sch | |
infer ctx (Pair e1 e2) = do | |
t1 <- infer ctx e1 | |
t2 <- infer ctx e2 | |
return (TPair t1 t2) | |
infer ctx (Fun x e) = do | |
tv <- fresh | |
t <- infer (M.insert x (Forall S.empty tv) ctx) e | |
return (TFun tv t) | |
infer ctx (App e1 e2) = do | |
tv <- fresh | |
t1 <- infer ctx e1 | |
t2 <- infer ctx e2 | |
unifyM t1 (TFun t2 tv) | |
return tv | |
infer ctx (Let x e1 e2) = do | |
t1 <- infer ctx e1 | |
(_, s) <- get | |
let sch = generalise (s # ctx) (s # t1) | |
infer (M.insert x sch ctx) e2 | |
unifyM :: Type -> Type -> Ti () | |
unifyM ty1 ty2 = do | |
(_, s) <- get | |
unify (s # ty1) (s # ty2) | |
traceSub msg = do | |
(i,s) <- get | |
trace (msg ++ ": " ++ show s) (return ()) | |
unify :: Type -> Type -> Ti () | |
unify TInt TInt = return () | |
unify (TVar x) ty = varBind x ty | |
unify ty (TVar x) = varBind x ty | |
unify (TPair t1 t2) (TPair u1 u2) = do | |
unifyM t1 u1 | |
unifyM t2 u2 | |
unify (TFun t1 t2) (TFun u1 u2) = do | |
unifyM t1 u1 | |
unifyM t2 u2 | |
unify t1 t2 = error $ "Can't unify " ++ show t1 ++ " and " ++ show t2 | |
varBind :: String -> Type -> Ti () | |
varBind x ty | |
| ty == TVar x = return () | |
| x `S.member` (ftv ty) = error "Occurs check failed" | |
| otherwise = do | |
get >>= \(i, s) -> put (i, M.singleton x ty |> s) | |
typeof :: Exp -> Type | |
typeof e = | |
let (ty, (_, s)) = runState (infer M.empty e) (0, emptySub) | |
in s # ty | |
typeCheck :: String -> Type | |
typeCheck = typeof . testParser main |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment