Created
June 2, 2020 23:47
-
-
Save ssadler/77c36e9f99f350c352971a5f4b818b18 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Data.Word | |
import Control.Monad | |
import Control.Monad.Catch | |
import Test.DejaFu | |
import Test.DejaFu.Conc.Internal.Common | |
import Test.DejaFu.Conc.Internal.STM | |
import Control.Monad.Conc.Class | |
import Control.Concurrent.Classy hiding (wait) | |
import Control.Concurrent.Classy.Async | |
import Control.Concurrent.Classy.MVar | |
import qualified Data.Map as Map | |
type HostAddress = Word32 | |
type ReceiverMap m = TVar (STM m) (Map.Map HostAddress (Async m ())) | |
type ClassyAsync = Async IO | |
classyAsync :: MonadConc m => m a -> m (Async m a) | |
classyAsync = async | |
newReceiverMap :: MonadConc m => m (ReceiverMap m) | |
newReceiverMap = atomically (newTVar mempty) | |
inboundConnectionLimit | |
:: MonadConc m | |
=> ReceiverMap m | |
-> HostAddress | |
-> Async m () | |
-> m a | |
-> m a | |
inboundConnectionLimit mreceivers ip asnc act = do | |
finally | |
do | |
mapM_ cancel =<< atomically do | |
lookupAsync ip mreceivers <* insertAsync ip asnc mreceivers | |
act | |
do | |
atomically do | |
lookupAsync ip mreceivers >>= | |
mapM_ \oasnc -> | |
when (asnc == oasnc) (void $ deleteAsync ip mreceivers) | |
insertAsync :: MonadConc m => HostAddress -> Async m () -> ReceiverMap m -> STM m () | |
insertAsync ip asnc t = do | |
modifyTVar t $ Map.insert ip asnc | |
lookupAsync :: MonadConc m => HostAddress -> ReceiverMap m -> STM m (Maybe (Async m ())) | |
lookupAsync ip tmap = do | |
Map.lookup ip <$> readTVar tmap | |
deleteAsync :: MonadConc m => HostAddress -> ReceiverMap m -> STM m () | |
deleteAsync ip t = modifyTVar t $ Map.delete ip | |
testInboundConnectionLimit :: Program (WithSetup (ModelTVar IO Integer)) IO () | |
testInboundConnectionLimit = withSetup setup \sem -> do | |
mreceivers <- atomically (newTVar mempty) | |
asyncs <- forM [0..1] \i -> do | |
handoff <- newEmptyMVar | |
asnc <- async do | |
me <- takeMVar handoff | |
inboundConnectionLimit mreceivers 0 me do | |
finally | |
(do | |
atomically $ modifyTVar sem (+1) | |
threadDelay 1) | |
(do atomically (modifyTVar sem (subtract 1))) | |
putMVar handoff asnc | |
pure asnc | |
mapM_ waitCatch asyncs | |
where | |
setup :: Program Basic IO (ModelTVar IO Integer) | |
setup = do | |
single <- atomically $ newTVar 0 | |
registerInvariant do | |
n <- inspectTVar single | |
when (n > 1) $ throwM TooManyThreads -- error "too many threads" | |
pure () | |
pure single | |
data TooManyThreads = TooManyThreads deriving (Show) | |
instance Exception TooManyThreads |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment