Created
April 30, 2016 04:36
-
-
Save jozefg/5e0a8908eb019c38b0de5097acf1d984 to your computer and use it in GitHub Desktop.
A simple register allocator which performs liveness analysis at the basic block level and properly handles spilling. Doesn't do coalescing though for brevity.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts, ConstraintKinds, RecordWildCards #-} | |
module RegAlloc where | |
import Control.Applicative | |
import Control.Monad | |
import Data.Set (Set) | |
import Data.Map (Map) | |
import Control.Monad.Writer.Strict | |
import Control.Monad.State.Strict | |
import qualified Data.Set as S | |
import qualified Data.Map as M | |
data Label = User Int | Fresh Int deriving (Eq, Show, Ord) | |
-- Horrible. Exclusively for testing | |
instance Num Label where | |
fromInteger = User . fromInteger | |
-- Make warnings shut up. | |
(+) = (+); (*) = (*); abs = abs; signum = signum; (-) = (-) | |
type Temp = Int | |
data RelOp = Eq | Neq | Leq | Geq | Lt | Gt deriving (Eq, Show) | |
data Op = Add | Sub | Mul | Div deriving (Eq, Show) | |
-- The explicit "in stack" operands are for use in stores, they help | |
-- when we have to do spilling. Ideally we'll compile them away into | |
-- addressings of %esp in real code but for now it's simpler to have | |
-- that idiom represented explicitly. | |
data Operand = Temp Temp | Lit Int | InStack Int | |
deriving (Eq, Show) | |
-- We start with a definition of a simple, quasi-assembly like language. | |
-- I just have jumps, conditional jumps, load, store, and basic math | |
data Code = Label Label | |
| Asgn Temp Operand Op Operand | |
| Copy Temp Operand | |
| Jmp Label | |
| CJmp Operand RelOp Operand Label Label | |
| Load Temp Operand | |
| Store Operand Operand | |
deriving (Eq, Show) | |
operand2temp :: Operand -> Maybe Temp | |
operand2temp (Temp t) = Just t | |
operand2temp (Lit _) = Nothing | |
operand2temp (InStack _) = Nothing | |
toSet :: (Foldable f, Ord a) => f a -> Set a | |
toSet = foldMap S.singleton | |
class InstrLike a where | |
gen :: a -> Set Temp | |
kill :: a -> Set Temp | |
instance InstrLike Code where | |
kill (Asgn t _ _ _) = S.singleton t | |
kill (Copy t _) = S.singleton t | |
kill (Load t _) = S.singleton t | |
kill _ = S.empty | |
gen (Asgn _ l _ r) = foldMap toSet [operand2temp l, operand2temp r] | |
gen (Copy _ o) = toSet (operand2temp o) | |
gen (Jmp _) = S.empty | |
gen (CJmp l _ r _ _) = foldMap toSet [operand2temp l, operand2temp r] | |
gen (Load _ o) = toSet (operand2temp o) | |
gen (Store to o) = foldMap toSet [operand2temp to, operand2temp o] | |
gen (Label _) = S.empty | |
instance InstrLike a => InstrLike [a] where | |
gen xs = foldr (\n p -> (p S.\\ kill n) `S.union` gen n) S.empty xs | |
kill xs = foldMap kill xs | |
swap :: Temp -> Temp -> Temp -> Temp | |
swap new old t = if old == t then new else t | |
swapOper :: Temp -> Temp -> Operand -> Operand | |
swapOper new old (Temp t) = Temp (swap new old t) | |
swapOper _ _ oper = oper | |
replace :: Temp -> Temp -> Code -> Code | |
replace _ _ (Label l) = Label l | |
replace new old (Asgn t l o r) = | |
Asgn (swap new old t) (swapOper new old l) o (swapOper new old r) | |
replace new old (Copy t o) = Copy (swap new old t) (swapOper new old o) | |
replace _ _ (Jmp l) = Jmp l | |
replace new old (CJmp l o r t1 t2) = | |
CJmp (swapOper new old l) o (swapOper new old r) t1 t2 | |
replace new old (Load t o) = Load (swap new old t) (swapOper new old o) | |
replace new old (Store to o) = Store (swapOper new old to) (swapOper new old o) | |
-- Split code into basic blocks. | |
basicBlocks :: [Code] -> [[Code]] | |
basicBlocks = go | |
where go [] = [] | |
go (c : cs) = | |
case span (not . liftA2 (||) startBB endBB) cs of | |
(bb, []) -> [c : bb] | |
(bb, end : rest) | startBB end -> (c : bb) : go (end : rest) | |
| otherwise -> (c : bb ++ [end]) : go rest | |
startBB (Label _) = True | |
startBB _ = False | |
endBB (Jmp _) = True | |
endBB (CJmp _ _ _ _ _) = True | |
endBB _ = False | |
-- Add a label to the start of every basic block | |
labelBlocks :: [[Code]] -> [[Code]] | |
labelBlocks = go 0 | |
where go _ [] = [] | |
go i (bb@(Label _ : _) : rest) = bb : go i rest | |
go i (bb : rest) = (Label (Fresh i) : bb) : go (i + 1) rest | |
-- Tags for nodes in the control flow graph. The distinguished start | |
-- and stop tags are just for consistency. | |
data CfgNode = Start | End | CfgLabel Label | |
deriving (Eq, Show, Ord) | |
-- The representation of control flow graphs, the graph bit is just | |
-- a big set of edges and code labels are mapped to the appropriate | |
-- basic blocks in a separate map. | |
data Cfg = Cfg { cfgBlocks :: Map Label [Code] | |
, cfg :: Set (CfgNode, CfgNode) | |
} deriving (Eq, Show) | |
-- Convert basic blocks to our naive repr of a control flow graph | |
mkCfg :: [[Code]] -> Cfg | |
mkCfg [] = Cfg M.empty S.empty | |
mkCfg bbs@(firstBB : _) = | |
Cfg (M.fromList (zip (map lbl bbs) bbs)) | |
(S.fromList ((Start, CfgLabel $ lbl firstBB) : cfgEdges bbs)) | |
where lbl :: [Code] -> Label | |
lbl (Label i : _) = i | |
lbl _ = error "lbl: Not called on an annotated basic block" | |
flowsTo l1 (Jmp l2) _ = [(CfgLabel l1, CfgLabel l2)] | |
flowsTo l1 (CJmp _ _ _ l2 l3) _ = | |
[(CfgLabel l1, CfgLabel l2), (CfgLabel l1, CfgLabel l3)] | |
flowsTo l1 _ next = [(CfgLabel l1, next)] | |
cfgEdges [] = [] | |
cfgEdges [bb] = flowsTo (lbl bb) (last bb) End | |
cfgEdges (bb : next :rest) = | |
flowsTo (lbl bb) (last bb) (CfgLabel $ lbl next) | |
++ cfgEdges (next : rest) | |
data Live = Live { liveIn :: Set Temp | |
, liveOut :: Set Temp | |
} deriving (Eq, Show) | |
getCode :: Map Label [Code] -> CfgNode -> [Code] | |
getCode _ Start = [] | |
getCode _ End = [] | |
getCode m (CfgLabel l) = m M.! l | |
liveness :: Cfg -> Map CfgNode Live | |
liveness (Cfg blocks graph) = | |
coalesce (S.fromList labels) (M.fromList initial) | |
where labels = Start : End : map CfgLabel (M.keys blocks) | |
initial = map (\k -> (k, Live S.empty S.empty)) labels | |
work l liveMap = | |
let code = getCode blocks l | |
children = S.map snd . S.filter ((== l). fst) $ graph | |
parents = S.map fst . S.filter ((== l). snd) $ graph | |
newOut = foldMap (liveIn . (liveMap M.!)) children | |
newIn = gen code `S.union` (newOut S.\\ kill code) | |
new = Live newIn newOut | |
todo = if new == liveMap M.! l then S.empty else parents | |
in (todo, M.insert l new liveMap) | |
coalesce workList liveMap = | |
case S.minView workList of | |
Nothing -> liveMap | |
Just (l, workList') -> | |
let (redoList, newMap) = work l liveMap | |
in coalesce (S.union redoList workList') newMap | |
data InterferenceGraph = IGraph (Set (Temp, Temp)) deriving Show | |
-- Actually build an interference graph using the results of liveness | |
-- analysis | |
interfers :: Cfg -> Map CfgNode Live -> InterferenceGraph | |
interfers (Cfg blocks _) liveMap = | |
IGraph (prepareFinalGraph $ M.foldMapWithKey go liveMap) | |
where go Start _ = S.empty | |
go End _ = S.empty | |
go (CfgLabel l) (Live _ lout) = | |
let code = blocks M.! l | |
in execWriter $ foldM_ goInstr lout (reverse code) | |
goInstr lout i = | |
let lin = (lout S.\\ kill i) `S.union` gen i | |
edges = S.fromList [(u, v) | | |
u <- S.toList (kill i) | |
, v <- S.toList lout] | |
in tell edges >> return lin | |
prepareFinalGraph edges = | |
S.filter (uncurry (/=)) $ | |
edges `S.union` (S.map (\(a, b) -> (b, a)) edges) | |
data AllocState = AllocState { freshCount :: Int | |
, stackDepth :: Int } | |
deriving Show | |
type AllocM m = MonadState AllocState m | |
freshTemp :: AllocM m => m Temp | |
freshTemp = do | |
modify (\a -> a{freshCount = freshCount a + 1}) | |
freshCount <$> get | |
freshAddr :: AllocM m => m Operand | |
freshAddr = do | |
modify (\a -> a{stackDepth = stackDepth a + 1}) | |
InStack . stackDepth <$> get | |
runAllocM :: Temp -> State AllocState a -> a | |
runAllocM t m = evalState m (AllocState t 0) | |
-- Rewrite a CFG so things are loaded into a local temp and stored | |
-- upon each use. | |
rewriteSpill :: AllocM m => Cfg -> Temp -> m Cfg | |
rewriteSpill (Cfg blocks graph) t = do | |
addr <- freshAddr | |
blocks' <- traverse (fmap concat . mapM (spill addr)) blocks | |
return $ Cfg blocks' graph | |
where spill a i | |
| t `S.member` S.union (gen i) (kill i) = do | |
t' <- freshTemp | |
let is = | |
[ if S.member t (gen i) then [Load t' a] else [] | |
, [replace t' t i] | |
, if S.member t (kill i) then [Store a (Temp t')] else []] | |
in return (concat is) | |
| otherwise = return [i] | |
degree :: Temp -> InterferenceGraph -> Int | |
degree t (IGraph g) = S.size $ S.filter ((== t) . fst) g | |
data CoalesceState = CoalesceState { toSpill :: [Temp] | |
, toSimplify :: [Temp] | |
, handled :: [Temp] | |
, interferGraph :: Set (Temp, Temp) | |
, degrees :: Map Temp Int } | |
deriving Show | |
allTemps :: Cfg -> Set Temp | |
allTemps = foldMap (foldMap $ (<>) <$> gen <*> kill) . cfgBlocks | |
buildLists :: Cfg -> InterferenceGraph -> Int -> CoalesceState | |
buildLists cfg ig@(IGraph g) kColors = | |
CoalesceState (M.keys highDegree) (M.keys lowDegree) [] g dGraph | |
where (highDegree, lowDegree) = M.partition (>= kColors) dGraph | |
temps = S.toList (allTemps cfg) | |
dGraph = M.fromList [(t, degree t ig) | t <- temps] | |
assignColors :: InterferenceGraph -> Int -> [Temp] -> (Map Temp Int, Set Temp) | |
assignColors (IGraph graph) kColors = go (M.empty, S.empty) | |
where allColors = S.fromList [1 .. kColors] | |
go s [] = s | |
go (colors, spilled) (t : temps) = | |
let neighbors = S.map snd $ S.filter ((== t) . fst) graph | |
usedColors = foldMap (toSet . flip M.lookup colors) neighbors | |
goodColors = allColors S.\\ usedColors | |
in if S.null goodColors | |
then go (colors, S.insert t spilled) temps | |
else go (M.insert t (S.findMin goodColors) colors, spilled) temps | |
doSimplify :: Int -> CoalesceState -> CoalesceState | |
doSimplify _ c@CoalesceState{toSimplify = []} = c | |
doSimplify kColors (CoalesceState toSpill (t : toSimplify) handled g degrees) = | |
CoalesceState toSpill' toSimplify' (t : handled) igraph' degrees' | |
where neighbors = S.map snd $ S.filter ((== t) . fst) g | |
degrees' = foldl (flip $ M.adjust pred) degrees neighbors | |
saved = S.filter ((== kColors) . (degrees M.!)) neighbors | |
toSpill' = filter (not . (`S.member` saved)) toSpill | |
toSimplify' = S.toList saved ++ toSimplify | |
igraph' = S.filter (\(a, b) -> a == t || b == t) g | |
registerAlloc :: AllocM m => Int -> Cfg -> m (Cfg, Map Temp Int) | |
registerAlloc kColors cfg = | |
let tempStack = handled $ coalesce (buildLists cfg igraph kColors) | |
(coloring, spills) = assignColors igraph kColors tempStack | |
in if S.null spills | |
then return (cfg, coloring) | |
else foldM rewriteSpill cfg spills >>= registerAlloc kColors | |
where igraph = interfers cfg (liveness cfg) | |
coalesce c@(CoalesceState {..}) = | |
case (toSimplify, toSpill) of | |
(_ : _, _) -> coalesce (doSimplify kColors c) | |
([], t : rest) -> coalesce (c {toSimplify = [t], toSpill = rest}) | |
([], []) -> c | |
alloc :: Int -> Cfg -> (Cfg, Map Temp Int) | |
alloc kColors cfg = | |
runAllocM (maximum $ allTemps cfg) $ registerAlloc kColors cfg | |
smallTest :: Cfg | |
smallTest = mkCfg . labelBlocks . basicBlocks | |
$ [ Copy 1 (Temp 3) | |
, Label 1 | |
, Asgn 2 (Temp 1) Add (Lit 1) | |
, Asgn 3 (Temp 2) Add (Temp 3) | |
, Asgn 1 (Temp 2) Mul (Lit 2) | |
, CJmp (Temp 1) Lt (Lit 20) 1 2 | |
, Label 2 | |
, Store (Lit 0) (Temp 2) | |
] | |
displayBlocks :: Cfg -> IO () | |
displayBlocks = mapM_ (\b -> do mapM_ print b; putStrLn "") . cfgBlocks |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment