Skip to content

Instantly share code, notes, and snippets.

@kccqzy
Created October 23, 2024 13:54
Show Gist options
  • Save kccqzy/d761b8adc840333af0303e1b822d769f to your computer and use it in GitHub Desktop.
Save kccqzy/d761b8adc840333af0303e1b822d769f to your computer and use it in GitHub Desktop.
Quick implementation of SimpleSub (please see the paper "The Simple Essence of Algebraic Subtyping")
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import Control.Monad.Trans.State
import Control.Monad.Trans.Writer.CPS
import Data.Foldable
import qualified Data.IntMap as IM
import qualified Data.Map as M
import qualified Data.Map.Merge.Lazy as MM
import qualified Data.Sequence as Seq
import qualified Data.Set as Set
import Data.String
data Exp
= EVar String
| ELit Lit
| EApp Exp Exp
| EAbs String Exp
| ELet String Exp Exp
| ELetRec String Exp Exp
| EPlusFun
| EIfThenElse
| EIntoTuple (M.Map String Exp)
| ETupleMember Exp String
deriving (Show)
plus :: Exp -> Exp -> Exp
plus e1 = EApp (EApp EPlusFun e1)
ifThenElse :: Exp -> Exp -> Exp -> Exp
ifThenElse i t = EApp (EApp (EApp EIfThenElse i) t)
instance IsString Exp where
fromString = EVar
data Lit
= LInt Integer
| LBool Bool
deriving (Show)
data Type
= TVar Int
| TInt
| TBool
| TFun Type Type
| TTuple (M.Map String Type)
deriving (Eq, Ord, Show)
data PolyType = PolyType Int Type
data TypeScheme = IsPoly PolyType | IsSimple Type
type TypeEnv = M.Map String TypeScheme
data VarConstraint = VarConstraint
{ vcLowerBounds, vcUpperBounds :: [Type],
vcLevel :: Int
}
deriving (Show)
type TIState = Seq.Seq VarConstraint
data TIError
= ErrorTypeConstrain Type Type
| ErrorUnboundVariable String
| ErrorExpectedPresenceOfField String Type
| ErrorIntersectionTypeImpossible DisplayType DisplayType
deriving (Show)
type TypeCheck = ExceptT TIError (State TIState)
runTI :: TypeCheck a -> Either TIError a
runTI t = evalState (runExceptT t) mempty
newTyVar :: Int -> TypeCheck Int
newTyVar level = do
im <- lift get
let new = VarConstraint [] [] level
k = Seq.length im
lift (modify (Seq.|> new))
pure k
newTyVarType :: Int -> TypeCheck Type
newTyVarType level = TVar <$> newTyVar level
typeVarName :: Int -> String
typeVarName v = go v []
where
go c = case c `divMod` 26 of
(0, r) -> (toEnum (97 + r) :)
(n, r) -> go (n - 1) . (toEnum (97 + r) :)
updateConstraints :: (VarConstraint -> VarConstraint) -> Int -> TypeCheck ()
updateConstraints f v = lift (modify (Seq.adjust' f v))
getConstraints :: Int -> TypeCheck VarConstraint
getConstraints v = lift (gets (`Seq.index` v))
getTypeLevel :: Type -> TypeCheck Int
getTypeLevel = \case
TInt -> pure 0
TBool -> pure 0
TFun a b -> liftA2 max (getTypeLevel a) (getTypeLevel b)
TVar v -> vcLevel <$> getConstraints v
TTuple fs -> do
levels <- traverse getTypeLevel fs
pure (if null levels then 0 else maximum levels)
instantiate :: Int -> TypeScheme -> TypeCheck Type
instantiate _ (IsSimple simpleType) = pure simpleType
instantiate lvl (IsPoly (PolyType level body)) = evalStateT (freshenAbove body) mempty
where
freshenAbove :: Type -> StateT (IM.IntMap Int) TypeCheck Type
freshenAbove b = case b of
TInt -> pure b
TBool -> pure b
TVar v -> do
curLevel <- lift (getTypeLevel b)
if curLevel > level
then do
alreadyFreshened <- gets (IM.lookup v)
case alreadyFreshened of
Nothing -> do
newVar <- lift (newTyVar lvl)
let res = TVar newVar
modify (IM.insert v newVar)
con <- lift (getConstraints v)
newUpper <- traverse freshenAbove (vcUpperBounds con)
newLower <- traverse freshenAbove (vcLowerBounds con)
lift (updateConstraints (\c -> c {vcUpperBounds = newUpper, vcLowerBounds = newLower}) newVar)
pure res
Just res -> pure (TVar res)
else pure b
TFun l r -> liftA2 TFun (freshenAbove l) (freshenAbove r)
TTuple fs -> TTuple <$> traverse freshenAbove fs
-- | @typeConstrain@ unifies two types such that the first type is equal or a
-- subtype of the second type.
typeConstrain :: Type -> Type -> TypeCheck ()
typeConstrain l r = evalStateT (typeConstrainWithCache l r) mempty
typeConstrainWithCache :: Type -> Type -> StateT (Set.Set (Type, Type)) TypeCheck ()
typeConstrainWithCache l r = do
alreadyConstrained <- gets (Set.member (l, r))
unless alreadyConstrained $ do
modify (Set.insert (l, r))
leftLevel <- lift (getTypeLevel l)
rightLevel <- lift (getTypeLevel r)
case (l, r) of
(TInt, TInt) -> pure ()
(TBool, TBool) -> pure ()
(TFun l1 r1, TFun l2 r2) -> do
typeConstrainWithCache l2 l1
typeConstrainWithCache r1 r2
(TTuple tt1, TTuple tt2) ->
let throwFieldPresence = MM.traverseMissing (\name _ -> lift (throwE (ErrorExpectedPresenceOfField name l)))
performConstrain = MM.zipWithMaybeAMatched (\_ t1 t2 -> typeConstrainWithCache t1 t2 >> pure Nothing)
in void (MM.mergeA MM.dropMissing throwFieldPresence performConstrain tt1 tt2)
(TVar x, TVar y) -> do
if rightLevel <= leftLevel
then constrainLeftVariable x
else constrainRightVariable y
(TVar x, _) -> do
if rightLevel <= leftLevel
then constrainLeftVariable x
else extrudeRight leftLevel
(_, TVar y) -> do
if leftLevel <= rightLevel
then constrainRightVariable y
else extrudeLeft rightLevel
_ -> lift (throwE (ErrorTypeConstrain l r))
where
constrainLeftVariable x = do
lift (updateConstraints (\c -> c {vcUpperBounds = r : vcUpperBounds c}) x)
lower <- vcLowerBounds <$> lift (getConstraints x)
forM_ lower (`typeConstrainWithCache` r)
constrainRightVariable y = do
lift (updateConstraints (\c -> c {vcLowerBounds = l : vcLowerBounds c}) y)
upper <- vcUpperBounds <$> lift (getConstraints y)
forM_ upper (l `typeConstrainWithCache`)
extrudeRight leftLevel = do
newR <- lift (extrude leftLevel r False)
typeConstrainWithCache l newR
extrudeLeft rightLevel = do
newL <- lift (extrude rightLevel l True)
typeConstrainWithCache newL r
-- | @extrude@ makes a copy of a variable with a type at a higher level such
-- that the new copy is at the requested level and is either a subtype or
-- supertype (depending on the polarity) of the original type.
extrude :: Int -> Type -> Bool -> TypeCheck Type
extrude level ty pol = evalStateT (extrudeWithCache pol ty) mempty
where
extrudeWithCache :: Bool -> Type -> StateT (IM.IntMap Type) TypeCheck Type
extrudeWithCache p t = do
originalLevel <- lift (getTypeLevel t)
if originalLevel <= level
then pure t
else case t of
TInt -> pure t
TBool -> pure t
TFun l r -> liftA2 TFun (extrudeWithCache (not p) l) (extrudeWithCache p r)
TTuple fs -> TTuple <$> traverse (extrudeWithCache p) fs
TVar v -> do
alreadyExtruded <- gets (IM.lookup v)
case alreadyExtruded of
Just result -> pure result
Nothing -> do
nvs <- lift (newTyVar level)
let result = TVar nvs
modify (IM.insert v result)
if pol
then do
lift (updateConstraints (\c -> c {vcUpperBounds = result : vcUpperBounds c}) v)
origLower <- vcLowerBounds <$> lift (getConstraints v)
newLower <- traverse (extrudeWithCache p) origLower
lift (updateConstraints (\c -> c {vcLowerBounds = newLower}) nvs)
else do
lift (updateConstraints (\c -> c {vcLowerBounds = result : vcLowerBounds c}) v)
origUpper <- vcUpperBounds <$> lift (getConstraints v)
newUpper <- traverse (extrudeWithCache p) origUpper
lift (updateConstraints (\c -> c {vcUpperBounds = newUpper}) nvs)
pure result
-- | @inferType@ performs type inference for an expression.
inferType :: TypeEnv -> Int -> Exp -> TypeCheck Type
inferType env level (EVar n) =
case M.lookup n env of
Nothing -> throwE (ErrorUnboundVariable n)
Just t -> instantiate level t
inferType _ _ (ELit (LInt _)) = pure TInt
inferType _ _ (ELit (LBool _)) = pure TBool
inferType env level (EIntoTuple fs) = do
ts <- traverse (inferType env level) fs
pure (TTuple ts)
inferType env level (EAbs n body) = do
paramType <- newTyVarType level
let newEnv = M.insert n (IsSimple paramType) env
bodyType <- inferType newEnv level body
pure (TFun paramType bodyType)
inferType env level (EApp e1 e2) = do
resultType <- newTyVarType level
funType <- inferType env level e1
argType <- inferType env level e2
typeConstrain funType (TFun argType resultType)
pure resultType
inferType env level (ETupleMember tp mem) = do
resultType <- newTyVarType level
tpType <- inferType env level tp
typeConstrain tpType (TTuple (M.singleton mem resultType))
pure resultType
inferType _ _ EPlusFun = pure (TFun TInt (TFun TInt TInt))
inferType _ level EIfThenElse = do
resultType <- newTyVarType level
pure (TFun TBool (TFun resultType (TFun resultType resultType)))
inferType env level (ELet x e1 e2) = do
bodyType <- inferType env (level + 1) e1
inferType (M.insert x (IsPoly (PolyType level bodyType)) env) level e2
inferType env level (ELetRec x e1 e2) = do
liftedType <- newTyVarType (level + 1)
bodyType <- inferType (M.insert x (IsSimple liftedType) env) (level + 1) e1
typeConstrain bodyType liftedType
inferType (M.insert x (IsPoly (PolyType level liftedType)) env) level e2
data DisplayType
= -- | A type at the top of the hierarchy. It is the type of everything.
DTTop
| -- | A type at the bottom of the hierarchy. It is the type of nothing, i.e. it diverges.
DTBottom
| -- | A type that is either of the specified types.
DTUnion (Set.Set DisplayType)
| -- | A type that is all of the specified types.
DTIntersection (Set.Set DisplayType)
| -- | A function from one type to another.
DTFunction DisplayType DisplayType
| -- | A named tuple type where the fields are individually typed.
DTTuple (M.Map String DisplayType)
| -- | A recursive type using the specified variable. Whenever this variable occurs in the type, it can be replaced by the type itself as the one-step unroll.
DTRecursive String DisplayType
| -- | A type variable.
DTVar String
| DTInt
| DTBool
deriving (Show, Eq, Ord)
type VarWithPolarity = (Int, Bool)
typeToDisplayType :: Type -> TypeCheck DisplayType
typeToDisplayType ty = removePolarVariables <$> evalStateT (go True mempty ty) mempty
where
go :: Bool -> Set.Set VarWithPolarity -> Type -> StateT (M.Map VarWithPolarity String) TypeCheck DisplayType
go polar inProcess = \case
TInt -> pure DTInt
TBool -> pure DTBool
TFun l r -> liftA2 DTFunction (go (not polar) inProcess l) (go polar inProcess r)
TTuple fs -> DTTuple <$> traverse (go polar inProcess) fs
TVar v -> do
let vpol = (v, polar)
if Set.member vpol inProcess
then do
recursive <- get
let alterer = \case
Nothing -> do
freshVar <- lift (newTyVar (error "no level for fresh var in recursive type"))
pure (Just (typeVarName freshVar))
Just x -> pure (Just x)
newRecursive <- M.alterF alterer vpol recursive
put newRecursive
pure (DTVar (newRecursive M.! vpol))
else do
cons <- lift (getConstraints v)
let bounds = if polar then vcLowerBounds cons else vcUpperBounds cons
merge = if polar then mergeWithUnion else mergeWithIntersection
finish = if polar then finishMerge DTUnion else finishMerge DTIntersection
newInProcess = Set.insert vpol inProcess
boundTypes <- traverse (go polar newInProcess) bounds
res <- lift (foldlM merge (Set.singleton (DTVar (typeVarName v))) boundTypes)
recursive <- get
pure $ case M.lookup vpol recursive of
Nothing -> finish res
Just x -> DTRecursive x (finish res)
mergeWithUnion :: Set.Set DisplayType -> DisplayType -> TypeCheck (Set.Set DisplayType)
mergeWithUnion s t =
case t of
DTUnion ss -> pure (s <> ss)
_ -> pure (Set.insert t s)
mergeWithIntersection :: Set.Set DisplayType -> DisplayType -> TypeCheck (Set.Set DisplayType)
mergeWithIntersection s t =
case t of
DTIntersection ss -> pure (s <> ss)
DTInt | Set.member DTBool s -> throwE (ErrorIntersectionTypeImpossible DTInt DTBool)
DTBool | Set.member DTInt s -> throwE (ErrorIntersectionTypeImpossible DTInt DTBool)
DTFunction {} | Set.member DTBool s -> throwE (ErrorIntersectionTypeImpossible t DTBool)
DTFunction {} | Set.member DTInt s -> throwE (ErrorIntersectionTypeImpossible t DTInt)
-- TODO
_ -> pure (Set.insert t s)
finishMerge :: (Set.Set a -> a) -> Set.Set a -> a
finishMerge op s =
case Set.minView s of
Nothing -> error "impossible"
Just (t, rest)
| Set.null rest -> t
| otherwise -> op s
-- | @removePolarVariables@ removes those variables that only appear in a
-- positive or negative position. To do that, it makes two passes over
-- structure: first, it collects all variables and their positions with polarity
removePolarVariables :: DisplayType -> DisplayType
removePolarVariables dt = removeVariables False dt
where
removeVariables :: Bool -> DisplayType -> DisplayType
removeVariables polar t = case t of
DTInt -> t
DTBool -> t
DTTop -> t
DTBottom -> t
DTRecursive v innerType ->
let newInner = removeVariables polar innerType
in if varOccurs v newInner then DTRecursive v newInner else newInner
DTVar v
| Set.member v uselessVariables -> if polar then DTTop else DTBottom
| otherwise -> t
DTFunction a b -> DTFunction (removeVariables (not polar) a) (removeVariables polar b)
DTTuple tts -> DTTuple (removeVariables polar <$> tts)
DTIntersection dts ->
-- We can remove Top. If bottom is present, we make it the result.
let dts' = Set.delete DTTop (Set.map (removeVariables polar) dts)
in if Set.member DTBottom dts' then DTBottom else finishMerge DTIntersection dts'
DTUnion dts ->
let dts' = Set.delete DTBottom (Set.map (removeVariables polar) dts)
in if Set.member DTTop dts' then DTTop else finishMerge DTUnion dts'
uselessVariables =
let (a, b) = execWriter (collectAllVariables False dt)
in Set.difference a b <> Set.difference b a
collectAllVariables :: Bool -> DisplayType -> Writer (Set.Set String, Set.Set String) ()
collectAllVariables polar = \case
DTUnion dts -> traverse_ (collectAllVariables polar) dts
DTIntersection dts -> traverse_ (collectAllVariables polar) dts
DTTuple tts -> traverse_ (collectAllVariables polar) tts
DTInt -> pure ()
DTBool -> pure ()
DTTop -> pure ()
DTBottom -> pure ()
DTFunction a b -> do
collectAllVariables (not polar) a
collectAllVariables polar b
DTVar v ->
tell $
if polar then (mempty, Set.singleton v) else (Set.singleton v, mempty)
DTRecursive v t -> do
-- I think we can just do a one-step unroll here.
collectAllVariables polar t
collectAllVariables polar (unrollRecursiveOnce v t t)
unrollRecursiveOnce v orig t = case t of
DTRecursive {} -> error "unimplemented"
DTInt -> t
DTBool -> t
DTTop -> t
DTBottom -> t
DTVar v'
| v == v' -> orig
| otherwise -> t
DTFunction a b -> DTFunction (unrollRecursiveOnce v orig a) (unrollRecursiveOnce v orig b)
DTUnion dts -> DTUnion (Set.map (unrollRecursiveOnce v orig) dts)
DTIntersection dts -> DTIntersection (Set.map (unrollRecursiveOnce v orig) dts)
DTTuple tts -> DTTuple (unrollRecursiveOnce v orig <$> tts)
varOccurs :: String -> DisplayType -> Bool
varOccurs v = \case
DTInt -> False
DTBool -> False
DTTop -> False
DTBottom -> False
DTFunction a b -> varOccurs v a || varOccurs v b
DTRecursive _ t -> varOccurs v t
DTVar v' -> v == v'
DTTuple tts -> any (varOccurs v) tts
DTUnion dts -> any (varOccurs v) dts
DTIntersection dts -> any (varOccurs v) dts
inferDisplayType :: Exp -> ExceptT TIError (State TIState) DisplayType
inferDisplayType = inferType mempty 0 >=> typeToDisplayType
examples :: [Exp]
examples =
[ EAbs "x" (EVar "x"),
EAbs "f" (EAbs "x" (EApp (EVar "f") (EApp (EVar "f") (EVar "x")))),
ELet
"id"
(EAbs "x" (EVar "x"))
(EVar "id"),
ELet
"id"
(EAbs "x" (EVar "x"))
(EApp (EVar "id") (EVar "id")),
ELet
"id"
(EAbs "x" (ELet "y" (EVar "x") (EVar "y")))
(EApp (EVar "id") (EVar "id")),
ELet
"id"
(EAbs "x" (ELet "y" (EVar "x") (EVar "y")))
(EApp (EApp (EVar "id") (EVar "id")) (ELit (LInt 2))),
EAbs "x" (EApp (EVar "x") (EVar "x")),
ELet
"wrong"
(EAbs "x" (EApp (EVar "x") (EVar "x")))
(EVar "wrong"),
ELet
"wrong2"
(EAbs "x" (EApp (EApp (EVar "x") (EVar "x")) (EVar "x")))
(EVar "wrong2"),
EAbs
"m"
( ELet
"y"
(EVar "m")
( ELet
"x"
(EApp (EVar "y") (ELit (LBool True)))
(EVar "x")
)
),
EApp (ELit (LInt 2)) (ELit (LInt 2)),
ELet "id" (EAbs "x" (EVar "x")) (EApp (EVar "id") (ELit (LInt 2))),
ELet
"omega"
(EApp (EAbs "x" (EApp (EVar "x") (EVar "x"))) (EAbs "x" (EApp (EVar "x") (EVar "x"))))
(EVar "omega"),
EAbs "f" (EAbs "x" (EApp (EVar "f") (EApp (EVar "f") (EVar "x")))),
ELet
"plusOne"
(EAbs "x" (plus (ELit (LInt 1)) (EVar "x")))
"plusOne",
ELet
"plusOne"
(EAbs "x" (plus (ELit (LInt 1)) (EVar "x")))
(ELet "two" (ELit (LInt 2)) (EApp (EVar "plusOne") (EVar "two"))),
ELet
"plusplus"
(EAbs "x" (EAbs "y" (plus "x" (plus "y" "x"))))
(ELet "two" (ELit (LInt 2)) (EApp (EVar "plusplus") (EVar "two"))),
let f = EAbs "x" (plus (ELit (LInt 1)) (EVar "x"))
z = ELit (LInt 0)
in ELet "church2" (EAbs "f" (EAbs "x" (EApp (EVar "f") (EApp (EVar "f") (EVar "x"))))) (EApp (EApp (EVar "church2") f) z),
ELetRec "loop" (EAbs "x" (plus "x" (EApp "loop" "x"))) "loop",
ELetRec "loop" (EAbs "x" (plus "x" (EApp "loop" "x"))) (EApp "loop" (ELit (LInt 42))),
EAbs "intAndBool" (EIntoTuple (M.fromList [("int", plus "intAndBool" (ELit (LInt 1))), ("bool", ifThenElse "intAndBool" (ELit (LInt 1)) (ELit (LInt 0)))]))
]
tupleExamples :: [Exp]
tupleExamples =
[ EIntoTuple (M.fromList [("x", ELit (LBool True)), ("y", ELit (LInt 42))]),
ELet "x" (ELit (LBool True)) $
ELet "y" (ELit (LInt 42)) $
EIntoTuple (M.fromList [("x", EVar "x"), ("y", EVar "y")]),
ELet "x" (ELit (LBool True)) $
ELet "y" (ELit (LInt 42)) $
ELet "y" (plus (EVar "y") (EVar "y")) $
EIntoTuple (M.fromList [("x", EVar "x"), ("y", EVar "y")]),
plus (ELit (LInt 1)) $
ETupleMember
( EIntoTuple
( M.fromList
[ ("a", ELit (LInt 2)),
("b", ELit (LInt 3))
]
)
)
"b",
ELet "a" (EIntoTuple (M.fromList [("b", EIntoTuple (M.fromList [("c", ELit (LInt 1))]))])) $
ELet "ret" (ETupleMember "a" "b") "ret",
ELet "x" (ELit (LInt 1)) $
ELet "inner" (ELet "y" (plus "x" (ELit (LInt 2))) (EIntoTuple (M.fromList [("y", "y")]))) $
EIntoTuple (M.fromList [("x", "x"), ("inner", "inner")]),
EAbs "a" (EIntoTuple (M.fromList [("a", "a"), ("r", plus "a" "a")])),
EApp (EAbs "a" (EIntoTuple (M.fromList [("a", "a"), ("r", plus "a" "a")]))) (ELit (LInt 42)),
ELet "c" (EAbs "a" (EIntoTuple (M.fromList [("a", "a"), ("r", plus "a" "a")]))) $
EApp "c" (ELit (LInt 42)),
EAbs "x" (EIntoTuple (M.fromList [("this", "x"), ("more", plus "x" (ELit (LInt 1)))])),
EAbs "f" (EAbs "x" (EIntoTuple (M.fromList [("zero", "x"), ("one", EApp "f" "x")]))),
ELetRec "f" (EAbs "x" (EIntoTuple (M.fromList [("zero", "x"), ("one", EApp "f" "x")]))) "f",
ELet "t" (EAbs "a" (EAbs "b" "a")) $
ELet "f" (EAbs "a" (EAbs "b" "b")) $
ELet "and" (EAbs "p" (EAbs "q" (EApp (EApp "p" "q") "f"))) $
EIntoTuple
( M.fromList
[ ("trueAndFalse", EApp (EApp "and" "t") "f"),
("trueAndTrue", EApp (EApp "and" "t") "t"),
("falseAndTrue", EApp (EApp "and" "f") "t"),
("falseAndFalse", EApp (EApp "and" "f") "f")
]
),
ELet "t" (EAbs "a" (EAbs "b" "a")) $
ELet "f" (EAbs "a" (EAbs "b" "b")) $
ELet "and" (EAbs "p" (EAbs "q" (EApp (EApp "p" "q") "f"))) $
let transform x = EApp (EApp x (ELit (LBool True))) (ELit (LBool False))
in EIntoTuple
( transform
<$> M.fromList
[ ("trueAndFalse", EApp (EApp "and" "t") "f"),
("trueAndTrue", EApp (EApp "and" "t") "t"),
("falseAndTrue", EApp (EApp "and" "f") "t"),
("falseAndFalse", EApp (EApp "and" "f") "f")
]
)
]
test :: Exp -> IO ()
test e =
case runTI (inferDisplayType e) of
Left err -> putStrLn $ show e ++ "\n " ++ show err ++ "\n"
Right t -> putStrLn $ show e ++ " :: " ++ show t ++ "\n"
main :: IO ()
main = mapM_ test (examples ++ tupleExamples)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment