Skip to content

Instantly share code, notes, and snippets.

@stevana
Last active January 25, 2017 16:12
Show Gist options
  • Save stevana/a2bb4640d1552dd69abcf93a30d0814e to your computer and use it in GitHub Desktop.
Save stevana/a2bb4640d1552dd69abcf93a30d0814e to your computer and use it in GitHub Desktop.
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
module IORefPrePostRaceCondition where
import Control.Concurrent (threadDelay)
import Control.Concurrent.ParallelIO.Local
import Control.Concurrent.STM.TChan
import Control.Lens hiding (pre)
import Control.Monad.Free
import Control.Monad.Free.TH
import Control.Monad.State
import Control.Monad.STM
import Data.IORef
import Data.Monoid
import System.IO.Unsafe
import System.Random
import Test.QuickCheck
import Test.QuickCheck.Monadic
import Text.Printf
------------------------------------------------------------------------
newtype Ref = Ref Int
deriving (Eq, Read, Num, Ord, Show)
class PPShow a where
ppShow :: a -> String
instance PPShow Ref where
ppShow (Ref i) = "$" ++ show i
makeWrapped ''Ref
data MemStepF ref
= New
| Read ref
| Write ref Int
| Inc ref
deriving (Show, Eq, Read, Functor)
type MemStep = MemStepF Ref
newtype Mem = Mem [MemStep]
deriving (Show, Eq, Monoid, Read)
makeWrapped ''Mem
------------------------------------------------------------------------
type Model = [Int]
initModel :: Model
initModel = []
new_pre :: Model -> Bool
new_pre _ = True
new_next :: Model -> () -> Ref -> Model
new_next m _ ref = m ++ [0]
new_post :: Model -> () -> Ref -> Bool
new_post _ _ _ = True
read_pre :: Model -> Bool
read_pre = not . null
read_next :: Model -> Ref -> Int -> Model
read_next m _ _ = m
read_post :: Model -> Ref -> Int -> Bool
read_post m ref r = m !! op Ref ref == r && read_next m ref r == m
write_pre :: Model -> Bool
write_pre = not . null
write_next :: Model -> (Ref, Int) -> () -> Model
write_next m (ref, i) _ = m & element (op Ref ref) .~ i
write_post :: Model -> (Ref, Int) -> () -> Bool
write_post m (ref, i) _ = write_next m (ref, i) () !! op Ref ref == i
inc_pre :: Model -> Bool
inc_pre = not . null
inc_next :: Model -> Ref -> () -> Model
inc_next m ref _ = m & element (op Ref ref) %~ succ
inc_post :: Model -> Ref -> () -> Bool
inc_post m ref _ = m' !! op Ref ref == m !! op Ref ref + 1
where
m' = inc_next m ref ()
------------------------------------------------------------------------
data Invocation
= NewI
| ReadI Ref
| WriteI Ref Int
| IncI Ref
deriving (Show, Read, Eq)
instance PPShow Invocation where
ppShow NewI = printf "%-8s" ("> new" :: String)
ppShow (ReadI ref) = printf "%-8s%s" ("> read" :: String) (ppShow ref)
ppShow (WriteI ref i) = printf "%-8s%s %s" ("> write" :: String) (ppShow ref) (show i)
ppShow (IncI ref) = printf "%-8s%s" ("> inc" :: String) (ppShow ref)
data Response
= NewR Ref
| ReadR Int
| WriteR
| IncR
deriving (Show, Read, Eq)
instance PPShow Response where
ppShow (NewR ref) = printf "%-8s%s" ("< new" :: String) ("[" ++ ppShow ref ++ "]")
ppShow (ReadR i) = printf "%-8s%s" ("< read" :: String) ("[" ++ show i ++ "]")
ppShow WriteR = printf "%-8s" ("< write" :: String)
ppShow IncR = printf "%-8s" ("< inc" :: String)
newtype ProcessId = ProcessId Int
deriving (Eq, Num, Show, Read)
instance PPShow ProcessId where
ppShow (ProcessId i) = show i
data HistoryEvent
= InvocationEvent Invocation ProcessId
| ResponseEvent Response ProcessId
deriving (Eq, Show, Read)
instance PPShow HistoryEvent where
ppShow (InvocationEvent i pid) = printf "%-15s%3s" (ppShow i) (" {" ++ ppShow pid ++ "}")
ppShow (ResponseEvent r pid) = printf "%-15s%3s" (ppShow r) (" {" ++ ppShow pid ++ "}")
type History = [HistoryEvent]
------------------------------------------------------------------------
type OurMonad m = (MonadIO m, MonadState Env m)
data Env = Env
{ _refs :: [IORef Int]
, _historyChan :: TChan HistoryEvent
, _pid :: ProcessId
}
makeLenses ''Env
defaultEnv :: Env
defaultEnv = Env [] (unsafePerformIO newTChanIO) 0
inv :: OurMonad m => Invocation -> m ()
inv ev = do
pid <- use pid
chan <- use historyChan
liftIO $ atomically $ writeTChan chan $ InvocationEvent ev pid
resp :: OurMonad m => Response -> m ()
resp ev = do
pid <- use pid
chan <- use historyChan
liftIO $ atomically $ writeTChan chan $ ResponseEvent ev pid
sleep :: IO ()
sleep = threadDelay =<< randomRIO (0, 20000)
data Type = RefT (IORef Int) | IntT Int | UnitT
semStep :: OurMonad m => MemStep -> m Type
semStep New = do
inv NewI
len <- length <$> use refs
ref <- liftIO $ newIORef 0
refs %= (++ [ref])
resp $ NewR $ Ref len
return $ RefT ref
semStep (Read ref) = do
inv $ ReadI ref
rs <- use refs
i <- liftIO $ readIORef $ rs !! op Ref ref
liftIO $ sleep
resp $ ReadR i
return $ IntT i
semStep (Write ref i) = do
inv $ WriteI ref i
rs <- use refs
if i `elem` [5..10]
-- Introduce a bug:
then liftIO $ writeIORef (rs !! op Ref ref) $ i -- + 1
else liftIO $ writeIORef (rs !! op Ref ref) $ i
liftIO $ sleep
resp $ WriteR
return UnitT
semStep (Inc ref) = do
inv $ IncI ref
rs <- use refs
liftIO $ do
-- Possible race condition:
i <- readIORef (rs !! op Ref ref)
sleep
writeIORef (rs !! op Ref ref) (i + 1)
-- The fix:
-- modifyIORef' (rs !! op Ref ref) succ
resp $ IncR
return UnitT
sem :: OurMonad m => Mem -> m ()
sem = foldM (\ih m -> semStep m >> return ih) () . op Mem
runMem :: Env -> Mem -> IO Env
runMem e = flip execStateT e . sem
debugMem :: Mem -> IO ()
debugMem m = do
putStrLn ""
(t, env) <- flip runStateT defaultEnv $ sem m
putStrLn ""
forM_ (zip [0..] (env^.refs)) $ \(i, ref) -> do
v <- readIORef ref
putStrLn $ "$" ++ show i ++ ": " ++ show v
hist <- getChanContents $ env^.historyChan
print hist
return ()
getChanContents :: TChan a -> IO [a]
getChanContents chan = do
xs <- atomically $ go []
return $ reverse xs
where
-- go :: [a] -> STM [a]
go acc = do
mx <- tryReadTChan chan
case mx of
Just x -> go $ x : acc
Nothing -> return acc
------------------------------------------------------------------------
gen1 :: Model -> Gen (MemStep, Model)
gen1 m = frequency
[ (if new_pre m then 1 else 0, new_gen m)
, (if read_pre m then 5 else 0, read_gen m)
, (if write_pre m then 5 else 0, write_gen m)
, (if inc_pre m then 5 else 0, inc_gen m)
]
where
new_gen :: Model -> Gen (MemStep, Model)
new_gen m = return (New, m ++ [0])
read_gen :: Model -> Gen (MemStep, Model)
read_gen m = do
ref <- Ref <$> choose (0, length m - 1)
return (Read ref, m) -- read_next m ref (m !! (op Ref ref)))
write_gen :: Model -> Gen (MemStep, Model)
write_gen m = do
ref <- Ref <$> choose (0, length m - 1)
i <- arbitrary
return (Write ref i, write_next m (ref, i) ())
inc_gen :: Model -> Gen (MemStep, Model)
inc_gen m = do
ref <- Ref <$> choose (0, length m - 1)
return (Inc ref, inc_next m ref ())
genMem :: Gen Mem
genMem = sized $ go []
where
go :: Model -> Int -> Gen Mem
go m 0 = return $ Mem []
go m n = do
(mem, m') <- gen1 m
Mem ih <- go m' (n - 1)
return $ Mem $ mem : ih
------------------------------------------------------------------------
instance Arbitrary Mem where
arbitrary = genMem
shrink = map Mem . shrinkMem . op Mem
shrinkMem :: [MemStep] -> [[MemStep]]
shrinkMem [] = []
shrinkMem (New : ms) =
[ [] ]
++ [ map (fmap (\ref -> max 0 (ref - 1))) ms ]
++ [ New : ms' | ms' <- shrinkMem ms ]
shrinkMem (Read ref : ms) =
[ [] ]
++ [ ms ]
++ [ Read ref : ms' | ms' <- shrinkMem ms ]
shrinkMem (Write ref i : ms) =
[ [] ]
++ [ ms ]
++ [ Write ref i' : ms' | (i', Mem ms') <- shrink (i, Mem ms) ]
shrinkMem (Inc ref : ms) =
[ [] ]
++ [ ms ]
++ [ Inc ref : ms' | ms' <- shrinkMem ms ]
------------------------------------------------------------------------
monadicOur :: PropertyM (StateT Env IO) () -> Property
monadicOur = monadic $ ioProperty . flip evalStateT defaultEnv
prop_safety :: Property
prop_safety = forAllShrink genMem shrink $ monadicOur . go [] . op Mem
where
go :: OurMonad m => Model -> [MemStep] -> PropertyM m ()
go _ [] = return ()
go m (New : ms) = do
monitor $ collect "new"
pre $ new_pre m
r <- run $ semStep New
case r of
RefT ref -> do
-- We don't have a model ref here.
-- let r = undefined
-- assert' "`new_post'" $ new_post m () r
-- go (new_next m () r) ms
go (m ++ [0]) ms
_ -> fail "not ref"
go m (Read ref : ms) = do
monitor $ collect "read"
pre $ read_pre m
r <- run $ semStep $ Read ref
case r of
IntT i -> do
assert' ("`read_post': read " ++ show ref ++ " ["++ show i ++ "]")
$ read_post m ref i
go (read_next m ref i) ms
_ -> fail "not int"
go m (Write ref i : ms) = do
monitor $ collect "write"
pre $ write_pre m
r <- run $ semStep $ Write ref i
case r of
UnitT -> do
assert' "`write_post'" $ write_post m (ref, i) ()
go (write_next m (ref, i) ()) ms
_ -> fail "not unit"
go m (Inc ref : ms) = do
monitor $ collect "inc"
pre $ inc_pre m
r <- run $ semStep $ Inc ref
case r of
UnitT -> do
assert' "`inc_post'" $ inc_post m ref ()
go (inc_next m ref ()) ms
_ -> fail "not unit"
assert' :: Monad m => String -> Bool -> PropertyM m ()
assert' msg True = return ()
assert' msg False = fail $ msg ++ " failed"
testP :: IO ()
testP = verboseCheckWith (stdArgs { maxSize = 10 }) prop_safety
prop_shrink :: Property -> Property
prop_shrink prop = monadicIO $ do
result <- run $ quickCheckWithResult (stdArgs {chatty = False}) prop
case result of
Failure { output = output } -> do
let lastLine = last $ lines output
assert $ lastLine == "Mem [New,Write (Ref 0) 5,Read (Ref 0)]"
_ -> return ()
testS :: IO ()
testS = quickCheckWith (stdArgs { maxSuccess = 500 }) (prop_shrink prop_safety)
------------------------------------------------------------------------
data ParMem = ParMem
{ _prefix :: Mem
, _left :: Mem
, _right :: Mem
}
deriving Show
makeLenses ''ParMem
genParMem :: Gen ParMem
genParMem = do
(prefix, m) <- genPreCondsMem initModel
left <- genMem' m
right <- genMem' m
return $ ParMem prefix left right
genPreCondsMem :: Model -> Gen (Mem, Model)
genPreCondsMem m
| read_pre m && write_pre m = return (Mem [], m)
| otherwise = do
(mem, m') <- gen1 m
(Mem ih, m'') <- genPreCondsMem m'
return (Mem $ mem : ih, m'')
genMem' :: Model -> Gen Mem
genMem' m = sized $ go m
where
go :: Model -> Int -> Gen Mem
go m 0 = return $ Mem []
go m n = do
(mem, m') <- gen1 m
Mem ih <- go m' (n - 1)
return $ Mem $ mem : ih
instance Arbitrary ParMem where
arbitrary = genParMem
shrink (ParMem p l r) = [ ParMem p' l' r' | (p', l', r') <- shrink (p, l, r) ]
------------------------------------------------------------------------
data Rose a = Rose a [Rose a]
deriving Show
takeInvocations :: History -> [HistoryEvent]
takeInvocations = takeWhile $ \h -> case h of
InvocationEvent _ _ -> True
_ -> False
findCorrespondingResp :: ProcessId -> History -> [(Response, History)]
findCorrespondingResp pid [] = []
findCorrespondingResp pid (ResponseEvent r pid' : es) | pid == pid' = [(r, es)]
findCorrespondingResp pid (e : es) =
[ (res, e:es') | (res, es') <- findCorrespondingResp pid es]
data Operation = Operation Invocation Response ProcessId
deriving Show
linearTree :: History -> [Rose Operation]
linearTree [] = []
linearTree es =
[ Rose (Operation inv resp pid) (linearTree es')
| e@(InvocationEvent inv pid) <- takeInvocations es
, (resp, es') <- findCorrespondingResp pid $ filter1 (/= e) es
]
where
filter1 :: (a -> Bool) -> [a] -> [a]
filter1 p [] = []
filter1 p (x : xs) | p x = x : filter1 p xs
| otherwise = xs
fromRose :: Rose a -> [a]
fromRose (Rose x xs) = x : concatMap fromRose xs
linearise :: History -> [[Operation]]
linearise = map fromRose . filter (postConditionsHold initModel) . linearTree
where
any' :: (a -> Bool) -> [a] -> Bool
any' p [] = True
any' p xs = any p xs
postConditionsHold :: Model -> Rose Operation -> Bool
postConditionsHold m (Rose (Operation NewI (NewR ref) _) ops)
= new_post m () ref && any' (postConditionsHold (new_next m () ref)) ops
postConditionsHold m (Rose (Operation (ReadI ref) (ReadR i) _) ops)
= read_post m ref i && any' (postConditionsHold (read_next m ref i)) ops
postConditionsHold m (Rose (Operation (WriteI ref i) WriteR _) ops)
= write_post m (ref, i) () && any' (postConditionsHold (write_next m (ref, i) ())) ops
postConditionsHold m (Rose (Operation (IncI ref) IncR _) ops)
= inc_post m ref () && any' (postConditionsHold (inc_next m ref ())) ops
------------------------------------------------------------------------
preConds :: Mem -> Bool
preConds = go initModel . op Mem
where
go m [] = True
go m (New : ms) = new_pre m && go (m ++ [0] ) ms
go m (Read ref : ms) = read_pre m && go m ms
go m (Write ref _ : ms) = write_pre m && go m ms
go m (Inc ref : ms) = inc_pre m && go m ms
prop_par :: Property
prop_par = forAllShrink genParMem shrink $ monadicIO . replicateM_ 10 . go initModel
where
go :: Model -> ParMem -> PropertyM IO ()
go m (ParMem prefix left right) = do
pre $ preConds $ prefix <> left
pre $ preConds $ prefix <> right
e <- run $ runMem defaultEnv prefix
run $ withPool 2 $ \pool -> do
parallel_ pool [ runMem (e & pid .~ 1) left
, runMem (e & pid .~ 2) right
]
hist <- run $ getChanContents $ e^.historyChan
monitor $ counterexample $ unlines $ map ppShow hist
assert $ not $ null $ linearise hist
------------------------------------------------------------------------
testC :: IO ()
testC = verboseCheckWith (stdArgs { maxSize = 10 }) prop_par
------------------------------------------------------------------------
prop_safety' :: Property
prop_safety' = forAllShrink genMem shrink $ monadicIO . go []
where
go :: Model -> Mem -> PropertyM IO ()
go m mem = do
pre $ preConds mem
e <- run $ runMem defaultEnv mem
hist <- run $ getChanContents $ e^.historyChan
checkPostConds initModel hist
checkPostConds :: Model -> History -> PropertyM IO ()
checkPostConds m [] = return ()
checkPostConds m (InvocationEvent NewI _
: ResponseEvent (NewR ref) _ : es) = do
assert' "new_post" $ new_post m () ref
checkPostConds (new_next m () ref) es
checkPostConds m (InvocationEvent (ReadI ref) _
: ResponseEvent (ReadR i) _ : es) = do
assert' "read_post" $ read_post m ref i
checkPostConds (read_next m ref i) es
checkPostConds m (InvocationEvent (WriteI ref i) _
: ResponseEvent WriteR _ : es) = do
assert' "write_post" $ write_post m (ref, i) ()
checkPostConds (write_next m (ref, i) ()) es
checkPostConds m (InvocationEvent (IncI ref) _
: ResponseEvent IncR _ : es) = do
assert' "inc_post" $ inc_post m ref ()
checkPostConds (inc_next m ref ()) es
checkPostConds m es = fail ""
testP' :: IO ()
testP' = verboseCheckWith (stdArgs { maxSize = 10 }) prop_safety'
testS' :: IO ()
testS' = quickCheckWith (stdArgs { maxSuccess = 500 }) (prop_shrink prop_safety')
@stevana
Copy link
Author

stevana commented Jan 25, 2017

Example run of testC:

Assertion failed (after 6 tests and 8 shrinks):     
ParMem {_prefix = Mem [New], _left = Mem [Write (Ref 0) 2], _right = Mem [Inc (Ref 0),Read (Ref 0)]}
> new           {0}
< new   [$0]    {0}
> inc   $0      {2}
> write $0 2    {1}
< write         {1}
< inc           {2}
> read  $0      {2}
< read  [1]     {2}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment