Skip to content

Instantly share code, notes, and snippets.

@lotz84
Created December 21, 2021 10:51
Show Gist options
  • Save lotz84/42fb385171a8f6d559bed052e5710263 to your computer and use it in GitHub Desktop.
Save lotz84/42fb385171a8f6d559bed052e5710263 to your computer and use it in GitHub Desktop.
Matrix chain multiplication implementation with dependent type
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module Main where
import Prelude hiding ((<>))
import Control.Monad.ST
import Data.Foldable
import Data.Traversable
import Data.Proxy
import GHC.TypeLits
import Text.Printf
import Data.Array.Unboxed
import Data.Array.ST
import Data.Time.Clock
import Data.Time.Format
import Numeric.LinearAlgebra.Static
data MCList m n where
S :: L m n -> MCList m n
(:.) :: KnownNat k => L m k -> MCList k n -> MCList m n
infixr 5 :.
naiveMcm :: (KnownNat m, KnownNat n) => MCList m n -> L m n
naiveMcm (S a) = a
naiveMcm (a :. b) = a <> naiveMcm b
num :: forall n a. (KnownNat n, Num a) => a
num = fromIntegral $ natVal (Proxy @n)
dims :: forall m n. (KnownNat m, KnownNat n) => MCList m n -> [Int]
dims (S a) = [num @m, num @n]
dims (_ :. mcm) = num @m : dims mcm
data Tree = Leaf | Node Tree Tree
deriving (Eq, Ord)
minCost :: (KnownNat m, KnownNat n) => MCList m n -> (Int, Tree)
minCost xs = runST $ do
let ds = dims xs
n = length ds - 1
costs = listArray (0,n) ds :: UArray Int Int
indices = [(x, x+offset) | offset <- [0..n-1], x <- [0..n-1-offset]]
table <- newArray_ ((0, 0), (n-1, n-1)) :: ST s (STArray s (Int, Int) (Int, Tree))
for_ indices $ \(i, j) ->
if i == j then writeArray table (i, j) (0, Leaf)
else do
candidates <- for [i..j-1] $ \k -> do
(cik, tik) <- readArray table (i,k)
(ckj, tkj) <- readArray table (k+1,j)
pure (cik + ckj + costs!i * costs!(k+1) * costs!(j+1), Node tik tkj)
writeArray table (i, j) $ minimum candidates
readArray table (0, n-1)
data MCTree m n where
L :: L m n -> MCTree m n
N :: KnownNat k => MCTree m k -> MCTree k n -> MCTree m n
mcmTree :: (KnownNat m, KnownNat n) => MCTree m n -> L m n
mcmTree (L a) = a
mcmTree (N l r) = mcmTree l <> mcmTree r
data Tree' a = Leaf' a
| Node' (Tree' a) (Tree' a)
buildLT :: [a] -> Tree' () -> (Tree' a, [a])
buildLT (a:as) (Leaf' _) = (Leaf' a, as)
buildLT as (Node' l r) =
let (l', as') = buildLT as l
(r', rest) = buildLT as' r
in (Node' l' r', rest)
data SomeTreeList m n where
NoRest :: MCTree m n -> SomeTreeList m n
SomeTreeList :: KnownNat k => (MCTree m k, MCList k n) -> SomeTreeList m n
buildMCTree :: (KnownNat m, KnownNat n) => MCList m n -> Tree -> SomeTreeList m n
buildMCTree (S a) Leaf = NoRest (L a)
buildMCTree (a :. b) Leaf = SomeTreeList (L a, b)
buildMCTree as (Node l r) =
case buildMCTree as l of
(NoRest _) -> error "Tree is too short"
(SomeTreeList (l', as')) ->
case buildMCTree as' r of
(NoRest r') -> NoRest (N l' r')
(SomeTreeList (r', rest)) -> (SomeTreeList (N l' r', rest))
mcm :: (KnownNat m, KnownNat n) => MCList m n -> L m n
mcm xs =
let (_, parenthesis) = minCost xs
in case buildMCTree xs parenthesis of
(NoRest tree) -> mcmTree tree
(SomeTreeList (tree, _)) -> error "Tree is too long"
withTime :: IO a -> IO ()
withTime action = do
start <- getCurrentTime
action
end <- getCurrentTime
putStrLn $ formatTime defaultTimeLocale "Time: %-3Ess" (diffUTCTime end start)
main :: IO ()
main = do
putStrLn "Generating random matrices"
!a <- randn @100 @500
!b <- randn @500 @1000
!c <- randn @1000 @5000
!d <- randn @5000 @10000
let m = a :. b :. c :. S d
putStrLn "# mcm"
withTime . putStrLn $ "norm: " ++ show (norm_2 (mcm m))
putStrLn "# naiveMcm"
withTime . putStrLn $ "norm: " ++ show (norm_2 (naiveMcm m))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment