Created
August 3, 2017 13:28
-
-
Save jozefg/92c73b378a71c0d6dc828b01b1d0d4a0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE PatternGuards #-} | |
module Unification where | |
import Control.Monad | |
import Control.Monad.Gen | |
import Control.Monad.Trans | |
import qualified Data.Map.Strict as M | |
import Data.Foldable | |
import Data.List (foldl') | |
import Data.Monoid | |
import qualified Data.Set as S | |
type Id = Int | |
type Index = Int | |
data Term = FreeVar Id Term | |
| LocalVar Index | |
| MetaVar Id Term | |
| Constant Id Term | |
| Uni | |
| Ap Term Term | |
| Lam Term Term | |
| Pi Term Term | |
deriving (Eq, Show, Ord) | |
type Constraint = (Term, Term) | |
data Def = Axiom | Def Term | |
type Env = M.Map Id Def | |
raise :: Int -> Term -> Term | |
raise = go 0 | |
where go lower i t = case t of | |
FreeVar i t -> FreeVar i t | |
LocalVar j -> if i >= lower then LocalVar (i + j) else LocalVar j | |
MetaVar i t -> MetaVar i t | |
Constant id tp -> Constant id tp | |
Uni -> Uni | |
Ap l r -> go lower i l `Ap` go lower i r | |
Lam tp body -> Lam (go lower i tp) (go (lower + 1) i body) | |
Pi tp body -> Pi (go lower i tp) (go (lower + 1) i body) | |
subst :: Term -> Int -> Term -> Term | |
subst new i t = case t of | |
FreeVar i t -> FreeVar i t | |
LocalVar j -> if i == j then new else LocalVar j | |
MetaVar i t -> MetaVar i t | |
Constant id tp -> Constant id tp | |
Uni -> Uni | |
Ap l r -> subst new i l `Ap` subst new i r | |
Lam tp body -> Lam (subst new i tp) (subst (raise 1 new) i body) | |
Pi tp body -> Pi (subst new i tp) (subst (raise 1 new) i body) | |
substMV :: Term -> Id -> Term -> Term | |
substMV new i t = case t of | |
FreeVar i t -> FreeVar i (substMV new i t) | |
LocalVar i -> LocalVar i | |
MetaVar j t -> if i == j then new else MetaVar j (substMV new i t) | |
Constant id tp -> Constant id tp | |
Uni -> Uni | |
Ap l r -> substMV new i l `Ap` substMV new i r | |
Lam tp body -> Lam (substMV new i tp) (substMV (raise 1 new) i body) | |
Pi tp body -> Pi (substMV new i tp) (substMV (raise 1 new) i body) | |
metavars :: Term -> S.Set Id | |
metavars t = case t of | |
FreeVar i t -> metavars t | |
LocalVar i -> S.empty | |
MetaVar j t -> S.insert j (metavars t) | |
Constant id tp -> S.empty | |
Uni -> S.empty | |
Ap l r -> metavars l <> metavars r | |
Lam tp body -> metavars tp <> metavars body | |
Pi tp body -> metavars tp <> metavars body | |
reduce :: Env -> Term -> Term | |
reduce env t = case t of | |
FreeVar i t -> FreeVar i t | |
LocalVar j -> LocalVar j | |
MetaVar i t -> MetaVar i t | |
Constant id tp -> case M.lookup id env of | |
Just (Def t) -> reduce env t | |
Just Axiom -> Constant id tp | |
Nothing -> Constant id tp | |
Uni -> Uni | |
Ap l r -> case reduce env l of | |
Lam tp body -> reduce env (subst r 0 body) | |
l' -> Ap l' r | |
Lam tp body -> Lam tp body | |
Pi tp body -> Pi tp body | |
type FreshM = GenT Id Maybe | |
typeOf :: Term -> FreshM (Term, S.Set Constraint) | |
typeOf t = do | |
(tp, cs) <- go [] t | |
return (tp, S.filter (uncurry (/=)) cs) | |
where go env t = case t of | |
FreeVar _ tp -> return (tp, S.empty) | |
LocalVar j -> do | |
guard (length env > j) | |
return (env !! j, S.empty) | |
MetaVar _ tp -> return (tp, S.empty) | |
Constant _ tp -> return (tp, S.empty) | |
Uni -> return $ (Uni, S.empty) | |
Ap l r -> go env l >>= \case | |
(Pi from to, cs1) -> do | |
(from', cs2) <- go env r | |
return (subst r 0 to, cs1 <> cs2 <> S.singleton (from, from')) | |
(fTp, cs1) -> do | |
(from', cs2) <- go env r | |
from <- MetaVar <$> gen <*> return Uni | |
to <- MetaVar <$> gen <*> return (Pi from Uni) | |
return (subst r 0 to, | |
cs1 <> | |
cs2 <> | |
S.fromList [(fTp, Pi from (Ap to (LocalVar 0))), | |
(from, from')]) | |
Pi l r -> do | |
(tp1, cs1) <- go env l | |
(tp2, cs2) <- go (l : env) r | |
return (Uni, cs1 <> cs2 <> S.fromList [(tp1, Uni), (tp2, Uni)]) | |
Lam tp body -> do | |
(tp1, cs1) <- go env tp | |
(to, cs2) <- go (tp : env) body | |
return (Pi tp to, cs1 <> cs2 <> S.singleton (tp1, Uni)) | |
isRigid :: Term -> Bool | |
isRigid Constant {} = True | |
isRigid _ = False | |
isStuck :: Term -> Bool | |
isStuck MetaVar {} = True | |
isStuck (Ap f _) = isStuck f | |
isStuck _ = False | |
peelApTelescope :: Term -> (Term, [Term]) | |
peelApTelescope t = go t [] | |
where go (Ap f r) rest = go f (r : rest) | |
go t rest = (t, rest) | |
applyApTelescope :: Term -> [Term] -> Term | |
applyApTelescope = foldl' Ap | |
applyPiTelescope :: Term -> [Term] -> Term | |
applyPiTelescope retTp [] = retTp | |
applyPiTelescope retTp (argTp : rest) = | |
Pi argTp $ raise 1 (applyPiTelescope retTp rest) | |
assertPi :: Env -> Term -> FreshM (Term, Term, S.Set Constraint) | |
assertPi env t = do | |
(tp, cs) <- typeOf t | |
case reduce env t of | |
Pi l r -> return (l, r, cs) | |
t' -> case peelApTelescope t' of | |
(MetaVar stuckMVar tp, cxt) -> do | |
(cxtTps, css) <- unzip <$> mapM typeOf cxt | |
let fromMVarTp = applyPiTelescope Uni cxtTps | |
from <- MetaVar <$> gen <*> return fromMVarTp | |
let toMVarTp = applyPiTelescope Uni (cxtTps ++ [from]) | |
to <- MetaVar <$> gen <*> return toMVarTp | |
let fromMVar = applyApTelescope from cxt | |
let toMVar = applyApTelescope to cxt | |
let cs' = S.singleton (Pi fromMVar (Ap (raise 1 toMVar) (LocalVar 0)), t') | |
return (fromMVar, Ap (raise 1 toMVar) (LocalVar 0), cs' <> fold css) | |
_ -> mzero | |
simplify :: Env -> Constraint -> FreshM (S.Set Constraint) | |
simplify env (t1, t2) | |
| t1 == t2 = return S.empty | |
| reduce env t1 /= t1 = simplify env (reduce env t1, t2) | |
| reduce env t2 /= t2 = simplify env (t1, reduce env t2) | |
| (FreeVar i tp, cxt) <- peelApTelescope t1, | |
(FreeVar j _, cxt') <- peelApTelescope t2, | |
i == j = do | |
guard (length cxt == length cxt') | |
fold <$> mapM (simplify env) (zip cxt cxt') | |
| (Constant i tp, cxt) <- peelApTelescope t1, | |
(Constant j _, cxt') <- peelApTelescope t2, | |
i == j = do | |
guard (length cxt == length cxt') | |
fold <$> mapM (simplify env) (zip cxt cxt') | |
| Lam tp1 body1 <- t1, | |
Lam tp2 body2 <- t2 = do | |
v <- FreeVar <$> gen <*> return tp1 | |
return $ S.fromList | |
[(subst v 0 body1, subst v 0 body2), | |
(tp1, tp2)] | |
| Pi tp1 body1 <- t1, | |
Pi tp2 body2 <- t2 = do | |
v <- FreeVar <$> gen <*> return tp1 | |
return $ S.fromList | |
[(subst v 0 body1, subst v 0 body2), | |
(tp1, tp2)] | |
-- | Lam tp body <- t1 = do | |
-- (from, to, cs) <- assertPi env t2 | |
-- cs' <- simplify env (t1, Lam from (raise 1 t2 `Ap` LocalVar 0)) | |
-- return (cs <> cs') | |
-- | Lam tp body <- t2 = do | |
-- (from, to, cs) <- assertPi env t1 | |
-- cs' <- simplify env (Lam from (raise 1 t1 `Ap` LocalVar 0), t2) | |
-- return (cs <> cs') | |
| otherwise = | |
if isStuck t1 || isStuck t2 then return $ S.singleton (t1, t2) else mzero | |
type Subst = M.Map Id Term | |
manySubst :: Subst -> Term -> Term | |
manySubst s t = M.foldrWithKey (\mv sol t -> substMV sol mv t) t s | |
(<+>) :: Subst -> Subst -> Subst | |
s1 <+> s2 | not (M.null (M.intersection s1 s2)) = error "Impossible" | |
s1 <+> s2 = M.union s1 s2 | |
tryFlexRigid :: Constraint -> FreshM ([FreshM [Subst]], S.Set Constraint) | |
tryFlexRigid (t1, t2) | |
| (MetaVar i tp, cxt1) <- peelApTelescope t1, | |
(stuckTerm, cxt2) <- peelApTelescope t2, | |
not (i `S.member` metavars t2) = do | |
(argTps, cs) <- fmap fold . unzip <$> mapM typeOf cxt1 | |
let possibleSubsts = proj argTps i stuckTerm 0 | |
return (possibleSubsts, cs) | |
| (MetaVar i tp, cxt1) <- peelApTelescope t2, | |
(stuckTerm, cxt2) <- peelApTelescope t1, | |
not (i `S.member` metavars t1) = do | |
(argTps, cs) <- fmap fold . unzip <$> mapM typeOf cxt1 | |
let possibleSubsts = proj argTps i stuckTerm 0 | |
return (possibleSubsts, cs) | |
| otherwise = mzero | |
where proj argTps mv f nargs = | |
generateSubst argTps mv f nargs : proj argTps mv f (nargs + 1) | |
generateSubst argTps mv f nargs = do | |
let mkLam tm = foldr Lam tm argTps | |
let saturateMV tm = foldl' Ap tm (map LocalVar [0..nargs - 1]) | |
let mkSubst = M.singleton mv | |
mvs <- | |
map (uncurry MetaVar) . flip zip argTps <$> replicateM nargs gen | |
let args = map saturateMV mvs | |
return [mkSubst . mkLam $ applyApTelescope t args | |
| t <- map LocalVar [0..length argTps] ++ [f]] | |
repeatedlySimplify :: Env -> S.Set Constraint -> FreshM (S.Set Constraint) | |
repeatedlySimplify env cs = do | |
cs' <- fold <$> traverse (simplify env) (S.toList cs) | |
if cs' == cs then return cs else repeatedlySimplify env cs' | |
unify :: Subst -> Env -> S.Set Constraint -> FreshM (Subst, S.Set Constraint) | |
unify s env cs = do | |
let cs' = applySubst s cs | |
cs'' <- repeatedlySimplify env cs' | |
let (flexflexes, flexrigids) = S.partition flexflex cs'' | |
if S.null flexrigids | |
then return (s, flexflexes) | |
else do | |
(psubsts, newC) <- tryFlexRigid (S.findMax flexrigids) | |
trySubsts psubsts (newC <> flexrigids <> flexflexes) | |
where applySubst s = S.map (\(t1, t2) -> (manySubst s t1, manySubst s t2)) | |
flexflex (t1, t2) = isStuck t1 && isStuck t2 | |
trySubsts [] cs = mzero | |
trySubsts (mss : psubsts) cs = do | |
ss <- mss | |
let tryThese = | |
foldr mplus mzero [unify (newS <+> s) env cs | newS <- ss] | |
let tryThose = trySubsts psubsts cs | |
tryThese `mplus` tryThose | |
driver :: Constraint -> Maybe (Subst, S.Set Constraint) | |
driver = runGenT . unify M.empty M.empty . S.singleton |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment