Skip to content

Instantly share code, notes, and snippets.

@andrewthad
Created October 4, 2017 19:25
Show Gist options
  • Save andrewthad/48024623cf39e672f4f7b4242b25d5b3 to your computer and use it in GitHub Desktop.
Save andrewthad/48024623cf39e672f4f7b4242b25d5b3 to your computer and use it in GitHub Desktop.
Parallel array initialization
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# OPTIONS_GHC -O2 -Wall -threaded -fforce-recomp #-}
import Criterion.Main
import Control.Monad (when)
import Control.Monad.ST.Unsafe (unsafeDupableInterleaveST,unsafeInterleaveST)
import GHC.ST (ST(..))
import GHC.Prim (spark#,seq#)
import qualified Data.Vector as V
import qualified Data.Vector.Primitive as PV
import qualified Data.Vector.Primitive.Mutable as MPV
main :: IO ()
main = do
let len = 20000000 :: Int
let expected = PV.enumFromN (0 :: Int) len
when (not (serialIncrementing len == expected)) $ do
fail "serialIncrementing implementation incorrect"
when (not (parallelIncrementing 8 len == expected)) $ do
fail "parallelIncrementing implementation incorrect"
when (not (parallelIncrementingFixed len == expected)) $ do
fail "parallelIncrementingFixed implementation incorrect"
defaultMain
[ bgroup "incrementing"
[ bench "library" $ whnf (PV.enumFromN (0 :: Int)) len
, bench "serial" $ whnf serialIncrementing len
, bench "parallel (fixed)" $ whnf parallelIncrementingFixed len
-- , bench "parallel (4 sparks)" $ whnf (parallelIncrementing 4) len
-- , bench "parallel (8 sparks)" $ whnf (parallelIncrementing 8) len
]
]
serialIncrementing :: Int -> PV.Vector Int
serialIncrementing !n = PV.create $ do
!v <- MPV.unsafeNew n
go 0 v
return v
where
go :: Int -> PV.MVector s Int -> ST s ()
go !ix !v = if ix < n
then MPV.unsafeWrite v ix ix >> go (ix + 1) v
else return ()
parallelIncrementingFixed :: Int -> PV.Vector Int
parallelIncrementingFixed !n = PV.create $ do
!v <- MPV.unsafeNew n
let (x1,x2) = (0, div (n * 7) 10)
(y1,y2) = (x2, x2 + div n 10)
(z1,z2) = (y2, y2 + div n 10)
(w1,w2) = (z2, n)
y <- unsafeDupableInterleaveST (go y1 y2 v) >>= sparkST
z <- unsafeDupableInterleaveST (go z1 z2 v) >>= sparkST
w <- unsafeDupableInterleaveST (go w1 w2 v) >>= sparkST
go x1 x2 v
() <- seqST y
() <- seqST z
() <- seqST w
return v
where
go :: Int -> Int -> PV.MVector s Int -> ST s ()
go !ix !hi !v = if ix < hi
then MPV.unsafeWrite v ix ix >> go (ix + 1) hi v
else return ()
parallelIncrementing :: Int -> Int -> PV.Vector Int
parallelIncrementing !cores !n = PV.create $ do
!v <- MPV.unsafeNew n
let bounds = makeBounds n cores
units <- V.forM bounds $ \(lo,hi) -> unsafeInterleaveST (go lo hi v) >>= sparkST
() <- seqVector units
return v
where
go :: Int -> Int -> PV.MVector s Int -> ST s ()
go !ix !hi !v = if ix < hi
then MPV.unsafeWrite v ix ix >> go (ix + 1) hi v
else return ()
seqVector :: V.Vector () -> ST s ()
seqVector = V.foldM (\() b -> seqST b) ()
seqST :: a -> ST s a
seqST a = ST $ \s -> seq# a s
sparkST :: a -> ST s a
sparkST a = ST $ \s -> spark# a s
makeBounds :: Int -> Int -> V.Vector (Int,Int)
makeBounds total cores =
let n = div total cores + 1
in V.generate cores (\i -> (i * n, min total ((i + 1) * n)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment