Created
December 21, 2011 16:04
-
-
Save shangaslammi/1506540 to your computer and use it in GitHub Desktop.
Statically Type-Checked Vector and Matrix Algebra
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TypeFamilies, MultiParamTypeClasses, FlexibleInstances, EmptyDataDecls, OverlappingInstances, FlexibleContexts #-} | |
module Data.Algebra where | |
import Prelude hiding (Num(..), (/)) | |
import qualified Prelude as P | |
import Control.Applicative | |
import Control.Arrow (first, second) | |
import Data.Ratio (Ratio, (%)) | |
import Data.Array.Unboxed | |
import Data.List (transpose) | |
class Add a b where | |
type AddResult a b | |
(+) :: a -> b -> AddResult a b | |
instance P.Num n => Add n n where | |
type AddResult n n = n | |
(+) = (P.+) | |
class Mul a b where | |
type MulResult a b | |
(*) :: a -> b -> MulResult a b | |
instance P.Num n => Mul n n where | |
type MulResult n n = n | |
(*) = (P.*) | |
class Div a b where | |
type DivResult a b | |
(/) :: a -> b -> DivResult a b | |
instance Div Float Float where | |
type DivResult Float Float = Float | |
(/) = (P./) | |
instance Div Int Int where | |
type DivResult Int Int = Ratio Int | |
(/) = (%) | |
infixl 6 + | |
infixl 7 *, / | |
newtype Vector d = Vector { vecArray :: UArray Int Float } | |
data D1 | |
data Succ d | |
type D2 = Succ D1 | |
type D3 = Succ D2 | |
vector2d :: Float -> Float -> Vector D2 | |
vector2d x y = Vector $ listArray (1,2) [x,y] | |
vector3d :: Float -> Float -> Float -> Vector D3 | |
vector3d x y z = Vector $ listArray (1,3) [x,y,z] | |
instance Mul Float (Vector d) where | |
type MulResult Float (Vector d) = (Vector d) | |
x * (Vector v) = Vector . amap (P.* x) $ v | |
instance Show (Vector d) where | |
show = show . elems . vecArray | |
instance Show (Matrix m n) where | |
show = unlines . map show . rows | |
instance Add (Vector d) (Vector d) where | |
type AddResult (Vector d) (Vector d) = Vector d | |
(Vector a) + (Vector b) = Vector $ listArray bds els where | |
bds = bounds a | |
els = zipWith (P.+) (elems a) (elems b) | |
newtype Matrix m n = Matrix { matArray :: UArray (Int,Int) Float } | |
instance Mul (Matrix m p) (Matrix p n) where | |
type MulResult (Matrix m p) (Matrix p n) = Matrix m n | |
(Matrix a) * (Matrix b) = Matrix $ listArray bds els where | |
bds = ((1,1),(m,n)) | |
els = el <$> [1..m] <*> [1..n] | |
el i j = sum [a!(i,k) * b!(k,j) | k <- [1..p]] | |
(_,(m,p)) = bounds a | |
(_,(_,n)) = bounds b | |
row :: Vector d -> Matrix D1 d | |
row (Vector v) = Matrix $ ixmap ((1,1),(1,n)) (\(1,j) -> j) v where | |
(_,n) = bounds v | |
col :: Vector d -> Matrix d D1 | |
col (Vector v) = Matrix $ ixmap ((1,1),(m,1)) (\(i,1) -> 1) v where | |
(m,_) = bounds v | |
rows :: Matrix m n -> [Vector n] | |
rows (Matrix mat) = map Vector $ [ixmap (1,n) (\i -> (k,i)) mat | k <- [1..m]] | |
where (_,(m,n)) = bounds mat | |
class Cons e c where | |
type ConsResult e c | |
(+:) :: e -> c -> ConsResult e c | |
infixr 5 +: | |
instance Cons Float (Vector d) where | |
type ConsResult Float (Vector d) = Vector (Succ d) | |
x +: (Vector v) = Vector . listArray bds . (x:) . elems $ v where | |
bds = second succ $ bounds v | |
instance Cons (Vector n) (Matrix m n) where | |
type ConsResult (Vector n) (Matrix m n) = Matrix (Succ m) n | |
(Vector a) +: (Matrix b) = Matrix $ listArray bds els where | |
bds = second (first succ) $ bounds b | |
els = elems a ++ elems b | |
class Concat a b where | |
type ConcatResult a b | |
(++:) :: a -> b -> ConcatResult a b | |
type family AddT a b | |
type instance AddT D1 b = Succ b | |
type instance AddT (Succ a) b = Succ (AddT a b) | |
instance Concat (Vector a) (Vector b) where | |
type ConcatResult (Vector a) (Vector b) = Vector (AddT a b) | |
(Vector a) ++: (Vector b) = Vector $ listArray bds els where | |
bds = (\(_,i)(_,j) -> (1,i+j)) (bounds a) (bounds b) | |
els = elems a ++ elems b | |
data Get d = Safe | |
getPrev :: Get (Succ d) -> Get d | |
getPrev _ = Safe | |
class GetIndex g where | |
getIndex :: g -> Int | |
instance GetIndex (Get D1) where | |
getIndex = const 1 | |
instance GetIndex (Get d) => GetIndex (Get (Succ d)) where | |
getIndex = (P.+ 1) . getIndex . getPrev | |
class SafeGet g c where | |
type GetElem c | |
(!!!) :: c -> g -> GetElem c | |
-- Instance for all types (Get i) and (Vector d) for which i <= d | |
instance (GetIndex (Get i), IsLessEq i d ~ TTrue) => SafeGet (Get i) (Vector d) where | |
type GetElem (Vector d) = Float | |
Vector v !!! g = v ! getIndex g | |
instance (GetIndex (Get i), GetIndex (Get j), IsLessEq i m ~ TTrue, IsLessEq j n ~ TTrue) | |
=> SafeGet (Get i, Get j) (Matrix m n) where | |
type GetElem (Matrix m n) = Float | |
Matrix m !!! (i,j) = m ! (getIndex i, getIndex j) | |
data TTrue -- type level truth value | |
-- type level less-or-equal | |
class TLessEq x y where | |
type IsLessEq x y | |
-- D1 is equal to D1 | |
instance TLessEq D1 D1 where | |
type IsLessEq D1 D1 = TTrue | |
-- D1 is less than any successor type | |
instance TLessEq D1 (Succ d) where | |
type IsLessEq D1 (Succ d) = TTrue | |
-- The ordering of x and y is the same as (Succ x) and (Succ y) | |
instance TLessEq x y => TLessEq (Succ x) (Succ y) where | |
type IsLessEq (Succ x) (Succ y) = IsLessEq x y | |
data Arb | |
class ToArb a where | |
type ArbResult a | |
toArb :: a -> ArbResult a | |
instance ToArb (Vector d) where | |
type ArbResult (Vector d) = Vector Arb | |
toArb (Vector v) = Vector v | |
instance ToArb (Matrix m n) where | |
type ArbResult (Matrix m n) = Matrix Arb Arb | |
toArb (Matrix m) = Matrix m |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment