Skip to content

Instantly share code, notes, and snippets.

@vshabanov
Last active November 15, 2025 16:51
Show Gist options
  • Select an option

  • Save vshabanov/528b66ed04b7b8333c8eba5af9eab060 to your computer and use it in GitHub Desktop.

Select an option

Save vshabanov/528b66ed04b7b8333c8eba5af9eab060 to your computer and use it in GitHub Desktop.
Experiments with nested data parallelism
{-# LANGUAGE TypeFamilies #-}
import Control.Concurrent
import Control.Concurrent.Async
import Control.Concurrent.Chan
import Control.Concurrent.MVar
import Control.Concurrent.QSem
import Control.Exception qualified as E
import Control.Monad
import Control.Monad.IO.Class
import Data.Bifunctor (first)
import Data.List
import Data.Ord
import GHC.Stack qualified as ES
import System.IO.Unsafe
main = testFold
testN = evalNDP 10 $ do
l <- ndpMap (\ x -> ndpMap pure [x..5]) [1..5]
pure sum <*> ndpMap (pure . length) l
testFold = evalNDP 10 $ do
let sleep = liftIO (threadDelay 5000000)
l <- sleep *> sleep *> sleep *> (ndpFold Nothing (pure 0)
(\ a x -> threadDelay (x * 100) >> pure (x*100+a))
(replicate 100 1000 <> replicate 3 10000)) <* sleep -- <* fail "Foo"
-- [1..100]
pure (length l, sum l, l)
{- | Nested data parallelism 'Monad'.
Allows to run nested computations in parallel on a fixed set of
threads.
Implementation notes:
'NDP' is a simple function which accepts a function to return a
result. Returns a list of actions to be performed in parallel.
There are many ways to order the execution of nested actions:
* "breadth-first" -- actions are pushed to the end of the queue.
Pros: easier to understand "layered" execution.
Cons: due to a different speed of the parallel execution different
"layers" will be mixed.
Cons: could use too much memory for all the unprocessed top "layer"
* "depth-first" -- actions are pushed to the front of the queue immediately
(a proper implementation requires passing the queue as a parameter
instead of returning the action).
Pros: uses less memory.
Cons: hard to understand/control the order of execution.
* mixed (this implementation) -- applicative operators catenate
actions and they're prepended to the queue in order, making it
"depth-first" but "breadth-first" on the level of applicative
operators.
Uses less memory while having a predictable top-level order of
operations (lower levels are not very predictable anyway).
-}
newtype NDP a
= NDP { runNDP :: (a -> IO TaskList) -> IO TaskList }
-- | Task list with O(1) append
data TaskList
= List [IO TaskList]
-- | replicate a task to all threads (so we don't need to pass
-- the number of threads to each 'NDP' action).
| ReplicateToAllThreads (Maybe Int) (IO TaskList)
-- -- | Map that doesn't create all tasks upfront
-- | Map (forall a . (a -> IO TaskList, [a]))
-- | O(1) append
| Cat TaskList TaskList
instance Semigroup TaskList where
List [] <> b = b
a <> List [] = a
a <> b = Cat a b
instance Monoid TaskList where
mempty = List []
prepend :: Int -> TaskList -> [IO TaskList] -> [IO TaskList]
prepend numThreads cl xs = case cl of
List l -> l <> xs
ReplicateToAllThreads n t ->
replicate (maybe numThreads (min numThreads) n) t <> xs
Cat a b -> prepend numThreads a $ prepend numThreads b xs
stop :: IO TaskList
stop = pure mempty
joinModify :: MVar a -> (a -> (a, IO b)) -> IO b
joinModify mvar f = join $ modifyMVar mvar (pure . f)
instance Functor NDP where
fmap f n = NDP $ \ c -> runNDP n (c . f)
instance Applicative NDP where
pure a = NDP $ \ c -> c a
mf <*> mx = NDP $ \ c -> do
lr <- newMVar (Nothing, Nothing)
let check = \ case
(Just f, Just x) -> ((Nothing, Nothing), c $ f x)
v -> (v, stop)
af <- runNDP mf $ \ f -> joinModify lr $ \ (_,x) -> check (Just f, x)
ax <- runNDP mx $ \ x -> joinModify lr $ \ (f,_) -> check (f, Just x)
pure $ af <> ax
instance Monad NDP where
m >>= f = NDP $ \ c -> runNDP m $ \ x -> runNDP (f x) c
instance MonadIO NDP where
liftIO m = NDP $ \ c -> pure $ List [m >>= c]
instance MonadFail NDP where
fail = liftIO . fail
-- | Parallel map operation that can run nested parallel maps.
ndpMap :: (a -> NDP b) -> [a] -> NDP [b]
ndpMap f l = if null l then pure [] else NDP $ \ cont -> do
results <- newMVar (length l, [])
pure $ List
[runNDP (f x) $ \ r ->
joinModify results $ \ case
(1, acc) ->
((0,[]), cont $ map snd $ sortBy (comparing fst) ((i,r):acc))
(left, acc) ->
((left-1, (i,r):acc), stop)
|(i,x) <- zip [0..] l]
-- | Runs as many parallel folds for the list as possible
-- Folds are IO as it's not clean what to do with folds that produce nested
-- computations.
ndpFold :: Maybe Int -> IO b -> (b -> a -> IO b) -> [a] -> NDP [b]
ndpFold nFolds newAcc f l = if null l then pure [] else NDP $ \ cont -> do
state <- newMVar (l, 0, []) -- (list, nActive, resultAccs)
let start x = do
a <- newAcc
next a x
next a0 x0 = do
a <- f a0 x0
joinModify state $ \ case
([], 1, rs) ->
(([], 0, []), cont (a:rs)) -- ready
([], nActive, rs) ->
(([], nActive-1, a:rs), stop) -- collect accumulators
(x:xs, nActive, rs) ->
((xs, nActive, rs), next a x)
pure $ ReplicateToAllThreads nFolds $ joinModify state $ \ case
s@([], _, _) -> (s, stop) -- nothing to do
(x:xs, nActive, rs) -> ((xs, nActive+1, rs), start x)
{-# NOINLINE printLock #-}
printLock = unsafePerformIO $ newMVar ()
evalNDP :: Int -> NDP a -> IO a
evalNDP nThreads toplevel = do
sem <- newQSem 0
tasks <- newMVar []
let pushTasks [] = pure ()
pushTasks ts = do
withMVar printLock $ const $ print ("pushing", length ts)
modifyMVar_ tasks $ pure . (ts <>)
replicateM_ (length ts) $ signalQSem sem
loopT t = do
new <- t
case prepend nThreads new [] of
[] -> loop
(tNew:ts) -> do
pushTasks ts
loopT tNew -- immediately recurse into the first task
-- without pushing it to the queue
loop = do
waitQSem sem
joinModify tasks $ \ case
[] -> ([], pure ()) -- exit
(t:ts) -> (ts, loopT t)
result <- newEmptyMVar
toplevelTasks <- runNDP toplevel $ \ a -> do
-- stop all the threads once the toplevel result is ready
replicateM_ nThreads $ signalQSem sem
putMVar result a
pure mempty
pushTasks $ prepend nThreads toplevelTasks []
replicateConcurrently nThreads loop
readMVar result
{-
data TaskQueue
= TaskQueue
{ numWorkers :: Int
, tasksSem :: QSem
, tasksList :: MVar [IO ()]
}
pushTask :: TaskQueue -> IO () -> IO ()
pushTask queue task = do
modifyMVar_ (tasksList queue) $ pure . (task :)
signalQSem $ tasksSem queue
-- modifyMVar writeChan (tasksChan queue) . Just
pushTasks :: TaskQueue -> [IO ()] -> IO ()
pushTasks queue tasks = do
modifyMVar_ (tasksList queue) $ pure . (tasks <>)
replicateM_ (length tasks) $ signalQSem $ tasksSem queue
-- non-GADT version
newtype NDPx a
= NDPx
{ runNDP
:: TaskQueue
-> (Either String a -> IO ()) -- ^ handle the result
-> IO ()
}
instance Functor NDPx where
fmap f n = NDPx $ \ q c -> runNDP n q (c . fmap f)
instance MonadFail NDPx where
fail msg = NDPx $ \ _ c -> c $ Left msg
instance Applicative NDPx where
pure a = NDPx $ \ _ c -> c (Right a)
mf <*> mx = NDPx $ \ q c -> do
lr <- newMVar (Nothing, Nothing)
let check v = do
case v of
(Just f, Just x) -> c (f <*> x)
_ -> pure ()
pure v
-- TODO: push a block (f:x:...) atomically to make sure
-- 'f' is picked up first
runNDP mf q $ \ f -> modifyMVar_ lr $ \ (_,x) -> check (Just f, x)
runNDP mx q $ \ x -> modifyMVar_ lr $ \ (f,_) -> check (f, Just x)
instance Monad NDPx where
m >>= f = NDPx $ \ p c -> runNDP m p $ \ case
Left e -> c (Left e)
Right x -> runNDP (f x) p c
instance MonadIO NDPx where
liftIO m = NDPx $ \ q c -> pushTask q (tryIO m >>= c)
ndpMap :: (a -> NDPx b) -> [a] -> NDPx [b]
ndpMap f l = if null l then pure [] else NDPx $ \ queue cont -> do
results <- newMVar (length l, [])
pushTasks queue
[runNDP (f x) queue $ \ r ->
join $ modifyMVar results $ \ (left, acc) ->
if left == 1 then do
let act = cont $ sequenceA $ map snd $ sortBy (comparing fst) ((i,r):acc)
pure ((0,[]), act)
else
pure ((left-1, (i,r):acc), pure ())
|(i,x) <- zip [0..] l]
-- | Runs as many parallel folds for the list as possible
-- Folds are IO as it's not clean what to do with folds that produce nested
-- computations.
ndpFold :: Maybe Int -> IO b -> (b -> a -> IO b) -> [a] -> NDPx [b]
ndpFold maxParallelFolds newAcc f l =
if null l then pure [] else NDPx $ \ queue cont -> do
let nFolds =
maybe (numWorkers queue) (min (numWorkers queue)) maxParallelFolds
state <- newMVar (l, 0, []) -- (list, nActive, resultAccs)
let modAct f = join $ modifyMVar state (pure . f)
tryNext m = do
ea' <- tryIO $ m
case ea' of
Left e -> modAct $ \ case
(_, 1, _) ->
(([], 0, []), cont (Left e)) -- last one
(_, nActive, _) ->
(([], nActive-1, [Left e]), pure ()) -- collect accumulators
Right a' -> next a'
start x = tryNext $ do
a <- newAcc
f a x
next acc = modAct $ \ case
([], 1, rs) ->
(([], 0, []), cont $ sequenceA (Right acc:rs)) -- ready
([], nActive, rs) ->
(([], nActive-1, Right acc:rs), pure ()) -- collect accumulators
(x:xs, nActive, rs) ->
((xs, nActive, rs), tryNext $ f acc x)
pushTasks queue $ replicate nFolds $ modAct $ \ case
s@([], _, _) -> (s, pure ()) -- nothing to do
(x:xs, nActive, rs) -> ((xs, nActive+1, rs), start x)
pxeval nThreads toplevel = do
-- tasks <- newChan
-- let queue = TaskQueue nThreads tasks
-- loop = forM
-- task <- readChan tasks
-- forM_ task (>> loop)
sem <- newQSem 0
tasks <- newMVar []
let queue = TaskQueue nThreads sem tasks
loop = do
waitQSem sem
join $ modifyMVar tasks $ \ case
[] -> pure ([], pure())
(t:ts) -> pure (ts, t >> loop)
result <- newEmptyMVar
runNDP toplevel queue $ \ a -> do
-- stop all the threads once the toplevel result is ready
replicateM_ nThreads $ signalQSem sem
-- replicateM_ nThreads (writeChan tasks Nothing)
putMVar result a
replicateConcurrently nThreads loop
readMVar result
-}
{-
data NDP x where
Pure :: a -> NDP a
Map :: (a -> NDP b) -> [a] -> NDP [b]
FMap :: (a -> b) -> NDP a -> NDP b
Bind :: NDP a -> (a -> NDP b) -> NDP b
-- could be used to implement Bind/Map/FMap
-- BindMap :: NDP [a] -> (a -> NDP b) -> NDP [b]
peval nThreads toplevel = do
tasks <- newChan :: IO (Chan (Maybe (IO ())))
let loop = do
task <- readChan tasks
forM_ task (>> loop)
addTask :: NDP a -> (a -> IO ()) -> IO ()
addTask ndp c = case ndp of
Pure x -> c x
FMap f x -> addTask x (c . f)
Map _ [] -> c []
Map f l -> do
results <- newMVar (length l, [])
-- TODO: it worth to add nested tasks at the beginning
-- to go depth-first instead of breadth-first.
-- It should have better locality and use less memory,
-- though it should work good enough like it is
forM_ (zip [0..] l) $ \ (i,x) -> writeChan tasks $ Just $
addTask (f x) $ \ r ->
join $ modifyMVar results $ \ (left, acc) -> pure $
if left == 1 then
(undefined, c $ map snd $ sortBy (comparing fst) ((i,r):acc))
else
((left-1, (i,r):acc), pure ())
Bind m f -> addTask m $ \ x -> addTask (f x) c
result <- newEmptyMVar
addTask toplevel $ \ a -> do
-- stop all the threads once toplevel result is ready
replicateM_ nThreads (writeChan tasks Nothing)
putMVar result a
replicateConcurrently nThreads loop
readMVar result
-- parfold -- ParMap (replicate nThreads $ listChan elems) (foldRead chan ..)
instance Show (NDP a) where
show = \ case
Pure{} -> "Pure"
Map{} -> "Map"
FMap _ x -> "FMap _ (" <> show x <> ")"
Bind x _ -> "Bind (" <> show x <> ") _"
eval :: NDP a -> a
eval = \ case
Pure x -> x
FMap f x -> f (eval x)
Map f l -> map (eval . f) l
Bind a f -> eval $ f $ eval a
instance Functor NDP where
fmap f = \ case
Pure x -> Pure $ f x
n -> FMap f n
instance Applicative NDP where
pure = Pure
Pure f <*> Pure x = Pure (f x)
f <*> x = FMap (\ [Left f, Right x] -> f x) $
Map (\ case
Left f -> Left <$> f
Right x -> Right <$> x) [Left f, Right x]
instance Monad NDP where
Pure a >>= f = f a
FMap f x >>= g = x >>= (g . f)
x >>= f = Bind x f
-}
{-
data Op a where
Const :: a -> Op a
MapReduce :: [a] -> IO (a -> IO (Op b)) -> ([b] -> Op c) -> Op c
data Op a where
Const :: a -> Op a
MapThen :: [a] -> (a -> IO (Op b)) -> ([b] -> IO (Op c)) -> Op c
MapFold :: [a] -> IO acc -> (acc -> a -> IO acc) -> ([acc] -> IO (Op c)) -> Op c
type Reducer = IO (Maybe a) -> IO (Maybe b)
... = MapThen nettingSets
(\ ns -> do
...
pure $ MapThen (subNettingSets ..)
(\ sns -> do
...
pure $ MapFold trades newAcc
(\ acc trade -> addTrade ...)
combineAccs)
(
parMap (chunksOf ..) $ \ ... newAcc
parFold newAcc combineAccs $ \ acc elem ->
addElem
... = parMap nettingSets $ \ ns ->
parMap ... $ \ sns ->
parMap (chunksOf ..) $ \ ... newAcc
parFold newAcc combineAccs $ \ acc elem ->
addElem
... = parMap nettingSets $ \ ns ->
parMapThen combineSns (\ sns ->
parMapThen combineAccs (chunksOf ..) $ \ ... newAcc
parFold newAcc combineAccs $ \ acc elem ->
addElem
)
nettingSets -> subNettingSets -> (trades | chunksOf trades)
foo pm = do
x <- send pm $ ParMap nettingSets $ \ ns ->
send pm $ ParMap (sns ...) $ \ sns ->
send pm $ ParFold trades newAcc addTrade
wait x --
ParMap l f = ParFold l (pure [])
data Op a where
Const :: a -> Op a
MapReduce :: [a] -> (a -> Op b) -> ([b] -> Op c) -> Op c
runOpSequential :: Op a -> a
runOpSequential = \ case
Const a -> a
MapReduce l f r -> runOpSequential $ r $ map (runOpSequential . f) l
ops =
MapReduce [1..10] (\ i -> MapReduce [i*10,i*20] Const (Const . maximum))
(Const . sum)
data Reducer a b c
= Reducer
{ rInQueue :: [a]
, rMap :: a -> Op b
, rOutQueue :: [b]
, rUnprocessed :: Int
, rReduce :: [b] -> Op c
}
-- data Stack a b where
-- -- Top :: Op a -> Stack a
-- Top :: Reducer a b c -> Stack b c
-- Cons :: Reducer a b c -> Stack c d -> Stack b d
runOpStacked :: Op a -> a
runOpStacked = \ case
Const a -> a
MapReduce l f r -> runStack $ Top $ Reducer l f [] (length l) r
runStack :: Stack a b -> b
runStack
-- runOpParallel :: Int -> Op a -> IO a
-- runOpParallel n = do
-- \ case
-- Const a -> a
-- MapReduce l f r -> r $ map (runOpSequential . f) l
-}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment