Last active
January 25, 2017 16:12
-
-
Save stevana/a2bb4640d1552dd69abcf93a30d0814e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE ConstraintKinds #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE FunctionalDependencies #-} | |
{-# LANGUAGE GeneralizedNewtypeDeriving #-} | |
{-# LANGUAGE TemplateHaskell #-} | |
{-# LANGUAGE TypeFamilies #-} | |
module IORefPrePostRaceCondition where | |
import Control.Concurrent (threadDelay) | |
import Control.Concurrent.ParallelIO.Local | |
import Control.Concurrent.STM.TChan | |
import Control.Lens hiding (pre) | |
import Control.Monad.Free | |
import Control.Monad.Free.TH | |
import Control.Monad.State | |
import Control.Monad.STM | |
import Data.IORef | |
import Data.Monoid | |
import System.IO.Unsafe | |
import System.Random | |
import Test.QuickCheck | |
import Test.QuickCheck.Monadic | |
import Text.Printf | |
------------------------------------------------------------------------ | |
newtype Ref = Ref Int | |
deriving (Eq, Read, Num, Ord, Show) | |
class PPShow a where | |
ppShow :: a -> String | |
instance PPShow Ref where | |
ppShow (Ref i) = "$" ++ show i | |
makeWrapped ''Ref | |
data MemStepF ref | |
= New | |
| Read ref | |
| Write ref Int | |
| Inc ref | |
deriving (Show, Eq, Read, Functor) | |
type MemStep = MemStepF Ref | |
newtype Mem = Mem [MemStep] | |
deriving (Show, Eq, Monoid, Read) | |
makeWrapped ''Mem | |
------------------------------------------------------------------------ | |
type Model = [Int] | |
initModel :: Model | |
initModel = [] | |
new_pre :: Model -> Bool | |
new_pre _ = True | |
new_next :: Model -> () -> Ref -> Model | |
new_next m _ ref = m ++ [0] | |
new_post :: Model -> () -> Ref -> Bool | |
new_post _ _ _ = True | |
read_pre :: Model -> Bool | |
read_pre = not . null | |
read_next :: Model -> Ref -> Int -> Model | |
read_next m _ _ = m | |
read_post :: Model -> Ref -> Int -> Bool | |
read_post m ref r = m !! op Ref ref == r && read_next m ref r == m | |
write_pre :: Model -> Bool | |
write_pre = not . null | |
write_next :: Model -> (Ref, Int) -> () -> Model | |
write_next m (ref, i) _ = m & element (op Ref ref) .~ i | |
write_post :: Model -> (Ref, Int) -> () -> Bool | |
write_post m (ref, i) _ = write_next m (ref, i) () !! op Ref ref == i | |
inc_pre :: Model -> Bool | |
inc_pre = not . null | |
inc_next :: Model -> Ref -> () -> Model | |
inc_next m ref _ = m & element (op Ref ref) %~ succ | |
inc_post :: Model -> Ref -> () -> Bool | |
inc_post m ref _ = m' !! op Ref ref == m !! op Ref ref + 1 | |
where | |
m' = inc_next m ref () | |
------------------------------------------------------------------------ | |
data Invocation | |
= NewI | |
| ReadI Ref | |
| WriteI Ref Int | |
| IncI Ref | |
deriving (Show, Read, Eq) | |
instance PPShow Invocation where | |
ppShow NewI = printf "%-8s" ("> new" :: String) | |
ppShow (ReadI ref) = printf "%-8s%s" ("> read" :: String) (ppShow ref) | |
ppShow (WriteI ref i) = printf "%-8s%s %s" ("> write" :: String) (ppShow ref) (show i) | |
ppShow (IncI ref) = printf "%-8s%s" ("> inc" :: String) (ppShow ref) | |
data Response | |
= NewR Ref | |
| ReadR Int | |
| WriteR | |
| IncR | |
deriving (Show, Read, Eq) | |
instance PPShow Response where | |
ppShow (NewR ref) = printf "%-8s%s" ("< new" :: String) ("[" ++ ppShow ref ++ "]") | |
ppShow (ReadR i) = printf "%-8s%s" ("< read" :: String) ("[" ++ show i ++ "]") | |
ppShow WriteR = printf "%-8s" ("< write" :: String) | |
ppShow IncR = printf "%-8s" ("< inc" :: String) | |
newtype ProcessId = ProcessId Int | |
deriving (Eq, Num, Show, Read) | |
instance PPShow ProcessId where | |
ppShow (ProcessId i) = show i | |
data HistoryEvent | |
= InvocationEvent Invocation ProcessId | |
| ResponseEvent Response ProcessId | |
deriving (Eq, Show, Read) | |
instance PPShow HistoryEvent where | |
ppShow (InvocationEvent i pid) = printf "%-15s%3s" (ppShow i) (" {" ++ ppShow pid ++ "}") | |
ppShow (ResponseEvent r pid) = printf "%-15s%3s" (ppShow r) (" {" ++ ppShow pid ++ "}") | |
type History = [HistoryEvent] | |
------------------------------------------------------------------------ | |
type OurMonad m = (MonadIO m, MonadState Env m) | |
data Env = Env | |
{ _refs :: [IORef Int] | |
, _historyChan :: TChan HistoryEvent | |
, _pid :: ProcessId | |
} | |
makeLenses ''Env | |
defaultEnv :: Env | |
defaultEnv = Env [] (unsafePerformIO newTChanIO) 0 | |
inv :: OurMonad m => Invocation -> m () | |
inv ev = do | |
pid <- use pid | |
chan <- use historyChan | |
liftIO $ atomically $ writeTChan chan $ InvocationEvent ev pid | |
resp :: OurMonad m => Response -> m () | |
resp ev = do | |
pid <- use pid | |
chan <- use historyChan | |
liftIO $ atomically $ writeTChan chan $ ResponseEvent ev pid | |
sleep :: IO () | |
sleep = threadDelay =<< randomRIO (0, 20000) | |
data Type = RefT (IORef Int) | IntT Int | UnitT | |
semStep :: OurMonad m => MemStep -> m Type | |
semStep New = do | |
inv NewI | |
len <- length <$> use refs | |
ref <- liftIO $ newIORef 0 | |
refs %= (++ [ref]) | |
resp $ NewR $ Ref len | |
return $ RefT ref | |
semStep (Read ref) = do | |
inv $ ReadI ref | |
rs <- use refs | |
i <- liftIO $ readIORef $ rs !! op Ref ref | |
liftIO $ sleep | |
resp $ ReadR i | |
return $ IntT i | |
semStep (Write ref i) = do | |
inv $ WriteI ref i | |
rs <- use refs | |
if i `elem` [5..10] | |
-- Introduce a bug: | |
then liftIO $ writeIORef (rs !! op Ref ref) $ i -- + 1 | |
else liftIO $ writeIORef (rs !! op Ref ref) $ i | |
liftIO $ sleep | |
resp $ WriteR | |
return UnitT | |
semStep (Inc ref) = do | |
inv $ IncI ref | |
rs <- use refs | |
liftIO $ do | |
-- Possible race condition: | |
i <- readIORef (rs !! op Ref ref) | |
sleep | |
writeIORef (rs !! op Ref ref) (i + 1) | |
-- The fix: | |
-- modifyIORef' (rs !! op Ref ref) succ | |
resp $ IncR | |
return UnitT | |
sem :: OurMonad m => Mem -> m () | |
sem = foldM (\ih m -> semStep m >> return ih) () . op Mem | |
runMem :: Env -> Mem -> IO Env | |
runMem e = flip execStateT e . sem | |
debugMem :: Mem -> IO () | |
debugMem m = do | |
putStrLn "" | |
(t, env) <- flip runStateT defaultEnv $ sem m | |
putStrLn "" | |
forM_ (zip [0..] (env^.refs)) $ \(i, ref) -> do | |
v <- readIORef ref | |
putStrLn $ "$" ++ show i ++ ": " ++ show v | |
hist <- getChanContents $ env^.historyChan | |
print hist | |
return () | |
getChanContents :: TChan a -> IO [a] | |
getChanContents chan = do | |
xs <- atomically $ go [] | |
return $ reverse xs | |
where | |
-- go :: [a] -> STM [a] | |
go acc = do | |
mx <- tryReadTChan chan | |
case mx of | |
Just x -> go $ x : acc | |
Nothing -> return acc | |
------------------------------------------------------------------------ | |
gen1 :: Model -> Gen (MemStep, Model) | |
gen1 m = frequency | |
[ (if new_pre m then 1 else 0, new_gen m) | |
, (if read_pre m then 5 else 0, read_gen m) | |
, (if write_pre m then 5 else 0, write_gen m) | |
, (if inc_pre m then 5 else 0, inc_gen m) | |
] | |
where | |
new_gen :: Model -> Gen (MemStep, Model) | |
new_gen m = return (New, m ++ [0]) | |
read_gen :: Model -> Gen (MemStep, Model) | |
read_gen m = do | |
ref <- Ref <$> choose (0, length m - 1) | |
return (Read ref, m) -- read_next m ref (m !! (op Ref ref))) | |
write_gen :: Model -> Gen (MemStep, Model) | |
write_gen m = do | |
ref <- Ref <$> choose (0, length m - 1) | |
i <- arbitrary | |
return (Write ref i, write_next m (ref, i) ()) | |
inc_gen :: Model -> Gen (MemStep, Model) | |
inc_gen m = do | |
ref <- Ref <$> choose (0, length m - 1) | |
return (Inc ref, inc_next m ref ()) | |
genMem :: Gen Mem | |
genMem = sized $ go [] | |
where | |
go :: Model -> Int -> Gen Mem | |
go m 0 = return $ Mem [] | |
go m n = do | |
(mem, m') <- gen1 m | |
Mem ih <- go m' (n - 1) | |
return $ Mem $ mem : ih | |
------------------------------------------------------------------------ | |
instance Arbitrary Mem where | |
arbitrary = genMem | |
shrink = map Mem . shrinkMem . op Mem | |
shrinkMem :: [MemStep] -> [[MemStep]] | |
shrinkMem [] = [] | |
shrinkMem (New : ms) = | |
[ [] ] | |
++ [ map (fmap (\ref -> max 0 (ref - 1))) ms ] | |
++ [ New : ms' | ms' <- shrinkMem ms ] | |
shrinkMem (Read ref : ms) = | |
[ [] ] | |
++ [ ms ] | |
++ [ Read ref : ms' | ms' <- shrinkMem ms ] | |
shrinkMem (Write ref i : ms) = | |
[ [] ] | |
++ [ ms ] | |
++ [ Write ref i' : ms' | (i', Mem ms') <- shrink (i, Mem ms) ] | |
shrinkMem (Inc ref : ms) = | |
[ [] ] | |
++ [ ms ] | |
++ [ Inc ref : ms' | ms' <- shrinkMem ms ] | |
------------------------------------------------------------------------ | |
monadicOur :: PropertyM (StateT Env IO) () -> Property | |
monadicOur = monadic $ ioProperty . flip evalStateT defaultEnv | |
prop_safety :: Property | |
prop_safety = forAllShrink genMem shrink $ monadicOur . go [] . op Mem | |
where | |
go :: OurMonad m => Model -> [MemStep] -> PropertyM m () | |
go _ [] = return () | |
go m (New : ms) = do | |
monitor $ collect "new" | |
pre $ new_pre m | |
r <- run $ semStep New | |
case r of | |
RefT ref -> do | |
-- We don't have a model ref here. | |
-- let r = undefined | |
-- assert' "`new_post'" $ new_post m () r | |
-- go (new_next m () r) ms | |
go (m ++ [0]) ms | |
_ -> fail "not ref" | |
go m (Read ref : ms) = do | |
monitor $ collect "read" | |
pre $ read_pre m | |
r <- run $ semStep $ Read ref | |
case r of | |
IntT i -> do | |
assert' ("`read_post': read " ++ show ref ++ " ["++ show i ++ "]") | |
$ read_post m ref i | |
go (read_next m ref i) ms | |
_ -> fail "not int" | |
go m (Write ref i : ms) = do | |
monitor $ collect "write" | |
pre $ write_pre m | |
r <- run $ semStep $ Write ref i | |
case r of | |
UnitT -> do | |
assert' "`write_post'" $ write_post m (ref, i) () | |
go (write_next m (ref, i) ()) ms | |
_ -> fail "not unit" | |
go m (Inc ref : ms) = do | |
monitor $ collect "inc" | |
pre $ inc_pre m | |
r <- run $ semStep $ Inc ref | |
case r of | |
UnitT -> do | |
assert' "`inc_post'" $ inc_post m ref () | |
go (inc_next m ref ()) ms | |
_ -> fail "not unit" | |
assert' :: Monad m => String -> Bool -> PropertyM m () | |
assert' msg True = return () | |
assert' msg False = fail $ msg ++ " failed" | |
testP :: IO () | |
testP = verboseCheckWith (stdArgs { maxSize = 10 }) prop_safety | |
prop_shrink :: Property -> Property | |
prop_shrink prop = monadicIO $ do | |
result <- run $ quickCheckWithResult (stdArgs {chatty = False}) prop | |
case result of | |
Failure { output = output } -> do | |
let lastLine = last $ lines output | |
assert $ lastLine == "Mem [New,Write (Ref 0) 5,Read (Ref 0)]" | |
_ -> return () | |
testS :: IO () | |
testS = quickCheckWith (stdArgs { maxSuccess = 500 }) (prop_shrink prop_safety) | |
------------------------------------------------------------------------ | |
data ParMem = ParMem | |
{ _prefix :: Mem | |
, _left :: Mem | |
, _right :: Mem | |
} | |
deriving Show | |
makeLenses ''ParMem | |
genParMem :: Gen ParMem | |
genParMem = do | |
(prefix, m) <- genPreCondsMem initModel | |
left <- genMem' m | |
right <- genMem' m | |
return $ ParMem prefix left right | |
genPreCondsMem :: Model -> Gen (Mem, Model) | |
genPreCondsMem m | |
| read_pre m && write_pre m = return (Mem [], m) | |
| otherwise = do | |
(mem, m') <- gen1 m | |
(Mem ih, m'') <- genPreCondsMem m' | |
return (Mem $ mem : ih, m'') | |
genMem' :: Model -> Gen Mem | |
genMem' m = sized $ go m | |
where | |
go :: Model -> Int -> Gen Mem | |
go m 0 = return $ Mem [] | |
go m n = do | |
(mem, m') <- gen1 m | |
Mem ih <- go m' (n - 1) | |
return $ Mem $ mem : ih | |
instance Arbitrary ParMem where | |
arbitrary = genParMem | |
shrink (ParMem p l r) = [ ParMem p' l' r' | (p', l', r') <- shrink (p, l, r) ] | |
------------------------------------------------------------------------ | |
data Rose a = Rose a [Rose a] | |
deriving Show | |
takeInvocations :: History -> [HistoryEvent] | |
takeInvocations = takeWhile $ \h -> case h of | |
InvocationEvent _ _ -> True | |
_ -> False | |
findCorrespondingResp :: ProcessId -> History -> [(Response, History)] | |
findCorrespondingResp pid [] = [] | |
findCorrespondingResp pid (ResponseEvent r pid' : es) | pid == pid' = [(r, es)] | |
findCorrespondingResp pid (e : es) = | |
[ (res, e:es') | (res, es') <- findCorrespondingResp pid es] | |
data Operation = Operation Invocation Response ProcessId | |
deriving Show | |
linearTree :: History -> [Rose Operation] | |
linearTree [] = [] | |
linearTree es = | |
[ Rose (Operation inv resp pid) (linearTree es') | |
| e@(InvocationEvent inv pid) <- takeInvocations es | |
, (resp, es') <- findCorrespondingResp pid $ filter1 (/= e) es | |
] | |
where | |
filter1 :: (a -> Bool) -> [a] -> [a] | |
filter1 p [] = [] | |
filter1 p (x : xs) | p x = x : filter1 p xs | |
| otherwise = xs | |
fromRose :: Rose a -> [a] | |
fromRose (Rose x xs) = x : concatMap fromRose xs | |
linearise :: History -> [[Operation]] | |
linearise = map fromRose . filter (postConditionsHold initModel) . linearTree | |
where | |
any' :: (a -> Bool) -> [a] -> Bool | |
any' p [] = True | |
any' p xs = any p xs | |
postConditionsHold :: Model -> Rose Operation -> Bool | |
postConditionsHold m (Rose (Operation NewI (NewR ref) _) ops) | |
= new_post m () ref && any' (postConditionsHold (new_next m () ref)) ops | |
postConditionsHold m (Rose (Operation (ReadI ref) (ReadR i) _) ops) | |
= read_post m ref i && any' (postConditionsHold (read_next m ref i)) ops | |
postConditionsHold m (Rose (Operation (WriteI ref i) WriteR _) ops) | |
= write_post m (ref, i) () && any' (postConditionsHold (write_next m (ref, i) ())) ops | |
postConditionsHold m (Rose (Operation (IncI ref) IncR _) ops) | |
= inc_post m ref () && any' (postConditionsHold (inc_next m ref ())) ops | |
------------------------------------------------------------------------ | |
preConds :: Mem -> Bool | |
preConds = go initModel . op Mem | |
where | |
go m [] = True | |
go m (New : ms) = new_pre m && go (m ++ [0] ) ms | |
go m (Read ref : ms) = read_pre m && go m ms | |
go m (Write ref _ : ms) = write_pre m && go m ms | |
go m (Inc ref : ms) = inc_pre m && go m ms | |
prop_par :: Property | |
prop_par = forAllShrink genParMem shrink $ monadicIO . replicateM_ 10 . go initModel | |
where | |
go :: Model -> ParMem -> PropertyM IO () | |
go m (ParMem prefix left right) = do | |
pre $ preConds $ prefix <> left | |
pre $ preConds $ prefix <> right | |
e <- run $ runMem defaultEnv prefix | |
run $ withPool 2 $ \pool -> do | |
parallel_ pool [ runMem (e & pid .~ 1) left | |
, runMem (e & pid .~ 2) right | |
] | |
hist <- run $ getChanContents $ e^.historyChan | |
monitor $ counterexample $ unlines $ map ppShow hist | |
assert $ not $ null $ linearise hist | |
------------------------------------------------------------------------ | |
testC :: IO () | |
testC = verboseCheckWith (stdArgs { maxSize = 10 }) prop_par | |
------------------------------------------------------------------------ | |
prop_safety' :: Property | |
prop_safety' = forAllShrink genMem shrink $ monadicIO . go [] | |
where | |
go :: Model -> Mem -> PropertyM IO () | |
go m mem = do | |
pre $ preConds mem | |
e <- run $ runMem defaultEnv mem | |
hist <- run $ getChanContents $ e^.historyChan | |
checkPostConds initModel hist | |
checkPostConds :: Model -> History -> PropertyM IO () | |
checkPostConds m [] = return () | |
checkPostConds m (InvocationEvent NewI _ | |
: ResponseEvent (NewR ref) _ : es) = do | |
assert' "new_post" $ new_post m () ref | |
checkPostConds (new_next m () ref) es | |
checkPostConds m (InvocationEvent (ReadI ref) _ | |
: ResponseEvent (ReadR i) _ : es) = do | |
assert' "read_post" $ read_post m ref i | |
checkPostConds (read_next m ref i) es | |
checkPostConds m (InvocationEvent (WriteI ref i) _ | |
: ResponseEvent WriteR _ : es) = do | |
assert' "write_post" $ write_post m (ref, i) () | |
checkPostConds (write_next m (ref, i) ()) es | |
checkPostConds m (InvocationEvent (IncI ref) _ | |
: ResponseEvent IncR _ : es) = do | |
assert' "inc_post" $ inc_post m ref () | |
checkPostConds (inc_next m ref ()) es | |
checkPostConds m es = fail "" | |
testP' :: IO () | |
testP' = verboseCheckWith (stdArgs { maxSize = 10 }) prop_safety' | |
testS' :: IO () | |
testS' = quickCheckWith (stdArgs { maxSuccess = 500 }) (prop_shrink prop_safety') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example run of
testC
: