Created
June 22, 2011 00:11
-
-
Save 23Skidoo/1039262 to your computer and use it in GitHub Desktop.
Functional vs. imperative nested triangular loops
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Main | |
where | |
-- See http://stackoverflow.com/questions/6411771/nested-triangular-loop-in-haskell | |
import Control.DeepSeq | |
import Criterion.Main | |
import Data.Array | |
import Data.List (tails) | |
import qualified Data.Vector as V | |
import qualified Data.Vector.Unboxed as U | |
import qualified Data.Vector.Unboxed.Mutable as M | |
numElts :: Int | |
numElts = 10 | |
dt::Double | |
dt = 0.001 | |
type B = (Double, Double, Double) | |
getX,getV,getM :: B -> Double | |
setX,setV,setM :: B -> Double -> B | |
getX (x,_,_) = x | |
setX (_,v,m) x' = (x',v,m) | |
getV (_,v,_) = v | |
setV (x,_,m) v' = (x,v',m) | |
getM (_,_,m) = m | |
setM (x,v,_) m' = (x,v,m') | |
instance U.Unbox a => NFData (U.Vector a) | |
instance NFData a => NFData (V.Vector a) where | |
rnf v = V.foldl' (\x y -> y `deepseq` x) () v | |
pureSolution :: Int -> Array Int B | |
pureSolution n = bs | |
where | |
bs_0 :: Array Int B | |
bs_0 = array (0,n) [(i,(i'*0.5,i'*2.5,i'*5.5)) | | |
i <- [0..n], let i' = fromIntegral i] | |
bs :: Array Int B | |
bs = accum (\b dv -> setV b (getV b + dv)) bs_0 dvs | |
where | |
dvs :: [(Int,Double)] | |
dvs = concat [[(i,dv_i),(j,dv_j)] | (i:is) <- tails [0..n], | |
j <- is, | |
let d = getX (bs!i) - getX (bs!j) | |
sqr = d * d + 0.01 | |
dist = sqrt sqr | |
mag = dt / (sqr * dist) | |
dv_i = -d * getM (bs!j) * mag | |
dv_j = d * getM (bs!i) * mag] | |
pureVectorSolution :: Int -> V.Vector B | |
pureVectorSolution n = bs | |
where | |
bs_0 :: V.Vector B | |
bs_0 = V.fromList [(i'*0.5,i'*2.5,i'*5.5) | | |
i <- [0..n], let i' = fromIntegral i] | |
bs :: V.Vector B | |
bs = V.accum (\b dv -> setV b (getV b + dv)) bs_0 dvs | |
where | |
dvs :: [(Int,Double)] | |
dvs = concat [[(i,dv_i),(j,dv_j)] | (i:is) <- tails [0..n], | |
j <- is, | |
let d = getX (bs V.! i) - getX (bs V.! j) | |
sqr = d * d + 0.01 | |
dist = sqrt sqr | |
mag = dt / (sqr * dist) | |
dv_i = -d * getM (bs V.! j) * mag | |
dv_j = d * getM (bs V.! i) * mag] | |
imperativeSolution :: Int -> IO (M.IOVector B) | |
imperativeSolution n = do v <- generateVector (0.5, 2.5, 5.5) | |
loop v | |
return v | |
where | |
generateVector :: B -> IO (M.IOVector B) | |
generateVector b = do v <- M.new n | |
generateVector' b v | |
return v | |
generateVector' :: B -> M.IOVector B -> IO () | |
generateVector' (c1,c2,c3) v = go 0 | |
where | |
go i | i < n = do let iI = fromIntegral i | |
M.unsafeWrite v i (iI * c1, iI * c2, iI * c3) | |
go (i+1) | |
| otherwise = return () | |
loop :: M.IOVector B -> IO () | |
loop v = go 0 | |
where | |
go i | i < n = do go' (i+1) | |
go (i+1) | |
| otherwise = return () | |
where | |
go' j | j < n = do doWork i j | |
go' (j + 1) | |
| otherwise = return () | |
doWork i j = do | |
bI@(xI,vI,mI) <- M.read v i | |
bJ@(xJ,vJ,mJ) <- M.read v j | |
let d = xI - xJ | |
let sqr = d * d + 0.01 | |
let dist = sqrt sqr | |
let mag = dt / (sqr * dist) | |
M.write v i (setV bI (vI - d * mJ * mag)) | |
M.write v j (setV bJ (vJ + d * mI * mag)) | |
main :: IO () | |
-- main = do let n = 10 | |
-- putStrLn "Pure solution:" | |
-- mapM_ print (assocs (pureSolution (n-1))) | |
-- putStrLn "\nPure vector solution:" | |
-- V.forM_ (pureVectorSolution (n-1)) print | |
-- putStrLn "\nImperative solution:" | |
-- v <- imperativeSolution n | |
-- v'<- U.unsafeFreeze v | |
-- U.forM_ v' print | |
main = defaultMain [ | |
bench "pureSolution" (nf pureSolution (numElts-1)), | |
bench "pureVectorSolution" (nf pureVectorSolution (numElts-1)), | |
bench "imperativeSolution" (imperativeSolution numElts) | |
] | |
-- Results: | |
-- n = 100: | |
-- benchmarking pureSolution | |
-- mean: 9.340623 ms, lb NaN s, ub 9.578420 ms, ci 0.950 | |
-- std dev: 1.321060 ms, lb 622.3172 us, ub 2.083924 ms, ci 0.950 | |
-- benchmarking pureVectorSolution | |
-- mean: 9.766345 ms, lb 9.608698 ms, ub NaN s, ci 0.950 | |
-- std dev: 1.260290 ms, lb 583.1293 us, ub 2.690631 ms, ci 0.950 | |
-- benchmarking imperativeSolution | |
-- mean: 455.0359 us, lb 450.4635 us, ub 468.4372 us, ci 0.950 | |
-- std dev: 44.33842 us, lb 3.906925 us, ub 75.85589 us, ci 0.950 | |
-- n = 1000: | |
-- benchmarking pureSolution | |
-- collecting 100 samples, 1 iterations each, in estimated 334.5483 s | |
-- mean: 2.949640 s, lb 2.867693 s, ub 3.005429 s, ci 0.950 | |
-- std dev: 421.1978 ms, lb 343.8233 ms, ub 539.4906 ms, ci 0.950 | |
-- found 4 outliers among 100 samples (4.0%) | |
-- 3 (3.0%) high severe | |
-- variance introduced by outliers: 5.997% | |
-- variance is slightly inflated by outliers | |
-- benchmarking pureVectorSolution | |
-- collecting 100 samples, 1 iterations each, in estimated 280.4593 s | |
-- mean: 2.747359 s, lb 2.709507 s, ub 2.803392 s, ci 0.950 | |
-- std dev: 237.7489 ms, lb 179.3110 ms, ub 311.8813 ms, ci 0.950 | |
-- found 13 outliers among 100 samples (13.0%) | |
-- 7 (7.0%) high mild | |
-- 6 (6.0%) high severe | |
-- variance introduced by outliers: 2.998% | |
-- variance is slightly inflated by outliers | |
-- benchmarking imperativeSolution | |
-- collecting 100 samples, 1 iterations each, in estimated 5.905104 s | |
-- mean: 58.59154 ms, lb 56.79405 ms, ub 60.60033 ms, ci 0.950 | |
-- std dev: 11.70101 ms, lb 9.120100 ms, ub NaN s, ci 0.950 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment