Last active
March 17, 2017 21:50
-
-
Save folkertdev/79d65ab5c655e0047544a51f74bcf77e to your computer and use it in GitHub Desktop.
Use a flat array to calculate the optimal order for a series of matrix multiplications with STUArray and the ST Monad
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Control.Monad.ST | |
import Data.Array.IArray as Array | |
import Data.STRef | |
import Data.Array.ST as Array.ST | |
import Data.Array.MArray as MArray | |
import Data.Array.Unboxed | |
solve :: Int -> List Int -> Word32 | |
solve _ [] = 0 | |
solve edgeCost_ problems = | |
let | |
size = length problems | |
edgeCost = fromIntegral edgeCost_ | |
buffer = runSTUArray $ do | |
-- actual api in https://hackage.haskell.org/package/array-0.5.1.1/docs/Data-Array-MArray-Safe.html | |
-- strategy described in http://code.ouroborus.net/fp-syd/past/2015/2015-05-Sutton-DynamicProgramming.pdf | |
cache <- MArray.newArray (1, (size * (size + 1)) `div` 2) (0 :: Word32) | |
-- store the problems in the bottom row | |
forM_ (zip [1..] problems) $ \(index, problem) -> | |
writeArray cache (ix size (index, index)) $ fromIntegral problem | |
let | |
-- 123 12 1 | |
is = concat . reverse $ List.inits [1..size-1] | |
-- 234 34 4 | |
js = concat $ List.tails [2..size] | |
forM_ (zip is js) $ \(i,j) -> do | |
-- the diagonal "tails" | |
let ps = [j, j-1..i+1] | |
qs = [j-1,j-2..i] | |
splits <- | |
zip ps qs | |
|> List.map (\(p, q) -> max <$> readArray cache (ix size (i,q)) | |
<*> readArray cache (ix size (p,j))) | |
|> sequence | |
splits | |
|> List.minimum | |
|> (+ edgeCost) | |
|> writeArray cache (ix size (i, j)) | |
return cache | |
in | |
buffer | |
|> (Array.! 1) | |
{-| point to index | |
find closed form for this | |
-} | |
ix :: Int -> (Int, Int) -> Int | |
ix n (i,j) = | |
if j == n then i else | |
n + ix (n - 1) (i, j) | |
{-| index to point | |
find closed form for this | |
-} | |
coix :: Int -> Int -> (Int, Int) | |
coix n i = | |
let helper n (p,q) = | |
if p <= n then | |
(p, q) | |
else | |
helper (n-1) (p - n, q - 1) | |
in | |
helper n (i, n) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment