Last active
November 15, 2025 16:51
-
-
Save vshabanov/528b66ed04b7b8333c8eba5af9eab060 to your computer and use it in GitHub Desktop.
Experiments with nested data parallelism
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| {-# LANGUAGE TypeFamilies #-} | |
| import Control.Concurrent | |
| import Control.Concurrent.Async | |
| import Control.Concurrent.Chan | |
| import Control.Concurrent.MVar | |
| import Control.Concurrent.QSem | |
| import Control.Exception qualified as E | |
| import Control.Monad | |
| import Control.Monad.IO.Class | |
| import Data.Bifunctor (first) | |
| import Data.List | |
| import Data.Ord | |
| import GHC.Stack qualified as ES | |
| import System.IO.Unsafe | |
| main = testFold | |
| testN = evalNDP 10 $ do | |
| l <- ndpMap (\ x -> ndpMap pure [x..5]) [1..5] | |
| pure sum <*> ndpMap (pure . length) l | |
| testFold = evalNDP 10 $ do | |
| let sleep = liftIO (threadDelay 5000000) | |
| l <- sleep *> sleep *> sleep *> (ndpFold Nothing (pure 0) | |
| (\ a x -> threadDelay (x * 100) >> pure (x*100+a)) | |
| (replicate 100 1000 <> replicate 3 10000)) <* sleep -- <* fail "Foo" | |
| -- [1..100] | |
| pure (length l, sum l, l) | |
| {- | Nested data parallelism 'Monad'. | |
| Allows to run nested computations in parallel on a fixed set of | |
| threads. | |
| Implementation notes: | |
| 'NDP' is a simple function which accepts a function to return a | |
| result. Returns a list of actions to be performed in parallel. | |
| There are many ways to order the execution of nested actions: | |
| * "breadth-first" -- actions are pushed to the end of the queue. | |
| Pros: easier to understand "layered" execution. | |
| Cons: due to a different speed of the parallel execution different | |
| "layers" will be mixed. | |
| Cons: could use too much memory for all the unprocessed top "layer" | |
| * "depth-first" -- actions are pushed to the front of the queue immediately | |
| (a proper implementation requires passing the queue as a parameter | |
| instead of returning the action). | |
| Pros: uses less memory. | |
| Cons: hard to understand/control the order of execution. | |
| * mixed (this implementation) -- applicative operators catenate | |
| actions and they're prepended to the queue in order, making it | |
| "depth-first" but "breadth-first" on the level of applicative | |
| operators. | |
| Uses less memory while having a predictable top-level order of | |
| operations (lower levels are not very predictable anyway). | |
| -} | |
| newtype NDP a | |
| = NDP { runNDP :: (a -> IO TaskList) -> IO TaskList } | |
| -- | Task list with O(1) append | |
| data TaskList | |
| = List [IO TaskList] | |
| -- | replicate a task to all threads (so we don't need to pass | |
| -- the number of threads to each 'NDP' action). | |
| | ReplicateToAllThreads (Maybe Int) (IO TaskList) | |
| -- -- | Map that doesn't create all tasks upfront | |
| -- | Map (forall a . (a -> IO TaskList, [a])) | |
| -- | O(1) append | |
| | Cat TaskList TaskList | |
| instance Semigroup TaskList where | |
| List [] <> b = b | |
| a <> List [] = a | |
| a <> b = Cat a b | |
| instance Monoid TaskList where | |
| mempty = List [] | |
| prepend :: Int -> TaskList -> [IO TaskList] -> [IO TaskList] | |
| prepend numThreads cl xs = case cl of | |
| List l -> l <> xs | |
| ReplicateToAllThreads n t -> | |
| replicate (maybe numThreads (min numThreads) n) t <> xs | |
| Cat a b -> prepend numThreads a $ prepend numThreads b xs | |
| stop :: IO TaskList | |
| stop = pure mempty | |
| joinModify :: MVar a -> (a -> (a, IO b)) -> IO b | |
| joinModify mvar f = join $ modifyMVar mvar (pure . f) | |
| instance Functor NDP where | |
| fmap f n = NDP $ \ c -> runNDP n (c . f) | |
| instance Applicative NDP where | |
| pure a = NDP $ \ c -> c a | |
| mf <*> mx = NDP $ \ c -> do | |
| lr <- newMVar (Nothing, Nothing) | |
| let check = \ case | |
| (Just f, Just x) -> ((Nothing, Nothing), c $ f x) | |
| v -> (v, stop) | |
| af <- runNDP mf $ \ f -> joinModify lr $ \ (_,x) -> check (Just f, x) | |
| ax <- runNDP mx $ \ x -> joinModify lr $ \ (f,_) -> check (f, Just x) | |
| pure $ af <> ax | |
| instance Monad NDP where | |
| m >>= f = NDP $ \ c -> runNDP m $ \ x -> runNDP (f x) c | |
| instance MonadIO NDP where | |
| liftIO m = NDP $ \ c -> pure $ List [m >>= c] | |
| instance MonadFail NDP where | |
| fail = liftIO . fail | |
| -- | Parallel map operation that can run nested parallel maps. | |
| ndpMap :: (a -> NDP b) -> [a] -> NDP [b] | |
| ndpMap f l = if null l then pure [] else NDP $ \ cont -> do | |
| results <- newMVar (length l, []) | |
| pure $ List | |
| [runNDP (f x) $ \ r -> | |
| joinModify results $ \ case | |
| (1, acc) -> | |
| ((0,[]), cont $ map snd $ sortBy (comparing fst) ((i,r):acc)) | |
| (left, acc) -> | |
| ((left-1, (i,r):acc), stop) | |
| |(i,x) <- zip [0..] l] | |
| -- | Runs as many parallel folds for the list as possible | |
| -- Folds are IO as it's not clean what to do with folds that produce nested | |
| -- computations. | |
| ndpFold :: Maybe Int -> IO b -> (b -> a -> IO b) -> [a] -> NDP [b] | |
| ndpFold nFolds newAcc f l = if null l then pure [] else NDP $ \ cont -> do | |
| state <- newMVar (l, 0, []) -- (list, nActive, resultAccs) | |
| let start x = do | |
| a <- newAcc | |
| next a x | |
| next a0 x0 = do | |
| a <- f a0 x0 | |
| joinModify state $ \ case | |
| ([], 1, rs) -> | |
| (([], 0, []), cont (a:rs)) -- ready | |
| ([], nActive, rs) -> | |
| (([], nActive-1, a:rs), stop) -- collect accumulators | |
| (x:xs, nActive, rs) -> | |
| ((xs, nActive, rs), next a x) | |
| pure $ ReplicateToAllThreads nFolds $ joinModify state $ \ case | |
| s@([], _, _) -> (s, stop) -- nothing to do | |
| (x:xs, nActive, rs) -> ((xs, nActive+1, rs), start x) | |
| {-# NOINLINE printLock #-} | |
| printLock = unsafePerformIO $ newMVar () | |
| evalNDP :: Int -> NDP a -> IO a | |
| evalNDP nThreads toplevel = do | |
| sem <- newQSem 0 | |
| tasks <- newMVar [] | |
| let pushTasks [] = pure () | |
| pushTasks ts = do | |
| withMVar printLock $ const $ print ("pushing", length ts) | |
| modifyMVar_ tasks $ pure . (ts <>) | |
| replicateM_ (length ts) $ signalQSem sem | |
| loopT t = do | |
| new <- t | |
| case prepend nThreads new [] of | |
| [] -> loop | |
| (tNew:ts) -> do | |
| pushTasks ts | |
| loopT tNew -- immediately recurse into the first task | |
| -- without pushing it to the queue | |
| loop = do | |
| waitQSem sem | |
| joinModify tasks $ \ case | |
| [] -> ([], pure ()) -- exit | |
| (t:ts) -> (ts, loopT t) | |
| result <- newEmptyMVar | |
| toplevelTasks <- runNDP toplevel $ \ a -> do | |
| -- stop all the threads once the toplevel result is ready | |
| replicateM_ nThreads $ signalQSem sem | |
| putMVar result a | |
| pure mempty | |
| pushTasks $ prepend nThreads toplevelTasks [] | |
| replicateConcurrently nThreads loop | |
| readMVar result | |
| {- | |
| data TaskQueue | |
| = TaskQueue | |
| { numWorkers :: Int | |
| , tasksSem :: QSem | |
| , tasksList :: MVar [IO ()] | |
| } | |
| pushTask :: TaskQueue -> IO () -> IO () | |
| pushTask queue task = do | |
| modifyMVar_ (tasksList queue) $ pure . (task :) | |
| signalQSem $ tasksSem queue | |
| -- modifyMVar writeChan (tasksChan queue) . Just | |
| pushTasks :: TaskQueue -> [IO ()] -> IO () | |
| pushTasks queue tasks = do | |
| modifyMVar_ (tasksList queue) $ pure . (tasks <>) | |
| replicateM_ (length tasks) $ signalQSem $ tasksSem queue | |
| -- non-GADT version | |
| newtype NDPx a | |
| = NDPx | |
| { runNDP | |
| :: TaskQueue | |
| -> (Either String a -> IO ()) -- ^ handle the result | |
| -> IO () | |
| } | |
| instance Functor NDPx where | |
| fmap f n = NDPx $ \ q c -> runNDP n q (c . fmap f) | |
| instance MonadFail NDPx where | |
| fail msg = NDPx $ \ _ c -> c $ Left msg | |
| instance Applicative NDPx where | |
| pure a = NDPx $ \ _ c -> c (Right a) | |
| mf <*> mx = NDPx $ \ q c -> do | |
| lr <- newMVar (Nothing, Nothing) | |
| let check v = do | |
| case v of | |
| (Just f, Just x) -> c (f <*> x) | |
| _ -> pure () | |
| pure v | |
| -- TODO: push a block (f:x:...) atomically to make sure | |
| -- 'f' is picked up first | |
| runNDP mf q $ \ f -> modifyMVar_ lr $ \ (_,x) -> check (Just f, x) | |
| runNDP mx q $ \ x -> modifyMVar_ lr $ \ (f,_) -> check (f, Just x) | |
| instance Monad NDPx where | |
| m >>= f = NDPx $ \ p c -> runNDP m p $ \ case | |
| Left e -> c (Left e) | |
| Right x -> runNDP (f x) p c | |
| instance MonadIO NDPx where | |
| liftIO m = NDPx $ \ q c -> pushTask q (tryIO m >>= c) | |
| ndpMap :: (a -> NDPx b) -> [a] -> NDPx [b] | |
| ndpMap f l = if null l then pure [] else NDPx $ \ queue cont -> do | |
| results <- newMVar (length l, []) | |
| pushTasks queue | |
| [runNDP (f x) queue $ \ r -> | |
| join $ modifyMVar results $ \ (left, acc) -> | |
| if left == 1 then do | |
| let act = cont $ sequenceA $ map snd $ sortBy (comparing fst) ((i,r):acc) | |
| pure ((0,[]), act) | |
| else | |
| pure ((left-1, (i,r):acc), pure ()) | |
| |(i,x) <- zip [0..] l] | |
| -- | Runs as many parallel folds for the list as possible | |
| -- Folds are IO as it's not clean what to do with folds that produce nested | |
| -- computations. | |
| ndpFold :: Maybe Int -> IO b -> (b -> a -> IO b) -> [a] -> NDPx [b] | |
| ndpFold maxParallelFolds newAcc f l = | |
| if null l then pure [] else NDPx $ \ queue cont -> do | |
| let nFolds = | |
| maybe (numWorkers queue) (min (numWorkers queue)) maxParallelFolds | |
| state <- newMVar (l, 0, []) -- (list, nActive, resultAccs) | |
| let modAct f = join $ modifyMVar state (pure . f) | |
| tryNext m = do | |
| ea' <- tryIO $ m | |
| case ea' of | |
| Left e -> modAct $ \ case | |
| (_, 1, _) -> | |
| (([], 0, []), cont (Left e)) -- last one | |
| (_, nActive, _) -> | |
| (([], nActive-1, [Left e]), pure ()) -- collect accumulators | |
| Right a' -> next a' | |
| start x = tryNext $ do | |
| a <- newAcc | |
| f a x | |
| next acc = modAct $ \ case | |
| ([], 1, rs) -> | |
| (([], 0, []), cont $ sequenceA (Right acc:rs)) -- ready | |
| ([], nActive, rs) -> | |
| (([], nActive-1, Right acc:rs), pure ()) -- collect accumulators | |
| (x:xs, nActive, rs) -> | |
| ((xs, nActive, rs), tryNext $ f acc x) | |
| pushTasks queue $ replicate nFolds $ modAct $ \ case | |
| s@([], _, _) -> (s, pure ()) -- nothing to do | |
| (x:xs, nActive, rs) -> ((xs, nActive+1, rs), start x) | |
| pxeval nThreads toplevel = do | |
| -- tasks <- newChan | |
| -- let queue = TaskQueue nThreads tasks | |
| -- loop = forM | |
| -- task <- readChan tasks | |
| -- forM_ task (>> loop) | |
| sem <- newQSem 0 | |
| tasks <- newMVar [] | |
| let queue = TaskQueue nThreads sem tasks | |
| loop = do | |
| waitQSem sem | |
| join $ modifyMVar tasks $ \ case | |
| [] -> pure ([], pure()) | |
| (t:ts) -> pure (ts, t >> loop) | |
| result <- newEmptyMVar | |
| runNDP toplevel queue $ \ a -> do | |
| -- stop all the threads once the toplevel result is ready | |
| replicateM_ nThreads $ signalQSem sem | |
| -- replicateM_ nThreads (writeChan tasks Nothing) | |
| putMVar result a | |
| replicateConcurrently nThreads loop | |
| readMVar result | |
| -} | |
| {- | |
| data NDP x where | |
| Pure :: a -> NDP a | |
| Map :: (a -> NDP b) -> [a] -> NDP [b] | |
| FMap :: (a -> b) -> NDP a -> NDP b | |
| Bind :: NDP a -> (a -> NDP b) -> NDP b | |
| -- could be used to implement Bind/Map/FMap | |
| -- BindMap :: NDP [a] -> (a -> NDP b) -> NDP [b] | |
| peval nThreads toplevel = do | |
| tasks <- newChan :: IO (Chan (Maybe (IO ()))) | |
| let loop = do | |
| task <- readChan tasks | |
| forM_ task (>> loop) | |
| addTask :: NDP a -> (a -> IO ()) -> IO () | |
| addTask ndp c = case ndp of | |
| Pure x -> c x | |
| FMap f x -> addTask x (c . f) | |
| Map _ [] -> c [] | |
| Map f l -> do | |
| results <- newMVar (length l, []) | |
| -- TODO: it worth to add nested tasks at the beginning | |
| -- to go depth-first instead of breadth-first. | |
| -- It should have better locality and use less memory, | |
| -- though it should work good enough like it is | |
| forM_ (zip [0..] l) $ \ (i,x) -> writeChan tasks $ Just $ | |
| addTask (f x) $ \ r -> | |
| join $ modifyMVar results $ \ (left, acc) -> pure $ | |
| if left == 1 then | |
| (undefined, c $ map snd $ sortBy (comparing fst) ((i,r):acc)) | |
| else | |
| ((left-1, (i,r):acc), pure ()) | |
| Bind m f -> addTask m $ \ x -> addTask (f x) c | |
| result <- newEmptyMVar | |
| addTask toplevel $ \ a -> do | |
| -- stop all the threads once toplevel result is ready | |
| replicateM_ nThreads (writeChan tasks Nothing) | |
| putMVar result a | |
| replicateConcurrently nThreads loop | |
| readMVar result | |
| -- parfold -- ParMap (replicate nThreads $ listChan elems) (foldRead chan ..) | |
| instance Show (NDP a) where | |
| show = \ case | |
| Pure{} -> "Pure" | |
| Map{} -> "Map" | |
| FMap _ x -> "FMap _ (" <> show x <> ")" | |
| Bind x _ -> "Bind (" <> show x <> ") _" | |
| eval :: NDP a -> a | |
| eval = \ case | |
| Pure x -> x | |
| FMap f x -> f (eval x) | |
| Map f l -> map (eval . f) l | |
| Bind a f -> eval $ f $ eval a | |
| instance Functor NDP where | |
| fmap f = \ case | |
| Pure x -> Pure $ f x | |
| n -> FMap f n | |
| instance Applicative NDP where | |
| pure = Pure | |
| Pure f <*> Pure x = Pure (f x) | |
| f <*> x = FMap (\ [Left f, Right x] -> f x) $ | |
| Map (\ case | |
| Left f -> Left <$> f | |
| Right x -> Right <$> x) [Left f, Right x] | |
| instance Monad NDP where | |
| Pure a >>= f = f a | |
| FMap f x >>= g = x >>= (g . f) | |
| x >>= f = Bind x f | |
| -} | |
| {- | |
| data Op a where | |
| Const :: a -> Op a | |
| MapReduce :: [a] -> IO (a -> IO (Op b)) -> ([b] -> Op c) -> Op c | |
| data Op a where | |
| Const :: a -> Op a | |
| MapThen :: [a] -> (a -> IO (Op b)) -> ([b] -> IO (Op c)) -> Op c | |
| MapFold :: [a] -> IO acc -> (acc -> a -> IO acc) -> ([acc] -> IO (Op c)) -> Op c | |
| type Reducer = IO (Maybe a) -> IO (Maybe b) | |
| ... = MapThen nettingSets | |
| (\ ns -> do | |
| ... | |
| pure $ MapThen (subNettingSets ..) | |
| (\ sns -> do | |
| ... | |
| pure $ MapFold trades newAcc | |
| (\ acc trade -> addTrade ...) | |
| combineAccs) | |
| ( | |
| parMap (chunksOf ..) $ \ ... newAcc | |
| parFold newAcc combineAccs $ \ acc elem -> | |
| addElem | |
| ... = parMap nettingSets $ \ ns -> | |
| parMap ... $ \ sns -> | |
| parMap (chunksOf ..) $ \ ... newAcc | |
| parFold newAcc combineAccs $ \ acc elem -> | |
| addElem | |
| ... = parMap nettingSets $ \ ns -> | |
| parMapThen combineSns (\ sns -> | |
| parMapThen combineAccs (chunksOf ..) $ \ ... newAcc | |
| parFold newAcc combineAccs $ \ acc elem -> | |
| addElem | |
| ) | |
| nettingSets -> subNettingSets -> (trades | chunksOf trades) | |
| foo pm = do | |
| x <- send pm $ ParMap nettingSets $ \ ns -> | |
| send pm $ ParMap (sns ...) $ \ sns -> | |
| send pm $ ParFold trades newAcc addTrade | |
| wait x -- | |
| ParMap l f = ParFold l (pure []) | |
| data Op a where | |
| Const :: a -> Op a | |
| MapReduce :: [a] -> (a -> Op b) -> ([b] -> Op c) -> Op c | |
| runOpSequential :: Op a -> a | |
| runOpSequential = \ case | |
| Const a -> a | |
| MapReduce l f r -> runOpSequential $ r $ map (runOpSequential . f) l | |
| ops = | |
| MapReduce [1..10] (\ i -> MapReduce [i*10,i*20] Const (Const . maximum)) | |
| (Const . sum) | |
| data Reducer a b c | |
| = Reducer | |
| { rInQueue :: [a] | |
| , rMap :: a -> Op b | |
| , rOutQueue :: [b] | |
| , rUnprocessed :: Int | |
| , rReduce :: [b] -> Op c | |
| } | |
| -- data Stack a b where | |
| -- -- Top :: Op a -> Stack a | |
| -- Top :: Reducer a b c -> Stack b c | |
| -- Cons :: Reducer a b c -> Stack c d -> Stack b d | |
| runOpStacked :: Op a -> a | |
| runOpStacked = \ case | |
| Const a -> a | |
| MapReduce l f r -> runStack $ Top $ Reducer l f [] (length l) r | |
| runStack :: Stack a b -> b | |
| runStack | |
| -- runOpParallel :: Int -> Op a -> IO a | |
| -- runOpParallel n = do | |
| -- \ case | |
| -- Const a -> a | |
| -- MapReduce l f r -> r $ map (runOpSequential . f) l | |
| -} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment