Skip to content

Instantly share code, notes, and snippets.

@shangaslammi
Created December 21, 2011 16:04
Show Gist options
  • Save shangaslammi/1506540 to your computer and use it in GitHub Desktop.
Save shangaslammi/1506540 to your computer and use it in GitHub Desktop.
Statically Type-Checked Vector and Matrix Algebra
{-# LANGUAGE TypeFamilies, MultiParamTypeClasses, FlexibleInstances, EmptyDataDecls, OverlappingInstances, FlexibleContexts #-}
module Data.Algebra where
import Prelude hiding (Num(..), (/))
import qualified Prelude as P
import Control.Applicative
import Control.Arrow (first, second)
import Data.Ratio (Ratio, (%))
import Data.Array.Unboxed
import Data.List (transpose)
class Add a b where
type AddResult a b
(+) :: a -> b -> AddResult a b
instance P.Num n => Add n n where
type AddResult n n = n
(+) = (P.+)
class Mul a b where
type MulResult a b
(*) :: a -> b -> MulResult a b
instance P.Num n => Mul n n where
type MulResult n n = n
(*) = (P.*)
class Div a b where
type DivResult a b
(/) :: a -> b -> DivResult a b
instance Div Float Float where
type DivResult Float Float = Float
(/) = (P./)
instance Div Int Int where
type DivResult Int Int = Ratio Int
(/) = (%)
infixl 6 +
infixl 7 *, /
newtype Vector d = Vector { vecArray :: UArray Int Float }
data D1
data Succ d
type D2 = Succ D1
type D3 = Succ D2
vector2d :: Float -> Float -> Vector D2
vector2d x y = Vector $ listArray (1,2) [x,y]
vector3d :: Float -> Float -> Float -> Vector D3
vector3d x y z = Vector $ listArray (1,3) [x,y,z]
instance Mul Float (Vector d) where
type MulResult Float (Vector d) = (Vector d)
x * (Vector v) = Vector . amap (P.* x) $ v
instance Show (Vector d) where
show = show . elems . vecArray
instance Show (Matrix m n) where
show = unlines . map show . rows
instance Add (Vector d) (Vector d) where
type AddResult (Vector d) (Vector d) = Vector d
(Vector a) + (Vector b) = Vector $ listArray bds els where
bds = bounds a
els = zipWith (P.+) (elems a) (elems b)
newtype Matrix m n = Matrix { matArray :: UArray (Int,Int) Float }
instance Mul (Matrix m p) (Matrix p n) where
type MulResult (Matrix m p) (Matrix p n) = Matrix m n
(Matrix a) * (Matrix b) = Matrix $ listArray bds els where
bds = ((1,1),(m,n))
els = el <$> [1..m] <*> [1..n]
el i j = sum [a!(i,k) * b!(k,j) | k <- [1..p]]
(_,(m,p)) = bounds a
(_,(_,n)) = bounds b
row :: Vector d -> Matrix D1 d
row (Vector v) = Matrix $ ixmap ((1,1),(1,n)) (\(1,j) -> j) v where
(_,n) = bounds v
col :: Vector d -> Matrix d D1
col (Vector v) = Matrix $ ixmap ((1,1),(m,1)) (\(i,1) -> 1) v where
(m,_) = bounds v
rows :: Matrix m n -> [Vector n]
rows (Matrix mat) = map Vector $ [ixmap (1,n) (\i -> (k,i)) mat | k <- [1..m]]
where (_,(m,n)) = bounds mat
class Cons e c where
type ConsResult e c
(+:) :: e -> c -> ConsResult e c
infixr 5 +:
instance Cons Float (Vector d) where
type ConsResult Float (Vector d) = Vector (Succ d)
x +: (Vector v) = Vector . listArray bds . (x:) . elems $ v where
bds = second succ $ bounds v
instance Cons (Vector n) (Matrix m n) where
type ConsResult (Vector n) (Matrix m n) = Matrix (Succ m) n
(Vector a) +: (Matrix b) = Matrix $ listArray bds els where
bds = second (first succ) $ bounds b
els = elems a ++ elems b
class Concat a b where
type ConcatResult a b
(++:) :: a -> b -> ConcatResult a b
type family AddT a b
type instance AddT D1 b = Succ b
type instance AddT (Succ a) b = Succ (AddT a b)
instance Concat (Vector a) (Vector b) where
type ConcatResult (Vector a) (Vector b) = Vector (AddT a b)
(Vector a) ++: (Vector b) = Vector $ listArray bds els where
bds = (\(_,i)(_,j) -> (1,i+j)) (bounds a) (bounds b)
els = elems a ++ elems b
data Get d = Safe
getPrev :: Get (Succ d) -> Get d
getPrev _ = Safe
class GetIndex g where
getIndex :: g -> Int
instance GetIndex (Get D1) where
getIndex = const 1
instance GetIndex (Get d) => GetIndex (Get (Succ d)) where
getIndex = (P.+ 1) . getIndex . getPrev
class SafeGet g c where
type GetElem c
(!!!) :: c -> g -> GetElem c
-- Instance for all types (Get i) and (Vector d) for which i <= d
instance (GetIndex (Get i), IsLessEq i d ~ TTrue) => SafeGet (Get i) (Vector d) where
type GetElem (Vector d) = Float
Vector v !!! g = v ! getIndex g
instance (GetIndex (Get i), GetIndex (Get j), IsLessEq i m ~ TTrue, IsLessEq j n ~ TTrue)
=> SafeGet (Get i, Get j) (Matrix m n) where
type GetElem (Matrix m n) = Float
Matrix m !!! (i,j) = m ! (getIndex i, getIndex j)
data TTrue -- type level truth value
-- type level less-or-equal
class TLessEq x y where
type IsLessEq x y
-- D1 is equal to D1
instance TLessEq D1 D1 where
type IsLessEq D1 D1 = TTrue
-- D1 is less than any successor type
instance TLessEq D1 (Succ d) where
type IsLessEq D1 (Succ d) = TTrue
-- The ordering of x and y is the same as (Succ x) and (Succ y)
instance TLessEq x y => TLessEq (Succ x) (Succ y) where
type IsLessEq (Succ x) (Succ y) = IsLessEq x y
data Arb
class ToArb a where
type ArbResult a
toArb :: a -> ArbResult a
instance ToArb (Vector d) where
type ArbResult (Vector d) = Vector Arb
toArb (Vector v) = Vector v
instance ToArb (Matrix m n) where
type ArbResult (Matrix m n) = Matrix Arb Arb
toArb (Matrix m) = Matrix m
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment