Skip to content

Instantly share code, notes, and snippets.

@23Skidoo
Created June 22, 2011 00:11
Show Gist options
  • Save 23Skidoo/1039262 to your computer and use it in GitHub Desktop.
Save 23Skidoo/1039262 to your computer and use it in GitHub Desktop.
Functional vs. imperative nested triangular loops
module Main
where
-- See http://stackoverflow.com/questions/6411771/nested-triangular-loop-in-haskell
import Control.DeepSeq
import Criterion.Main
import Data.Array
import Data.List (tails)
import qualified Data.Vector as V
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as M
numElts :: Int
numElts = 10
dt::Double
dt = 0.001
type B = (Double, Double, Double)
getX,getV,getM :: B -> Double
setX,setV,setM :: B -> Double -> B
getX (x,_,_) = x
setX (_,v,m) x' = (x',v,m)
getV (_,v,_) = v
setV (x,_,m) v' = (x,v',m)
getM (_,_,m) = m
setM (x,v,_) m' = (x,v,m')
instance U.Unbox a => NFData (U.Vector a)
instance NFData a => NFData (V.Vector a) where
rnf v = V.foldl' (\x y -> y `deepseq` x) () v
pureSolution :: Int -> Array Int B
pureSolution n = bs
where
bs_0 :: Array Int B
bs_0 = array (0,n) [(i,(i'*0.5,i'*2.5,i'*5.5)) |
i <- [0..n], let i' = fromIntegral i]
bs :: Array Int B
bs = accum (\b dv -> setV b (getV b + dv)) bs_0 dvs
where
dvs :: [(Int,Double)]
dvs = concat [[(i,dv_i),(j,dv_j)] | (i:is) <- tails [0..n],
j <- is,
let d = getX (bs!i) - getX (bs!j)
sqr = d * d + 0.01
dist = sqrt sqr
mag = dt / (sqr * dist)
dv_i = -d * getM (bs!j) * mag
dv_j = d * getM (bs!i) * mag]
pureVectorSolution :: Int -> V.Vector B
pureVectorSolution n = bs
where
bs_0 :: V.Vector B
bs_0 = V.fromList [(i'*0.5,i'*2.5,i'*5.5) |
i <- [0..n], let i' = fromIntegral i]
bs :: V.Vector B
bs = V.accum (\b dv -> setV b (getV b + dv)) bs_0 dvs
where
dvs :: [(Int,Double)]
dvs = concat [[(i,dv_i),(j,dv_j)] | (i:is) <- tails [0..n],
j <- is,
let d = getX (bs V.! i) - getX (bs V.! j)
sqr = d * d + 0.01
dist = sqrt sqr
mag = dt / (sqr * dist)
dv_i = -d * getM (bs V.! j) * mag
dv_j = d * getM (bs V.! i) * mag]
imperativeSolution :: Int -> IO (M.IOVector B)
imperativeSolution n = do v <- generateVector (0.5, 2.5, 5.5)
loop v
return v
where
generateVector :: B -> IO (M.IOVector B)
generateVector b = do v <- M.new n
generateVector' b v
return v
generateVector' :: B -> M.IOVector B -> IO ()
generateVector' (c1,c2,c3) v = go 0
where
go i | i < n = do let iI = fromIntegral i
M.unsafeWrite v i (iI * c1, iI * c2, iI * c3)
go (i+1)
| otherwise = return ()
loop :: M.IOVector B -> IO ()
loop v = go 0
where
go i | i < n = do go' (i+1)
go (i+1)
| otherwise = return ()
where
go' j | j < n = do doWork i j
go' (j + 1)
| otherwise = return ()
doWork i j = do
bI@(xI,vI,mI) <- M.read v i
bJ@(xJ,vJ,mJ) <- M.read v j
let d = xI - xJ
let sqr = d * d + 0.01
let dist = sqrt sqr
let mag = dt / (sqr * dist)
M.write v i (setV bI (vI - d * mJ * mag))
M.write v j (setV bJ (vJ + d * mI * mag))
main :: IO ()
-- main = do let n = 10
-- putStrLn "Pure solution:"
-- mapM_ print (assocs (pureSolution (n-1)))
-- putStrLn "\nPure vector solution:"
-- V.forM_ (pureVectorSolution (n-1)) print
-- putStrLn "\nImperative solution:"
-- v <- imperativeSolution n
-- v'<- U.unsafeFreeze v
-- U.forM_ v' print
main = defaultMain [
bench "pureSolution" (nf pureSolution (numElts-1)),
bench "pureVectorSolution" (nf pureVectorSolution (numElts-1)),
bench "imperativeSolution" (imperativeSolution numElts)
]
-- Results:
-- n = 100:
-- benchmarking pureSolution
-- mean: 9.340623 ms, lb NaN s, ub 9.578420 ms, ci 0.950
-- std dev: 1.321060 ms, lb 622.3172 us, ub 2.083924 ms, ci 0.950
-- benchmarking pureVectorSolution
-- mean: 9.766345 ms, lb 9.608698 ms, ub NaN s, ci 0.950
-- std dev: 1.260290 ms, lb 583.1293 us, ub 2.690631 ms, ci 0.950
-- benchmarking imperativeSolution
-- mean: 455.0359 us, lb 450.4635 us, ub 468.4372 us, ci 0.950
-- std dev: 44.33842 us, lb 3.906925 us, ub 75.85589 us, ci 0.950
-- n = 1000:
-- benchmarking pureSolution
-- collecting 100 samples, 1 iterations each, in estimated 334.5483 s
-- mean: 2.949640 s, lb 2.867693 s, ub 3.005429 s, ci 0.950
-- std dev: 421.1978 ms, lb 343.8233 ms, ub 539.4906 ms, ci 0.950
-- found 4 outliers among 100 samples (4.0%)
-- 3 (3.0%) high severe
-- variance introduced by outliers: 5.997%
-- variance is slightly inflated by outliers
-- benchmarking pureVectorSolution
-- collecting 100 samples, 1 iterations each, in estimated 280.4593 s
-- mean: 2.747359 s, lb 2.709507 s, ub 2.803392 s, ci 0.950
-- std dev: 237.7489 ms, lb 179.3110 ms, ub 311.8813 ms, ci 0.950
-- found 13 outliers among 100 samples (13.0%)
-- 7 (7.0%) high mild
-- 6 (6.0%) high severe
-- variance introduced by outliers: 2.998%
-- variance is slightly inflated by outliers
-- benchmarking imperativeSolution
-- collecting 100 samples, 1 iterations each, in estimated 5.905104 s
-- mean: 58.59154 ms, lb 56.79405 ms, ub 60.60033 ms, ci 0.950
-- std dev: 11.70101 ms, lb 9.120100 ms, ub NaN s, ci 0.950
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment