Last active
December 17, 2015 06:48
-
-
Save sjoerdvisscher/5567699 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TypeFamilies, GADTs, TypeOperators #-} | |
module LinearMap where | |
import Prelude hiding ((.)) | |
import Data.Category | |
import Data.Category.Limit | |
type a :* b = (a,b) | |
infix 7 :&& | |
infix 6 :|| | |
data LM s :: * -> * -> * where | |
Zero :: Obj (LM s) a -> Obj (LM s) b -> LM s a b | |
Elt :: s -> LM s s s | |
(:||) :: LM s a c -> LM s b c -> LM s (a :* b) c | |
(:&&) :: LM s a c -> LM s a d -> LM s a (c :* d) | |
instance (Eq s, Num s) => Eq (LM s a b) where | |
f == g = toMatrix f == toMatrix g | |
instance (Ord s, Num s) => Ord (LM s a b) where | |
f `compare` g = toMatrix f `compare` toMatrix g | |
toMatrix :: Num s => LM s a b -> [[s]] | |
toMatrix (Elt s) = [[s]] | |
toMatrix (f :&& g) = toMatrix f ++ toMatrix g | |
toMatrix (f :|| g) = zipWith (++) (toMatrix f) (toMatrix g) | |
toMatrix (Zero a b) = toMatrix (initialize b . terminate a) | |
apply :: Num s => LM s a b -> a -> b | |
apply m = fromLM . (m .) . toLM (src m) | |
toLM :: Num s => Obj (LM s) a -> a -> LM s s a | |
toLM (Zero a _) = toLM a | |
toLM (Elt _) = Elt | |
toLM (f :|| _) = toLM (tgt f) | |
toLM (f :&& g) = \(a, b) -> toLM (tgt f) a &&& toLM (tgt g) b | |
fromLM :: Num s => LM s s a -> a | |
fromLM (Zero _ a) = fromLM (initialize a) | |
fromLM (Elt s) = s | |
fromLM (f :&& g) = (fromLM f, fromLM g) | |
fromLM (_ :|| _) = error "fromLM: unexpected argument" | |
idS :: Num s => Obj (LM s) s | |
idS = Elt 1 | |
(+^+) :: Num s => LM s a b -> LM s a b -> LM s a b | |
Zero{} +^+ f = f | |
f +^+ Zero{} = f | |
Elt s +^+ Elt t = Elt (s + t) | |
(f :|| g) +^+ (h :|| k) = (f +^+ h) ||| (g +^+ k) | |
(f :&& g) +^+ (h :&& k) = (f +^+ h) &&& (g +^+ k) | |
_ +^+ _ = error "(+^+) for LM s a b: unexpected combination" | |
-- The last case cannot arise unless pairs are scalars. | |
(*^^) :: Num s => s -> LM s a b -> LM s a b | |
_ *^^ Zero a b = Zero a b | |
s *^^ Elt t = Elt (s * t) | |
s *^^ (f :|| g) = s *^^ f ||| s *^^ g | |
s *^^ (f :&& g) = s *^^ f &&& s *^^ g | |
instance Num s => Category (LM s) where | |
src (Zero a _) = a | |
src (Elt _) = idS | |
src (f :|| g) = src f +++ src g | |
src (f :&& _) = src f | |
tgt (Zero _ b) = b | |
tgt (Elt _) = idS | |
tgt (f :|| _) = tgt f | |
tgt (f :&& g) = tgt f *** tgt g | |
Zero _ c . f = Zero (src f) c | |
f . Zero a _ = Zero a (tgt f) | |
Elt s . Elt t = Elt (s * t) | |
(f :&& g) . h = (f . h) &&& (g . h) | |
h . (f :|| g) = (h . f) ||| (h . g) | |
(f :|| g) . (h :&& k) = (f . h) +^+ (g . k) | |
_ . _ = error "(.) for LM s: unexpected combination" | |
instance Num s => HasTerminalObject (LM s) where | |
type TerminalObject (LM s) = s | |
terminalObject = idS | |
terminate (Zero a _) = Zero a idS | |
terminate (Elt _) = Elt 0 | |
terminate (f :|| g) = terminate (src f) ||| terminate (src g) | |
terminate (f :&& g) = terminate (tgt f) ||| terminate (tgt g) | |
instance Num s => HasInitialObject (LM s) where | |
type InitialObject (LM s) = s | |
initialObject = idS | |
initialize (Zero _ b) = Zero idS b | |
initialize (Elt _) = Elt 0 | |
initialize (f :|| g) = initialize (src f) &&& initialize (src g) | |
initialize (f :&& g) = initialize (tgt f) &&& initialize (tgt g) | |
instance Num s => HasBinaryProducts (LM s) where | |
type BinaryProduct (LM s) a b = a :* b | |
proj1 a b = a ||| Zero b a | |
proj2 a b = Zero a b ||| b | |
f *** g = (f ||| Zero (src g) (tgt f)) &&& (Zero (src f) (tgt g) ||| g) | |
f &&& g = f :&& g | |
instance Num s => HasBinaryCoproducts (LM s) where | |
type BinaryCoproduct (LM s) a b = a :* b | |
inj1 a b = a &&& Zero a b | |
inj2 a b = Zero b a &&& b | |
f +++ g = (f &&& Zero (src f) (tgt g)) ||| (Zero (src g) (tgt f) &&& g) | |
f ||| g = f :|| g | |
-- (***) and (+++) are equivalent, the operation is called the direct sum | |
transpose :: LM s a b -> LM s b a | |
transpose (Zero a b) = Zero b a | |
transpose (Elt s) = Elt s | |
transpose (f :|| g) = transpose f :&& transpose g | |
transpose (f :&& g) = transpose f :|| transpose g |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TypeFamilies, TypeOperators #-} | |
{-# LANGUAGE MultiParamTypeClasses, FlexibleContexts, FlexibleInstances #-} | |
module LinearMapApi where | |
import Prelude hiding (id) | |
import Data.VectorSpace | |
import Data.Category.Limit | |
import LinearMap | |
class IsObj cat a where id :: cat a a | |
instance IsObj (LM Int) Int where id = Elt 1 | |
instance IsObj (LM Integer) Integer where id = Elt 1 | |
instance IsObj (LM Double) Double where id = Elt 1 | |
instance IsObj (LM Float) Float where id = Elt 1 | |
instance (IsObj (LM s) a, IsObj (LM s) b, Num s) => IsObj (LM s) (a, b) where | |
id = id *** id | |
instance (IsObj (LM s) a, IsObj (LM s) b, Num s) => AdditiveGroup (LM s a b) where | |
zeroV = zeroL id id | |
negateV = ((-1) *^) | |
(^+^) = (+^+) | |
instance (IsObj (LM s) a, IsObj (LM s) b, Num s) => VectorSpace (LM s a b) where | |
type Scalar (LM s a b) = s | |
(*^) = (*^^) | |
fstL :: (Num s, IsObj (LM s) a, IsObj (LM s) b) => LM s (a :* b) a | |
fstL = proj1 id id | |
sndL :: (Num s, IsObj (LM s) a, IsObj (LM s) b) => LM s (a :* b) b | |
sndL = proj2 id id | |
leftL :: (Num s, IsObj (LM s) a, IsObj (LM s) b) => LM s a (a :* b) | |
leftL = inj1 id id | |
rightL :: (Num s, IsObj (LM s) a, IsObj (LM s) b) => LM s b (a :* b) | |
rightL = inj2 id id |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment