Created
December 21, 2021 10:51
-
-
Save lotz84/42fb385171a8f6d559bed052e5710263 to your computer and use it in GitHub Desktop.
Matrix chain multiplication implementation with dependent type
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE AllowAmbiguousTypes #-} | |
{-# LANGUAGE BangPatterns #-} | |
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE KindSignatures #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE TypeApplications #-} | |
module Main where | |
import Prelude hiding ((<>)) | |
import Control.Monad.ST | |
import Data.Foldable | |
import Data.Traversable | |
import Data.Proxy | |
import GHC.TypeLits | |
import Text.Printf | |
import Data.Array.Unboxed | |
import Data.Array.ST | |
import Data.Time.Clock | |
import Data.Time.Format | |
import Numeric.LinearAlgebra.Static | |
data MCList m n where | |
S :: L m n -> MCList m n | |
(:.) :: KnownNat k => L m k -> MCList k n -> MCList m n | |
infixr 5 :. | |
naiveMcm :: (KnownNat m, KnownNat n) => MCList m n -> L m n | |
naiveMcm (S a) = a | |
naiveMcm (a :. b) = a <> naiveMcm b | |
num :: forall n a. (KnownNat n, Num a) => a | |
num = fromIntegral $ natVal (Proxy @n) | |
dims :: forall m n. (KnownNat m, KnownNat n) => MCList m n -> [Int] | |
dims (S a) = [num @m, num @n] | |
dims (_ :. mcm) = num @m : dims mcm | |
data Tree = Leaf | Node Tree Tree | |
deriving (Eq, Ord) | |
minCost :: (KnownNat m, KnownNat n) => MCList m n -> (Int, Tree) | |
minCost xs = runST $ do | |
let ds = dims xs | |
n = length ds - 1 | |
costs = listArray (0,n) ds :: UArray Int Int | |
indices = [(x, x+offset) | offset <- [0..n-1], x <- [0..n-1-offset]] | |
table <- newArray_ ((0, 0), (n-1, n-1)) :: ST s (STArray s (Int, Int) (Int, Tree)) | |
for_ indices $ \(i, j) -> | |
if i == j then writeArray table (i, j) (0, Leaf) | |
else do | |
candidates <- for [i..j-1] $ \k -> do | |
(cik, tik) <- readArray table (i,k) | |
(ckj, tkj) <- readArray table (k+1,j) | |
pure (cik + ckj + costs!i * costs!(k+1) * costs!(j+1), Node tik tkj) | |
writeArray table (i, j) $ minimum candidates | |
readArray table (0, n-1) | |
data MCTree m n where | |
L :: L m n -> MCTree m n | |
N :: KnownNat k => MCTree m k -> MCTree k n -> MCTree m n | |
mcmTree :: (KnownNat m, KnownNat n) => MCTree m n -> L m n | |
mcmTree (L a) = a | |
mcmTree (N l r) = mcmTree l <> mcmTree r | |
data Tree' a = Leaf' a | |
| Node' (Tree' a) (Tree' a) | |
buildLT :: [a] -> Tree' () -> (Tree' a, [a]) | |
buildLT (a:as) (Leaf' _) = (Leaf' a, as) | |
buildLT as (Node' l r) = | |
let (l', as') = buildLT as l | |
(r', rest) = buildLT as' r | |
in (Node' l' r', rest) | |
data SomeTreeList m n where | |
NoRest :: MCTree m n -> SomeTreeList m n | |
SomeTreeList :: KnownNat k => (MCTree m k, MCList k n) -> SomeTreeList m n | |
buildMCTree :: (KnownNat m, KnownNat n) => MCList m n -> Tree -> SomeTreeList m n | |
buildMCTree (S a) Leaf = NoRest (L a) | |
buildMCTree (a :. b) Leaf = SomeTreeList (L a, b) | |
buildMCTree as (Node l r) = | |
case buildMCTree as l of | |
(NoRest _) -> error "Tree is too short" | |
(SomeTreeList (l', as')) -> | |
case buildMCTree as' r of | |
(NoRest r') -> NoRest (N l' r') | |
(SomeTreeList (r', rest)) -> (SomeTreeList (N l' r', rest)) | |
mcm :: (KnownNat m, KnownNat n) => MCList m n -> L m n | |
mcm xs = | |
let (_, parenthesis) = minCost xs | |
in case buildMCTree xs parenthesis of | |
(NoRest tree) -> mcmTree tree | |
(SomeTreeList (tree, _)) -> error "Tree is too long" | |
withTime :: IO a -> IO () | |
withTime action = do | |
start <- getCurrentTime | |
action | |
end <- getCurrentTime | |
putStrLn $ formatTime defaultTimeLocale "Time: %-3Ess" (diffUTCTime end start) | |
main :: IO () | |
main = do | |
putStrLn "Generating random matrices" | |
!a <- randn @100 @500 | |
!b <- randn @500 @1000 | |
!c <- randn @1000 @5000 | |
!d <- randn @5000 @10000 | |
let m = a :. b :. c :. S d | |
putStrLn "# mcm" | |
withTime . putStrLn $ "norm: " ++ show (norm_2 (mcm m)) | |
putStrLn "# naiveMcm" | |
withTime . putStrLn $ "norm: " ++ show (norm_2 (naiveMcm m)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment