Last active
April 9, 2023 16:00
-
-
Save DarinM223/44c7bf5d0a98232f6b01a7435a570810 to your computer and use it in GitHub Desktop.
Translation of `sound_eager.ml` to Haskell https://okmij.org/ftp/ML/generalization.html
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE ImplicitParams #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE OverloadedStrings #-} | |
module SoundEager where | |
import Control.Monad (unless) | |
import Control.Monad.Primitive (PrimMonad, PrimState) | |
import Control.Monad.ST (runST) | |
import Data.Char (chr, ord) | |
import Data.Functor ((<&>)) | |
import Data.Functor.Classes (Eq1, eq1, liftEq) | |
import Data.Functor.Identity (Identity (Identity), runIdentity) | |
import Data.Maybe (fromJust) | |
import Data.Primitive (MutVar, newMutVar, readMutVar, writeMutVar) | |
import Data.Primitive.PrimVar (PrimVar, modifyPrimVar, newPrimVar, readPrimVar, writePrimVar) | |
import Data.Text (Text) | |
import Data.Text qualified as T | |
import Unsafe.Coerce (unsafeCoerce) | |
type Varname = Text | |
data Exp | |
= Var Varname | |
| App Exp Exp | |
| Lam Varname Exp | |
| Let Varname Exp Exp | |
type Qname = Text | |
type Level = Int | |
data Typ r | |
= TVar (r (Tv r)) | |
| QVar Qname | |
| TArrow (Typ r) (Typ r) | |
instance Show (Typ Identity) where | |
show (TVar (Identity tv)) = "TVar (" ++ show tv ++ ")" | |
show (QVar name) = "QVar " ++ T.unpack name | |
show (TArrow ty1 ty2) = show ty1 ++ " -> " ++ show ty2 | |
instance (Eq1 r) => Eq (Typ r) where | |
(==) (TVar a) (TVar b) = eq1 a b | |
(==) (QVar a) (QVar b) = a == b | |
(==) (TArrow a1 b1) (TArrow a2 b2) = a1 == a2 && b1 == b2 | |
(==) _ _ = False | |
data Tv r = Unbound Text Level | Link (Typ r) deriving (Eq) | |
instance Show (Tv Identity) where | |
show (Unbound name level) = | |
"Unbound (" ++ T.unpack name ++ " " ++ show level ++ ")" | |
show (Link ty) = "Link (" ++ show ty ++ ")" | |
type IntRef m = PrimVar (PrimState m) Int | |
newtype Ref m a = Ref {unRef :: MutVar (PrimState m) a} deriving (Eq) | |
instance Eq1 (Ref m) where | |
liftEq _ (Ref a) (Ref b) = a == unsafeCoerce b | |
readRef :: (PrimMonad m) => Ref m a -> m a | |
readRef = readMutVar . unRef | |
{-# INLINE readRef #-} | |
writeRef :: (PrimMonad m) => Ref m a -> a -> m () | |
writeRef ref = writeMutVar (unRef ref) | |
{-# INLINE writeRef #-} | |
transformRef :: | |
(Monad m) => | |
(Tv r' -> m (r' (Tv r'))) -> | |
(r (Tv r) -> m (Tv r)) -> | |
Typ r -> | |
m (Typ r') | |
transformRef constr f (TVar ref) = TVar <$> (f ref >>= goTv >>= constr) | |
where | |
goTv (Unbound name level) = pure $ Unbound name level | |
goTv (Link typ) = Link <$> transformRef constr f typ | |
transformRef _ _ (QVar name) = pure $ QVar name | |
transformRef constr f (TArrow ty1 ty2) = | |
TArrow <$> transformRef constr f ty1 <*> transformRef constr f ty2 | |
toIdentity' :: (PrimMonad m) => Typ (Ref m) -> m (Typ Identity) | |
toIdentity' = transformRef (pure . Identity) readRef | |
toRef :: (PrimMonad m) => Typ Identity -> m (Typ (Ref m)) | |
toRef = transformRef (fmap Ref . newMutVar) (pure . runIdentity) | |
toIdentity :: (PrimMonad m) => Typ (Ref m) -> m (Typ Identity) | |
toIdentity (TVar ref) = TVar . Identity <$> (readRef ref >>= goTv) | |
where | |
goTv (Unbound name level) = pure $ Unbound name level | |
goTv (Link typ) = Link <$> toIdentity typ | |
toIdentity (QVar name) = pure $ QVar name | |
toIdentity (TArrow ty1 ty2) = TArrow <$> toIdentity ty1 <*> toIdentity ty2 | |
gensym :: (PrimMonad m) => (?gensym :: IntRef m) => m Text | |
gensym = do | |
n <- readPrimVar ?gensym | |
writePrimVar ?gensym (n + 1) | |
if n < 26 | |
then pure $ T.singleton (chr (ord 'a' + n)) | |
else pure $ "t" <> T.pack (show n) | |
enterLevel :: (PrimMonad m) => (?level :: IntRef m) => m () | |
enterLevel = modifyPrimVar ?level (+ 1) | |
leaveLevel :: (PrimMonad m) => (?level :: IntRef m) => m () | |
leaveLevel = modifyPrimVar ?level (subtract 1) | |
type Constr m = (?gensym :: IntRef m, ?level :: IntRef m) | |
newVar :: (PrimMonad m, Constr m) => m (Typ (Ref m)) | |
newVar = do | |
tv <- Unbound <$> gensym <*> readPrimVar ?level | |
TVar . Ref <$> newMutVar tv | |
occurs :: (PrimMonad m) => Ref m (Tv (Ref m)) -> Typ (Ref m) -> m () | |
occurs tvr = \case | |
TVar tvr' | tvr == tvr' -> error "Occurs check" | |
TVar tv -> do | |
readRef tv >>= \case | |
Unbound name l' -> do | |
minLevel <- readRef tvr <&> \case Unbound _ l -> min l l'; _ -> l' | |
writeRef tv (Unbound name minLevel) | |
Link ty -> occurs tvr ty | |
TArrow t1 t2 -> occurs tvr t1 >> occurs tvr t2 | |
_ -> pure () | |
unify :: (PrimMonad m) => Typ (Ref m) -> Typ (Ref m) -> m () | |
unify t1 t2 = unless (t1 == t2) $ do | |
(tv1, tv2) <- (,) <$> getTv t1 <*> getTv t2 | |
case (tv1, tv2, t1, t2) of | |
(Just Unbound {}, _, TVar tv, t') -> occurs tv t' >> writeRef tv (Link t') | |
(_, Just Unbound {}, t', TVar tv) -> occurs tv t' >> writeRef tv (Link t') | |
(Just (Link t1'), _, _, t2') -> unify t1' t2' | |
(_, Just (Link t2'), t1', _) -> unify t1' t2' | |
(_, _, TArrow tyl1 tyl2, TArrow tyr1 tyr2) -> | |
unify tyl1 tyr1 >> unify tyl2 tyr2 | |
_ -> error "Invalid types for unification" | |
where | |
getTv = \case TVar ref -> Just <$> readMutVar (unRef ref); _ -> pure Nothing | |
type Env m = [(Varname, Typ (Ref m))] | |
gen :: (PrimMonad m, Constr m) => Typ (Ref m) -> m (Typ (Ref m)) | |
gen (TVar ref) = | |
readRef ref >>= \case | |
Unbound name l -> | |
readPrimVar ?level <&> \currLevel -> | |
if l > currLevel then QVar name else TVar ref | |
Link ty -> gen ty | |
gen (TArrow ty1 ty2) = TArrow <$> gen ty1 <*> gen ty2 | |
gen ty = pure ty | |
inst :: (PrimMonad m, Constr m) => Typ (Ref m) -> m (Typ (Ref m)) | |
inst = fmap fst . go [] | |
where | |
go sub (QVar name) = case lookup name sub of | |
Just ty -> pure (ty, sub) | |
Nothing -> (\ty -> (ty, (name, ty) : sub)) <$> newVar | |
go sub (TVar ref) = | |
readRef ref >>= \case | |
Link ty -> go sub ty | |
Unbound {} -> pure (TVar ref, sub) | |
go sub (TArrow ty1 ty2) = do | |
(ty1', sub') <- go sub ty1 | |
(ty2', sub'') <- go sub' ty2 | |
pure (TArrow ty1' ty2', sub'') | |
typeof :: (PrimMonad m, Constr m) => Env m -> Exp -> m (Typ (Ref m)) | |
typeof env (Var x) = inst $ fromJust $ lookup x env | |
typeof env (Lam x e) = do | |
tyX <- newVar | |
TArrow tyX <$> typeof ((x, tyX) : env) e | |
typeof env (App fun arg) = do | |
tyFun <- typeof env fun | |
tyArg <- typeof env arg | |
tyRes <- newVar | |
tyRes <$ unify tyFun (TArrow tyArg tyRes) | |
typeof env (Let x e rest) = do | |
tyE <- enterLevel *> typeof env e <* leaveLevel | |
tyE' <- gen tyE | |
typeof ((x, tyE') : env) rest | |
runInfer :: (PrimMonad m) => ((Constr m) => m a) -> m a | |
runInfer f = do | |
gensym_ <- newPrimVar 0 | |
level_ <- newPrimVar 1 | |
let ?gensym = gensym_; ?level = level_ in f | |
testAlg :: Exp -> Typ Identity | |
testAlg e = runST (runInfer (typeof [] e) >>= toIdentity) | |
test1 :: Typ Identity | |
test1 = testAlg $ Lam "x" (Var "x") | |
test2 :: Typ Identity | |
test2 = testAlg $ Lam "x" (Lam "y" (App (Var "x") (Var "y"))) | |
testOccurs :: Typ Identity | |
testOccurs = testAlg $ Lam "y" $ App (Var "y") (Lam "z" (App (Var "y") (Var "z"))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment