Created
March 25, 2016 07:19
-
-
Save jozefg/56abfbc42e49f298458d to your computer and use it in GitHub Desktop.
Wadler's classic pattern matching algorithm implemented for a core language with Bound.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveTraversable #-} | |
{-# LANGUAGE DeriveFoldable #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
module PatCompile where | |
import Bound | |
import Bound.Var | |
import Bound.Scope | |
import Control.Monad (ap) | |
import Prelude.Extras (Show1 (..), Eq1 (..)) | |
import Data.List | |
import Debug.Trace | |
newtype Con = MkCon Int deriving (Eq, Show, Ord, Enum) | |
data Constant = IntLit Int | |
| CharLit Char | |
| Con Con | |
-- Non-pure data constants | |
| TrueLit | |
| FalseLit | |
| If | |
| EqInt | |
| EqChar | |
| Plus | Minus | Times | Div | |
| RaiseError | |
| MatchFail | |
deriving (Eq, Show) | |
data Pattern = PVar Int | |
| PCon Con [Pattern] | |
| PLit Constant | |
deriving (Eq, Show) | |
data Exp a = Var a | |
| Const Constant | |
| App (Exp a) (Exp a) | |
| Let Pattern (Exp a) (Scope Int Exp a) | |
| LetRec [Bind a] (Scope RecBV Exp a) | |
| Lambda Pattern (Scope Int Exp a) | |
| Bar (Exp a) (Exp a) | |
| Case (Exp a) [Alt a] -- Invariant: Exhaustive, non-overlapping, "simple" | |
deriving (Eq, Show, Functor, Foldable, Traversable) | |
data RecBV = RecBV {patternNum :: Int, varNum :: Int} deriving (Eq, Show) | |
data Bind a = Bind Pattern (Scope RecBV Exp a) | |
deriving (Eq, Show, Functor, Foldable, Traversable) | |
data Alt a = Alt Pattern (Scope Int Exp a) | |
deriving (Eq, Show, Functor, Foldable, Traversable) | |
instance Show1 Exp where | |
instance Eq1 Exp where | |
instance Applicative Exp where | |
pure = return | |
(<*>) = ap | |
instance Monad Exp where | |
return = Var | |
e >>= f = | |
case e of | |
Var a -> f a | |
Const c -> Const c | |
App l r -> App (l >>= f) (r >>= f) | |
Let p e1 e2 -> Let p (e1 >>= f) (e2 >>>= f) | |
LetRec binds e2 -> LetRec (map bindBind binds) (e2 >>>= f) | |
Lambda p body -> Lambda p (body >>>= f) | |
Bar l r -> Bar (l >>= f) (r >>= f) | |
Case e alts -> Case (e >>= f) (map bindAlt alts) | |
where bindAlt (Alt p e) = Alt p (e >>>= f) | |
bindBind (Bind p e) = Bind p (e >>>= f) | |
mkVar :: a -> Exp (Var b (Exp a)) | |
mkVar = Var . F . Var | |
fillBound :: Eq b => b -> a -> Scope b Exp a -> Scope b Exp a | |
fillBound b new s = | |
Scope $ splat mkVar (\b' -> if b == b' then mkVar new else Var (B b')) s | |
stripScope :: (Show a, Show b) => Scope b Exp a -> Exp a | |
stripScope s = | |
case traverse id $ splat (Var . Just) (const (Var Nothing)) s of | |
Nothing -> error $ "Not closed: " ++ show s | |
Just e -> e | |
-- The multi-match branch we're working with. | |
data FlexibleAlt a = FAlt [Pattern] (Scope Int Exp a) | |
deriving (Functor, Foldable, Traversable, Show) | |
-- Compile a match on an expression with a list of branches and | |
-- a default into a simplified Case expression. This is actually work | |
-- because we require that Case expressions be exhaustive, not use | |
-- nested patterns, and not be overlapping. | |
-- | |
-- To simplify things, we have parallel matching and we demand that | |
-- the list of list of alternatives is a rectangle. | |
match :: (Eq a, Show a) => [a] -> [FlexibleAlt a] -> Exp a -> Exp a | |
match scruts alts def | |
-- Base case, we've compiled all scruts. | |
| [] <- scruts = | |
foldr Bar def $ map (\(FAlt [] e) -> stripScope e) alts | |
-- All variable branches | |
| Just branches <- allVars alts, | |
scrut : remaining <- scruts = | |
let new (i, ps, s) = FAlt ps (fillBound i scrut s) | |
in match remaining (map new branches) def | |
-- Next case, first scrut is matched against only constructors | |
| Just branches <- allCons alts, | |
scrut : remaining <- scruts = | |
let ((_, args, _, _) : _) = branches | |
newVars = [0 .. length args - 1] | |
def' = F . Var <$> def | |
gathered = groupBy (\(i, _, _, _) (j, _, _, _) -> i == j) branches | |
branches' = map (\bs -> let (i : _, as, ps, ss) = unzip4 bs | |
in (i, as, ps, ss)) | |
gathered | |
new (i, argss, pss, ss) = | |
Alt (PCon i (map PVar newVars)) . Scope $ | |
match (map B newVars ++ map (F . Var) remaining) | |
[FAlt (args ++ ps) (F . Var <$> s) | (args, ps, s) <- zip3 argss pss ss] | |
def' | |
in Case (Var scrut) (map new branches') | |
-- A degenerate version of the above where we're matching on literals | |
| Just branches <- allLits alts, | |
scrut : remaining <- scruts = | |
let gathered = groupBy (\(i, _, _) (j, _, _) -> i == j) branches | |
branches' = map (\bs -> let (i : _, ps, ss) = unzip3 bs | |
in (i, ps, ss)) | |
gathered | |
new (l, pss, ss) = | |
Alt (PLit l) . abstract (const Nothing) $ | |
match remaining [FAlt ps s | (ps, s) <- zip pss ss] def | |
in Case (Var scrut) (map new branches') | |
-- A final case, we split apart overlapping patterns into chunks of | |
-- nonoverlapping patterns and process them separately. | |
| chunks <- splitChunks alts = foldr (match scruts) def chunks | |
where allVars [] = Just [] | |
allVars (FAlt (PVar i : ps) s : alts) = ((i, ps, s) :) <$> allVars alts | |
allVars _ = Nothing | |
allCons [] = Just [] | |
allCons (FAlt (PCon c args : ps) s : alts) = | |
((c, args, ps, s) :) <$> allCons alts | |
allCons _ = Nothing | |
allLits [] = Just [] | |
allLits (FAlt (PLit l : ps) s : alts) = | |
((l, ps, s) :) <$> allLits alts | |
allLits _ = Nothing | |
splitChunks = groupBy $ \a b -> case (a, b) of | |
(FAlt (PVar _ : _) _, FAlt (PVar _ : _) _) -> True | |
(FAlt (PCon _ _ : _) _, FAlt (PCon _ _ : _) _) -> True | |
(FAlt (PLit _ : _) _, FAlt (PLit _ : _) _) -> True | |
_ -> False | |
abstractF :: [String] -> Exp String -> Scope Int Exp String | |
abstractF vars = abstract (flip elemIndex vars) | |
app :: Exp a -> [Exp a] -> Exp a | |
app = foldl App | |
con :: Con -> Exp a | |
con = Const . Con | |
instance Num Con where | |
fromInteger = MkCon . fromIntegral | |
test = | |
match ["hello", "world"] | |
[ FAlt [PCon 0 [], PVar 0] $ abstractF ["x"] (Var "x") | |
, FAlt [PVar 0, PCon 0 []] $ abstractF ["x"] (Var "x") | |
, FAlt [PCon 1 [PVar 0, PVar 1], PCon 1 [PVar 2, PVar 3]] | |
. abstractF ["x", "xs", "y", "ys"] | |
$ app (con 1) [Var "x", app (con 1) [Var "y", app (Var "rec") [Var "xs", Var "ys"]]]] | |
(Const MatchFail) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment