Skip to content

Instantly share code, notes, and snippets.

@KiJeong-Lim
Created July 20, 2020 15:51
Show Gist options
  • Save KiJeong-Lim/c7243c261ead9046980c0e4a2f86bd76 to your computer and use it in GitHub Desktop.
Save KiJeong-Lim/c7243c261ead9046980c0e4a2f86bd76 to your computer and use it in GitHub Desktop.
Higher-Order Pattern Unification
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import Control.Monad.Trans.State.Strict
import Data.IORef
import qualified Data.List as List
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Unique
infix 5 +->
type Symbol = Either String Unique
type LogicVar = Symbol
type Constant = Symbol
type DeBruijn = Int
type SuspEnv = [SuspItem]
type ScopeLevel = Int
data TermNode
= LVar LogicVar
| NCon Constant
| NIdx DeBruijn
| NApp TermNode TermNode
| NAbs TermNode
| Susp { getSubstitutee :: TermNode, getOL :: Int, getNL :: Int, getSuspEnv :: SuspEnv }
deriving (Eq, Ord)
data SuspItem
= Dummy Int
| Binds TermNode Int
deriving (Eq, Ord)
data SubstLVar
= SubstLVar { getSubstLVar :: Map.Map LogicVar TermNode }
deriving (Eq)
data ReductionOption
= WHNF
| HNF
| NF
deriving (Eq)
data Disagreement
= Disagreement { getLHS :: TermNode, getRHS :: TermNode }
deriving (Eq)
data EnvOfHOPU
= EnvOfHOPU { getLabeling :: Map.Map Symbol ScopeLevel, getBinding :: SubstLVar }
deriving ()
class HasLVar a where
getFreeLVs :: a -> Set.Set LogicVar -> Set.Set LogicVar
substLVars :: SubstLVar -> a -> a
instance Show TermNode where
showsPrec = showTerm where
showTerm :: Int -> TermNode -> String -> String
showTerm 0 (NApp t1 t2) str = showTerm 0 t1 (" " ++ showTerm 1 t2 str)
showTerm 0 t str = showTerm 1 t str
showTerm 1 (NAbs t) str = "\\" ++ showTerm 1 t str
showTerm 1 t str = showTerm 2 t str
showTerm 2 (NIdx i) str = "#" ++ showsPrec 0 i str
showTerm 2 (NCon (Left c)) str = c ++ str
showTerm 2 (NCon (Right c)) str = "c_" ++ showsPrec 0 (hashUnique c) str
showTerm 2 (LVar (Left v)) str = v ++ str
showTerm 2 (LVar (Right v)) str = "v_" ++ showsPrec 0 (hashUnique v) str
showTerm 2 (Susp t ol nl env) str = "{ " ++ showTerm 0 t (", ol = " ++ showsPrec 0 ol (", nl = " ++ showsPrec 0 nl (", env = " ++ showEnv env (" }" ++ str))))
showTerm 2 t str = showTerm 3 t str
showTerm _ t str = "(" ++ showTerm 0 t (")" ++ str)
showItem :: SuspItem -> String -> String
showItem item str = case item of
Binds t l -> "(" ++ showTerm 0 t (", " ++ showsPrec 0 l (")" ++ str))
Dummy l -> "@" ++ showsPrec 0 l str
showEnv :: SuspEnv -> String -> String
showEnv [] str = "[]" ++ str
showEnv (item : items) str = "[" ++ go item items str where
go :: SuspItem -> [SuspItem] -> String -> String
go item' [] str = showItem item' ("]" ++ str)
go item' (item : items) str = showItem item' (", " ++ go item items str)
instance Read TermNode where
readsPrec = flip go [] where
cond1 :: Char -> Bool
cond1 ch = ch `elem` ['A' .. 'Z'] || ch `elem` ['a' .. 'z'] || ch `elem` ['0' .. '9']
cond2 :: Char -> Bool
cond2 ch = ch `elem` ['a' .. 'z'] || ch `elem` ['0' .. '9'] || ch == '_'
readCon :: String -> [(String, String)]
readCon (ch : str)
| ch `elem` ['A' .. 'Z'] = [(ch : takeWhile cond1 str, dropWhile cond1 str)]
readCon _ = []
readVar :: String -> [(String, String)]
readVar (ch : str)
| ch `elem` ['a' .. 'z'] = [(ch : takeWhile cond2 str, dropWhile cond2 str)]
readVar _ = []
installVar :: [String] -> String -> TermNode
installVar env str = case str `List.elemIndex` env of
Nothing -> LVar (Left str)
Just i -> NIdx (i + 1)
many :: (String -> [(a, String)]) -> (String -> [([a], String)])
many p str0 = [ (x : xs, str2) | (x, str1) <- p str0, (xs, str2) <- many p str1 ] ++ [([], str0)]
go :: Int -> [String] -> String -> [(TermNode, String)]
go 0 env str0 = [ (NAbs t, str2) | (str, '\\' : ' ' : str1) <- readVar str0, (t, str2) <- go 0 (str : env) str1 ] ++ go 1 env str0
go 1 env str0 = [ (List.foldl' NApp t ts, str2) | (t, str1) <- go 2 env str0, (ts, str2) <- many (getArgs env) str1 ]
go 2 env ('(' : str0) = [ (t, str1) | (t, ')' : str1) <- go 0 env str0 ]
go 2 env str0 = [ (NCon (Left str), str1) | (str, str1) <- readCon str0 ] ++ [ (installVar env str, str1) | (str, str1) <- readVar str0 ]
getArgs :: [String] -> String -> [(TermNode, String)]
getArgs env (' ' : str0) = go 2 env str0
getArgs _ _ = []
instance HasLVar TermNode where
getFreeLVs = flip go where
go :: Set.Set LogicVar -> TermNode -> Set.Set LogicVar
go fvs (LVar v) = Set.insert v fvs
go fvs (NCon c) = fvs
go fvs (NIdx i) = fvs
go fvs (NApp t1 t2) = go (go fvs t1) t2
go fvs (NAbs t1) = go fvs t1
go fvs (Susp t ol nl env) = go fvs (rewriteWithSusp t ol nl env HNF)
substLVars (SubstLVar mapsto) = inTerm where
inItem :: SuspItem -> SuspItem
inItem (Dummy l) = mkDummy l
inItem (Binds t l) = mkBinds (inTerm t) l
inTerm :: TermNode -> TermNode
inTerm (LVar v) = case Map.lookup v mapsto of
Nothing -> LVar v
Just t -> t
inTerm (NCon c) = NCon c
inTerm (NIdx i) = NIdx i
inTerm (NApp t1 t2) = mkNApp (inTerm t1) (inTerm t2)
inTerm (NAbs t1) = mkNAbs (inTerm t1)
inTerm (Susp t ol nl env) = mkSusp (inTerm t) ol nl (map inItem env)
instance HasLVar a => HasLVar [a] where
getFreeLVs = flip (foldr getFreeLVs)
substLVars = map . substLVars
instance HasLVar b => HasLVar (a, b) where
getFreeLVs = getFreeLVs . snd
substLVars = fmap . substLVars
instance HasLVar a => HasLVar (Map.Map k a) where
getFreeLVs = getFreeLVs . Map.elems
substLVars = Map.map . substLVars
instance Semigroup SubstLVar where
SubstLVar { getSubstLVar = map2 } <> SubstLVar { getSubstLVar = map1 } = map21 `seq` SubstLVar { getSubstLVar = map21 } where
map21 :: Map.Map LogicVar TermNode
map21 = substLVars (SubstLVar { getSubstLVar = map2 }) map1 `Map.union` map2
instance Monoid SubstLVar where
mempty = SubstLVar { getSubstLVar = Map.empty }
instance Show Disagreement where
showsPrec _ (Disagreement lhs rhs) str = "`" ++ showsPrec 5 lhs ("\' =?= `" ++ showsPrec 5 rhs ("\'" ++ str))
instance HasLVar Disagreement where
getFreeLVs (Disagreement lhs1 rhs1) fvs = getFreeLVs lhs1 (getFreeLVs rhs1 fvs)
substLVars theta1 (Disagreement lhs1 rhs1) = Disagreement (substLVars theta1 lhs1) (substLVars theta1 rhs1)
getFreeLVars :: HasLVar a => a -> Set.Set LogicVar
getFreeLVars = flip getFreeLVs Set.empty
mkBinds :: TermNode -> Int -> SuspItem
mkBinds t l = t `seq` l `seq` Binds t l
mkDummy :: Int -> SuspItem
mkDummy l = l `seq` Dummy l
mkNApp :: TermNode -> TermNode -> TermNode
mkNApp t1 t2 = t1 `seq` t2 `seq` NApp t1 t2
mkNAbs :: TermNode -> TermNode
mkNAbs t1 = t1 `seq` NAbs t1
mkSusp :: TermNode -> Int -> Int -> SuspEnv -> TermNode
mkSusp t 0 0 [] = t
mkSusp t ol nl env = t `seq` Susp { getSubstitutee = t, getOL = ol, getNL = nl, getSuspEnv = env }
viewLambda :: TermNode -> (Int, TermNode)
viewLambda = go 0 where
go :: Int -> TermNode -> (Int, TermNode)
go n (NAbs t) = go (n + 1) t
go n t = (n, t)
makeLambda :: (Int, TermNode) -> TermNode
makeLambda (n, t) = foldr ($) t (replicate n NAbs)
unFoldNApp :: TermNode -> (TermNode, [TermNode])
unFoldNApp = go [] where
go :: [TermNode] -> TermNode -> (TermNode, [TermNode])
go ts (NApp t1 t2) = go (t2 : ts) t1
go ts t = (t, ts)
isR :: TermNode -> Bool
isR (NCon c) = True
isR (NIdx i) = True
isR _ = False
rewriteWithSusp :: TermNode -> Int -> Int -> SuspEnv -> ReductionOption -> TermNode
rewriteWithSusp (LVar v) ol nl env option = v `seq` LVar v
rewriteWithSusp (NCon c) ol nl env option = c `seq` NCon c
rewriteWithSusp (NIdx i) ol nl env option
| i > ol =
let j = i - ol + nl
in j `seq` NIdx j
| i >= 1 = case env !! (i - 1) of
Dummy l ->
let j = nl - l
in j `seq` NIdx j
Binds t l -> rewriteWithSusp t 0 (nl - l) [] option
| otherwise = error ("rewriteWithSusp (NIdx " ++ show i ++ ") " ++ show ol ++ " " ++ show nl ++ " ")
rewriteWithSusp (NApp t1 t2) ol nl env option = case rewriteWithSusp t1 ol nl env WHNF of
NAbs (Susp t1' ol1' nl1' (Dummy nl1 : env1'))
| ol1' > 0 && nl1 + 1 == nl1' ->
let t2' = mkSusp t2 ol nl env
in t2' `seq` rewriteWithSusp t1' ol1' nl1 (mkBinds t2' nl1 : env1') option
NAbs t1' ->
let t2' = mkSusp t2 ol nl env
in t2' `seq` rewriteWithSusp t1' 1 0 [mkBinds t2' 0] option
t1' -> case option of
NF -> mkNApp (rewriteWithSusp t1' 0 0 [] option) (rewriteWithSusp t2 ol nl env option)
HNF -> mkNApp (rewriteWithSusp t1' 0 0 [] option) (mkSusp t2 ol nl env)
WHNF -> mkNApp t1' (mkSusp t2 ol nl env)
rewriteWithSusp (NAbs t1) ol nl env option
| option == WHNF = mkNAbs (mkSusp t1 (ol + 1) (nl + 1) (mkDummy nl : env))
| otherwise = mkNAbs (rewriteWithSusp t1 (ol + 1) (nl + 1) (mkDummy nl : env) option)
rewriteWithSusp (Susp t ol nl env) ol' nl' env' option
| ol == 0 && nl == 0 = rewriteWithSusp t ol' nl' env' option
| ol' == 0 = rewriteWithSusp t ol (nl + nl') env option
| otherwise = case rewriteWithSusp t ol nl env option of
NAbs t'
| option == WHNF -> mkNAbs (mkSusp t' (ol' + 1) (nl' + 1) (mkDummy nl' : env'))
| otherwise -> mkNAbs (rewriteWithSusp t' (ol' + 1) (nl' + 1) (mkDummy nl' : env') option)
t' -> rewriteWithSusp t' ol' nl' env' option
rewrite :: ReductionOption -> TermNode -> TermNode
rewrite option t = rewriteWithSusp t 0 0 [] option
(+->) :: LogicVar -> TermNode -> SubstLVar
v +-> t = SubstLVar (Map.singleton v t)
getNewSymbol :: MonadIO m => ScopeLevel -> StateT EnvOfHOPU m Symbol
getNewSymbol label = do
sym <- fmap Right $ lift (liftIO newUnique)
env <- get
put (env { getLabeling = Map.insert sym label (getLabeling env) })
return sym
substEnv :: SubstLVar -> EnvOfHOPU -> EnvOfHOPU
substEnv theta (EnvOfHOPU { getLabeling = labeling, getBinding = binding }) = EnvOfHOPU { getLabeling = labeling', getBinding = binding' } where
labeling' :: Map.Map Symbol ScopeLevel
labeling' = flip Map.mapWithKey labeling $ \v -> \label -> foldr min label [ label' | Just label' <- [ Map.lookup v' labeling | (v', t') <- Map.toList (getSubstLVar theta), v `Set.member` getFreeLVars t' ] ]
binding' :: SubstLVar
binding' = theta <> binding
insert :: Eq a => a -> [a] -> [a]
insert x [] = [x]
insert x (y : ys)
| x == y = ys
| otherwise = y : insert x ys
areAllDistinct :: Eq a => [a] -> Bool
areAllDistinct [] = True
areAllDistinct (x : xs) = go x xs && areAllDistinct xs where
go :: Eq a => a -> [a] -> Bool
go y [] = True
go y (z : zs)
| y == z = False
| otherwise = go y zs
isPattern :: Map.Map Symbol ScopeLevel -> LogicVar -> [TermNode] -> Bool
isPattern labeling v ts = all isR ts && areAllDistinct ts && all (\c -> labeling Map.! v < labeling Map.! c) [ c | NCon c <- ts ]
down :: [TermNode] -> [TermNode] -> StateT EnvOfHOPU (ExceptT String IO) [TermNode]
zs `down` ts
| areAllDistinct ts && all isR ts && areAllDistinct zs && all isR zs = return [ NIdx (length ts - i) | Just i <- [ z `List.elemIndex` ts | z <- zs ] ]
| otherwise = lift (throwE "down failed.")
up :: [TermNode] -> LogicVar -> StateT EnvOfHOPU (ExceptT String IO) [TermNode]
ts `up` y
| areAllDistinct ts && all isR ts = do
env <- get
return [ NCon c | NCon c <- ts, getLabeling env Map.! c <= getLabeling env Map.! y ]
| otherwise = lift (throwE "up failed.")
makeSubst :: LogicVar -> TermNode -> [TermNode] -> EnvOfHOPU -> ExceptT String IO (Maybe EnvOfHOPU)
makeSubst = go where
bnd :: LogicVar -> TermNode -> [TermNode] -> Int -> StateT EnvOfHOPU (ExceptT String IO) TermNode
bnd x rhs ts l
| (m, t) <- viewLambda rhs'
, m > 0
= do
result <- bnd x t params (l + m)
return (makeLambda (m, result))
| (r, ts) <- unFoldNApp rhs'
, isR r
= do
env <- get
r' <- get_r' x r ts params l (getLabeling env) ([ rewriteWithSusp a 0 l [] NF | a <- params ] ++ map NIdx [l, l - 1 .. 1])
results <- foldbnd x ts params l []
return (List.foldl' mkNApp r' (reverse results))
| (LVar y, bs) <- unFoldNApp rhs'
= do
env <- get
if x == y
then lift (throwE "occurs check failed.")
else do
let al = [ rewriteWithSusp a 0 l [] NF | a <- params ] ++ map NIdx [l, l - 1 .. 1]
bl = map (rewrite NF) bs
zl = Set.toList (Set.delete (NIdx l) (Set.fromList al) `Set.intersection` Set.fromList bl)
if isPattern (getLabeling env) y bl
then if getLabeling env Map.! x < getLabeling env Map.! y
then do
cs <- al `up` y
ws <- cs `down` al
us <- zl `down` bl
vs <- zl `down` al
h <- getNewSymbol (getLabeling env Map.! x)
modify (substEnv (y +-> (makeLambda (length bl, List.foldl' mkNApp (LVar h) (cs ++ us)))))
return (List.foldl' mkNApp (LVar h) (ws ++ vs))
else do
cs <- bl `up` x
ws <- cs `down` bl
us <- zl `down` al
vs <- zl `down` bl
h <- getNewSymbol (getLabeling env Map.! x)
modify (substEnv (y +-> (makeLambda (length bl, List.foldl' mkNApp (LVar h) (ws ++ vs)))))
return (List.foldl' mkNApp (LVar h) (cs ++ us))
else lift (throwE "NotAPattern")
| otherwise
= lift (throwE "bnd failed.")
where
rhs' :: TermNode
rhs' = rewrite HNF rhs
params :: [TermNode]
params = map (rewrite HNF) ts
get_r' :: LogicVar -> TermNode -> [TermNode] -> [TermNode] -> Int -> Map.Map Symbol ScopeLevel -> [TermNode] -> StateT EnvOfHOPU (ExceptT String IO) TermNode
get_r' x r ts params l binding al
| NCon c <- r
, binding Map.! x >= binding Map.! c
= return r
| Just d <- r `List.elemIndex` al
= return (NIdx (length al - d))
| otherwise
= lift (throwE "flex-rigid failed.")
foldbnd :: LogicVar -> [TermNode] -> [TermNode] -> Int -> [TermNode] -> StateT EnvOfHOPU (ExceptT String IO) [TermNode]
foldbnd x [] al l results = return results
foldbnd x (t : ts) al l results = do
result <- bnd x t al l
env <- get
foldbnd x (substLVars (getBinding env) ts) al l (result : results)
mkSubst :: LogicVar -> TermNode -> [TermNode] -> StateT EnvOfHOPU (ExceptT String IO) ()
mkSubst x rhs params
| (k, rhs') <- viewLambda (rewrite HNF rhs)
, (x', args) <- unFoldNApp rhs'
, LVar x == x'
= do
env <- get
if isPattern (getLabeling env) x args
then do
h <- getNewSymbol (getLabeling env Map.! x)
let al = [ rewriteWithSusp a 0 k [] NF | a <- params ] ++ map NIdx [k, k - 1 .. 1]
n = length params
wl = [ NIdx (n + k - i) | i <- [1, 2 .. n + k], al !! (i - 1) == args !! (i - 1) ]
modify (substEnv (x +-> (makeLambda (k + n, List.foldl' mkNApp (LVar h) wl))))
else lift (throwE "NotAPattern")
| otherwise
= do
result <- bnd x rhs params 0
let n = length params
modify (substEnv (x +-> (makeLambda (n, result))))
go :: LogicVar -> TermNode -> [TermNode] -> EnvOfHOPU -> ExceptT String IO (Maybe EnvOfHOPU)
go v t params env = catchE (Just <$> execStateT (mkSubst v t params) env) $ \str -> if str == "NotAPattern" then return Nothing else throwE str
runHOPU :: [Disagreement] -> Map.Map Symbol ScopeLevel -> IO (Maybe (EnvOfHOPU, [Disagreement]))
runHOPU = go where
simpl :: IORef Bool -> [Disagreement] -> StateT EnvOfHOPU (ExceptT String IO) [Disagreement]
simpl changed [] = return []
simpl changed (Disagreement lhs rhs : disagreements)
| (n1, t1) <- viewLambda lhs'
, (n2, t2) <- viewLambda rhs'
, n1 > 0 && n2 > 0
= case n1 `compare` n2 of
LT -> simpl changed (Disagreement t1 (makeLambda (n2 - n1, t2)) : disagreements)
EQ -> simpl changed (Disagreement t1 t2 : disagreements)
GT -> simpl changed (Disagreement (makeLambda (n1 - n2, t1)) t2 : disagreements)
| (n1, t1) <- viewLambda lhs'
, (r2, ts2) <- unFoldNApp rhs'
, isR r2 && n1 > 0
= simpl changed (Disagreement t1 (List.foldl' mkNApp r2 (ts2 ++ map NIdx [n1, n1 - 1 .. 1])) : disagreements)
| (n2, t2) <- viewLambda rhs'
, (r1, ts1) <- unFoldNApp lhs'
, isR r1 && n2 > 0
= simpl changed (Disagreement (List.foldl' mkNApp r1 (ts1 ++ map NIdx [n2, n2 - 1 .. 1])) t2 : disagreements)
| (r1, ts1) <- unFoldNApp lhs'
, (r2, ts2) <- unFoldNApp rhs'
, isR r1 && isR r2
= if r1 == r2
then simpl changed (zipWith Disagreement ts1 ts2 ++ disagreements)
else lift (throwE "rigid-rigid pairs does not match.")
| (LVar v1, ts1) <- unFoldNApp lhs'
= do
env <- get
let params1 = map (\t -> rewriteWithSusp t 0 0 [] HNF) ts1
if isPattern (getLabeling env) v1 params1
then do
result <- lift (makeSubst v1 rhs' params1 env)
case result of
Nothing -> doNext
Just env' -> do
put env'
lift (lift (writeIORef changed True))
simpl changed (substLVars (getBinding env') disagreements)
else doNext
| (LVar v2, ts2) <- unFoldNApp rhs'
= do
env <- get
let params2 = map (rewrite HNF) ts2
if isPattern (getLabeling env) v2 params2
then do
result <- lift (makeSubst v2 lhs' params2 env)
case result of
Nothing -> doNext
Just env' -> do
put env'
lift (lift (writeIORef changed True))
simpl changed (substLVars (getBinding env') disagreements)
else doNext
| otherwise
= doNext
where
lhs' :: TermNode
lhs' = rewrite HNF lhs
rhs' :: TermNode
rhs' = rewrite HNF rhs
doNext :: StateT EnvOfHOPU (ExceptT String IO) [Disagreement]
doNext = do
disagreements' <- simpl changed disagreements
env' <- get
return (insert (substLVars (getBinding env') (Disagreement lhs' rhs')) disagreements')
loop :: IORef Bool -> [Disagreement] -> StateT EnvOfHOPU (ExceptT String IO) [Disagreement]
loop changed disagreements = do
disagreements' <- simpl changed disagreements
has_changed <- lift (lift (readIORef changed))
if has_changed
then do
lift (lift (writeIORef changed False))
loop changed disagreements'
else return disagreements'
go :: [Disagreement] -> Map.Map Symbol ScopeLevel -> IO (Maybe (EnvOfHOPU, [Disagreement]))
go disagreements labeling = do
changed <- newIORef False
result <- runExceptT (runStateT (loop changed disagreements) (EnvOfHOPU labeling mempty))
case result of
Right (disagreements', env') -> return (Just (env', disagreements'))
Left str -> return Nothing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment