Created
October 4, 2017 19:25
-
-
Save andrewthad/48024623cf39e672f4f7b4242b25d5b3 to your computer and use it in GitHub Desktop.
Parallel array initialization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE BangPatterns #-} | |
{-# LANGUAGE MagicHash #-} | |
{-# OPTIONS_GHC -O2 -Wall -threaded -fforce-recomp #-} | |
import Criterion.Main | |
import Control.Monad (when) | |
import Control.Monad.ST.Unsafe (unsafeDupableInterleaveST,unsafeInterleaveST) | |
import GHC.ST (ST(..)) | |
import GHC.Prim (spark#,seq#) | |
import qualified Data.Vector as V | |
import qualified Data.Vector.Primitive as PV | |
import qualified Data.Vector.Primitive.Mutable as MPV | |
main :: IO () | |
main = do | |
let len = 20000000 :: Int | |
let expected = PV.enumFromN (0 :: Int) len | |
when (not (serialIncrementing len == expected)) $ do | |
fail "serialIncrementing implementation incorrect" | |
when (not (parallelIncrementing 8 len == expected)) $ do | |
fail "parallelIncrementing implementation incorrect" | |
when (not (parallelIncrementingFixed len == expected)) $ do | |
fail "parallelIncrementingFixed implementation incorrect" | |
defaultMain | |
[ bgroup "incrementing" | |
[ bench "library" $ whnf (PV.enumFromN (0 :: Int)) len | |
, bench "serial" $ whnf serialIncrementing len | |
, bench "parallel (fixed)" $ whnf parallelIncrementingFixed len | |
-- , bench "parallel (4 sparks)" $ whnf (parallelIncrementing 4) len | |
-- , bench "parallel (8 sparks)" $ whnf (parallelIncrementing 8) len | |
] | |
] | |
serialIncrementing :: Int -> PV.Vector Int | |
serialIncrementing !n = PV.create $ do | |
!v <- MPV.unsafeNew n | |
go 0 v | |
return v | |
where | |
go :: Int -> PV.MVector s Int -> ST s () | |
go !ix !v = if ix < n | |
then MPV.unsafeWrite v ix ix >> go (ix + 1) v | |
else return () | |
parallelIncrementingFixed :: Int -> PV.Vector Int | |
parallelIncrementingFixed !n = PV.create $ do | |
!v <- MPV.unsafeNew n | |
let (x1,x2) = (0, div (n * 7) 10) | |
(y1,y2) = (x2, x2 + div n 10) | |
(z1,z2) = (y2, y2 + div n 10) | |
(w1,w2) = (z2, n) | |
y <- unsafeDupableInterleaveST (go y1 y2 v) >>= sparkST | |
z <- unsafeDupableInterleaveST (go z1 z2 v) >>= sparkST | |
w <- unsafeDupableInterleaveST (go w1 w2 v) >>= sparkST | |
go x1 x2 v | |
() <- seqST y | |
() <- seqST z | |
() <- seqST w | |
return v | |
where | |
go :: Int -> Int -> PV.MVector s Int -> ST s () | |
go !ix !hi !v = if ix < hi | |
then MPV.unsafeWrite v ix ix >> go (ix + 1) hi v | |
else return () | |
parallelIncrementing :: Int -> Int -> PV.Vector Int | |
parallelIncrementing !cores !n = PV.create $ do | |
!v <- MPV.unsafeNew n | |
let bounds = makeBounds n cores | |
units <- V.forM bounds $ \(lo,hi) -> unsafeInterleaveST (go lo hi v) >>= sparkST | |
() <- seqVector units | |
return v | |
where | |
go :: Int -> Int -> PV.MVector s Int -> ST s () | |
go !ix !hi !v = if ix < hi | |
then MPV.unsafeWrite v ix ix >> go (ix + 1) hi v | |
else return () | |
seqVector :: V.Vector () -> ST s () | |
seqVector = V.foldM (\() b -> seqST b) () | |
seqST :: a -> ST s a | |
seqST a = ST $ \s -> seq# a s | |
sparkST :: a -> ST s a | |
sparkST a = ST $ \s -> spark# a s | |
makeBounds :: Int -> Int -> V.Vector (Int,Int) | |
makeBounds total cores = | |
let n = div total cores + 1 | |
in V.generate cores (\i -> (i * n, min total ((i + 1) * n))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment