Created
March 31, 2016 22:26
-
-
Save jozefg/58def43f9ad2865c4755954303c49ba0 to your computer and use it in GitHub Desktop.
WIP - Doodle of a fully lazy lambda lifter
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE DeriveFoldable #-} | |
{-# LANGUAGE DeriveTraversable #-} | |
{-# LANGUAGE TypeSynonymInstances #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE TypeFamilies #-} | |
module Lift where | |
import Control.Monad.Gen | |
import Data.Functor.Foldable (Fix(..)) | |
import Data.List (partition) | |
import qualified Data.Functor.Foldable as F | |
import qualified Data.Set as S | |
import qualified Data.Map as M | |
type Name = String | |
type Tag = Int | |
data Binding = Rec | NoRec | |
data Lit = IntLit Int | |
data Op = Plus | Minus | Times | Div | Neg | |
data CoreF var a = Var Name | |
| Ap a a | |
| Let Binding [(var, a)] a | |
| Case a [AltF var a] | |
| Lam [var] a | |
| Op Op | |
| Lit Lit | |
| Con Tag Int -- tag, arity | |
deriving (Functor) | |
data AltF var a = Default a | |
| LitAlt Lit a | |
| ConAlt Tag [var] a | |
deriving (Functor) | |
type Alt var = (AltF var (Core var)) | |
type Core var = Fix (CoreF var) | |
type Expr = Core Name | |
data SC var = SC [var] (Core var) | |
var :: Name -> Core a | |
var = Fix . Var | |
ap :: Core a -> Core a -> Core a | |
ap l = Fix . Ap l | |
let_ :: Binding -> [(a, Core a)] -> Core a -> Core a | |
let_ n b = Fix . Let n b | |
case_ :: Core a -> [Alt a] -> Core a | |
case_ e = Fix . Case e | |
lam :: [a] -> Core a -> Core a | |
lam as = Fix . Lam as | |
op :: Op -> Core a | |
op = Fix . Op | |
lit :: Lit -> Core a | |
lit = Fix . Lit | |
con :: Tag -> Int -> Core a | |
con t = Fix . Con t | |
-- This is a little bit of a gamble, but it'd be cool if we could use | |
-- recursion-schemes seemlessly with annotations: | |
data Annot annot f = Annot annot (f (Annot annot f)) | |
data AnnotF annot f b = AnnotF annot (f b) deriving (Functor) | |
type AnnotExpr a b = Annot a (CoreF b) -- Useful for annotating the below | |
annot :: Annot annot f -> annot | |
annot (Annot a _) = a | |
type instance F.Base (Annot annot f) = AnnotF annot f | |
instance F.Foldable (AnnotExpr annot var) where | |
project (Annot a f) = AnnotF a f | |
instance F.Unfoldable (AnnotExpr annot var) where | |
embed (AnnotF a f) = Annot a f | |
-- The point of this module is to take the nice high level language | |
-- and restrict it in various ways so that it's suitable for | |
-- compilation to something like a STG machine, or even just TIM or G | |
-- machine variants. | |
-- | |
-- In our case we'd like to implement one simple transformation here. Our | |
-- current language contains lambdas which have | |
-- | |
-- 1. Free variables | |
-- 2. May appear anywhere in an expression | |
-- | |
-- We'd like to fix this by lifting lambdas to the top level while | |
-- also adding enough arguments to them so that they become closed. | |
-- In order to make this a bit more fun we've also implemented this | |
-- so that the transformation maintains full lazyness. If a subexpression | |
-- can be shared between invocations of the lifted lambdas, we will make | |
-- every effort to make this the case. | |
-- A warm up | |
freeVars :: Expr -> S.Set Name | |
freeVars = F.cata go | |
where go e = case e of | |
Var n -> S.singleton n | |
Ap l r -> l `S.union` r | |
Let NoRec bindings a -> | |
let (vars, freeBindings) = mconcat <$> unzip bindings | |
in freeBindings `S.union` (a `S.difference` S.fromList vars) | |
Let Rec bindings a -> | |
let (vars, freeBindings) = mconcat <$> unzip bindings | |
in (freeBindings `S.union` a) `S.difference` S.fromList vars | |
Case a alts -> mconcat (a : map goAlt alts) | |
Lam vars a -> a `S.difference` S.fromList vars | |
Op _ -> S.empty | |
Con _ _ -> S.empty | |
Lit _ -> S.empty | |
goAlt alt = case alt of | |
Default a -> a | |
LitAlt _ a -> a | |
ConAlt _ vars a -> a `S.difference` S.fromList vars | |
annFree :: Expr -> AnnotExpr (S.Set Name) Name | |
annFree = F.cata go | |
where go e = case e of | |
Var n -> Annot (S.singleton n) (Var n) | |
Ap l r -> Annot (S.union (annot l) (annot r)) (Ap l r) | |
Let NoRec bindings a -> | |
let (vars, freeBind) = foldMap annot <$> unzip bindings | |
free = S.union freeBind (annot a `S.difference` S.fromList vars) | |
in Annot free (Let NoRec bindings a) | |
Let Rec bindings a -> | |
let (vars, freeBinds) = foldMap annot <$> unzip bindings | |
free = (S.union freeBinds $ annot a) `S.difference` S.fromList vars | |
in Annot free (Let Rec bindings a) | |
Case a alts -> | |
Annot (mconcat $ annot a : map goAlt alts) (Case a alts) | |
Lam vars a -> | |
Annot (annot a `S.difference` S.fromList vars) (Lam vars a) | |
Op o -> Annot S.empty (Op o) | |
Con t a -> Annot S.empty (Con t a) | |
Lit i -> Annot S.empty (Lit i) | |
goAlt alt = case alt of | |
Default a -> annot a | |
LitAlt _ a -> annot a | |
ConAlt _ vars a -> annot a `S.difference` S.fromList vars | |
separateLambdas :: Core a -> Core a | |
separateLambdas = F.cata go | |
where go (Lam vars a) = foldr (lam . (:[])) a vars | |
go e = Fix e | |
-- A few quick helper functions | |
getLevel :: S.Set Name -> M.Map Name Int -> Int | |
getLevel free levels = maximum . S.insert 0 $ S.map (levels M.!) free | |
annVars :: [Name] -> Int -> [(Int, Name)] | |
annVars vars curr = zip (repeat curr) vars | |
extend :: [Name] -> Int -> M.Map Name Int -> M.Map Name Int | |
extend vars curr levels = M.fromList (zip vars (repeat curr)) `M.union` levels | |
-- The meet of our transformation | |
addLevels :: AnnotExpr (S.Set Name) Name | |
-> AnnotExpr (Int, S.Set Name) (Int, Name) | |
addLevels e = F.cata go e M.empty 0 | |
where go (AnnotF free ef) levels curr = | |
Annot (getLevel free levels, free) $ case ef of | |
Var n -> Var n | |
Op o -> Op o | |
Con t a -> Con t a | |
Lit i -> Lit i | |
Ap l r -> Ap (l levels curr) (r levels curr) | |
Lam vars body -> | |
Lam (annVars vars (curr + 1)) | |
(body (extend vars (curr + 1) levels) (curr + 1)) | |
Let NoRec binds body -> | |
let (vars, vals) = unzip binds | |
vals' = map (\v -> v levels curr) vals | |
bindingLevel = getLevel (foldMap (snd . annot) vals') levels | |
vars' = annVars vars bindingLevel | |
body' = body (extend vars bindingLevel levels) curr | |
in Let NoRec (zip vars' vals') body' | |
Let Rec binds body -> | |
let (vars, vals) = unzip binds | |
fakeLevels = extend vars 0 levels -- Needed for recursion | |
vals' = map (\v -> v fakeLevels curr) vals | |
bindingLevel = maximum $ map (fst . annot) vals' | |
levels' = extend vars bindingLevel levels | |
vars' = annVars vars bindingLevel | |
body' = body levels' curr | |
in Let Rec (zip vars' vals') body' | |
Case a alts -> Case (a levels curr) (map (goAlt levels curr) alts) | |
goAlt levels curr alt = case alt of | |
Default f -> Default (f levels curr) | |
LitAlt i f -> LitAlt i (f levels curr) | |
ConAlt t vars f -> | |
ConAlt t (annVars vars curr) (f (extend vars curr levels) curr) | |
-- Decide whether or not something is worth floating. Only allowed to | |
-- inspect top most levle though | |
shouldMFE :: CoreF f a -> Bool | |
shouldMFE (Var _) = False | |
shouldMFE (Op _) = False | |
shouldMFE (Con _ _) = False | |
shouldMFE _ = True | |
-- Note that we push the level into the binding | |
explodeMFE :: AnnotExpr (Int, S.Set Name) (Int, Name) -> Core (Int, Name) | |
explodeMFE annE = F.cata go annE 0 | |
where explode level curr should e | |
| level == curr || not should = e | |
| otherwise = let_ NoRec [((level, "v"), e)] (var "v") | |
go (AnnotF (level, _) e) curr = | |
explode level curr (shouldMFE e) $ case e of | |
Var n -> var n | |
Ap l r -> ap (l curr) (r curr) | |
Case matchee alts -> case_ (matchee curr) (map (goAlt curr) alts) | |
Let re binds body -> | |
let new ((blevel, bvar), val) = ((blevel, bvar), val blevel) | |
in let_ re (map new binds) (body level) | |
Lam vars body -> | |
let bodyLevel = maximum (map fst vars) | |
in lam vars (body bodyLevel) | |
Op o -> op o | |
Lit l -> lit l | |
Con t a -> con t a | |
goAlt curr alt = case alt of | |
Default a -> Default (a curr) | |
LitAlt l a -> LitAlt l (a curr) | |
ConAlt t vars a -> ConAlt t vars (a curr) | |
rename :: Core (Int, Name) -> Core (Int, Name) | |
rename expr = | |
runGenWith (successor $ show . (succ :: Int -> Int) . read) "0" | |
$ F.cata go expr M.empty | |
where freshExtend vars names = do | |
newNames <- traverse (\(_, n) -> (,) n <$> gen) vars | |
let names' = M.fromList newNames `M.union` names | |
let vars' = map (fmap (names' M.!)) vars | |
return (vars', names') | |
go e names = case e of | |
Var n -> return $ var (names M.! n) | |
Ap l r -> ap <$> l names <*> r names | |
Case matchee alts -> | |
case_ <$> matchee names <*> mapM (goAlt names) alts | |
Let re binds body -> do | |
let (vars, vals) = unzip binds | |
(vars', names') <- freshExtend vars names | |
vals' <- case re of | |
NoRec -> mapM ($ names) vals | |
Rec -> mapM ($ names') vals | |
let_ re (zip vars' vals') <$> body names' | |
Lam vars body -> do | |
(vars', names') <- freshExtend vars names | |
lam vars' <$> body names' | |
Op o -> return (op o) | |
Lit l -> return (lit l) | |
Con t a -> return (con t a) | |
goAlt names alt = case alt of | |
Default a -> Default <$> a names | |
LitAlt l a -> LitAlt l <$> a names | |
ConAlt t vars a -> do | |
(vars', names') <- freshExtend vars names | |
ConAlt t vars' <$> a names' | |
data Block = Block { blockLevel :: Int | |
, blockBind :: Binding | |
, blockDefns :: [(Name, Expr)] | |
} | |
floatLets :: Core (Int, Name) -> Expr | |
floatLets = uncurry install . F.cata go | |
where go e = case e of | |
Var n -> (var n, []) | |
Ap l r -> (ap (fst l) (fst r), snd l ++ snd r) | |
Op o -> (op o, []) | |
Lit l -> (lit l, []) | |
Con t a -> (con t a, []) | |
Let re binds (body, laterDefs) -> | |
let (vars, vals) = unzip binds | |
names = map snd vars | |
level = maximum (map fst vars) | |
(vals', priorDefs) = concat <$> unzip vals | |
defs = priorDefs ++ Block level re (zip names vals') : laterDefs | |
in (body, defs) | |
Case (matchee, priorDefns) alts -> | |
let (alts', defns) = concat <$> unzip (map goAlt alts) | |
in (case_ matchee alts', priorDefns ++ defns) | |
Lam vars (body, defns) -> | |
let names = map snd vars | |
level = maximum (map fst vars) | |
(defns', localDefns) = partition ((> level) . blockLevel) defns | |
in (lam names (install body localDefns), defns') | |
goAlt alt = case alt of | |
Default (a, defns) -> (Default a, defns) | |
LitAlt l (a, defns) -> (LitAlt l a, defns) | |
ConAlt t vars (a, defns) -> (ConAlt t (map snd vars) a, defns) | |
install e blocks = | |
foldr (\b -> let_ (blockBind b) (blockDefns b)) e blocks | |
doTheThing :: Expr -> Expr | |
doTheThing = | |
floatLets . rename . explodeMFE . addLevels . annFree . separateLambdas |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment