Created
September 24, 2019 17:18
-
-
Save lotz84/78474ac9ee307d50376e025093316d0f to your computer and use it in GitHub Desktop.
Tensor implementation using Representable Functor
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE IncoherentInstances #-} | |
{-# LANGUAGE LiberalTypeSynonyms #-} | |
{-# LANGUAGE MultiParamTypeClasses #-} | |
{-# LANGUAGE PolyKinds #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE TypeFamilyDependencies #-} | |
{-# LANGUAGE TypeOperators #-} | |
module Main where | |
import Data.Coerce | |
import Data.Functor.Identity | |
import Data.Function (fix) | |
import Data.Functor.Product | |
import Data.Functor.Compose | |
import Data.Maybe | |
import Data.Proxy | |
import Data.Type.Bool | |
import GHC.Natural | |
import GHC.TypeLits | |
import Data.Finite (Finite) | |
import Data.List.Split (chunksOf) | |
import qualified Data.Vector.Sized as V | |
class Functor f => Representable f where | |
type Log f | |
index :: f a -> (Log f -> a) | |
tabulate :: (Log f -> a) -> f a | |
positions :: f (Log f) | |
tabulate h = fmap h positions | |
positions = tabulate id | |
instance Representable ((->) r) where | |
type Log ((->) r) = r | |
index = ($) | |
positions = id | |
distribute :: (Functor f, Representable g) => f (g a) -> g (f a) | |
distribute fga = tabulate $ \i -> fmap (`index` i) fga | |
instance Representable Identity where | |
type Log Identity = () | |
index a _ = runIdentity a | |
positions = Identity () | |
data Diag a = Diag a a | |
instance Functor Diag where | |
fmap f (Diag a b) = Diag (f a) (f b) | |
instance Representable Diag where | |
type Log Diag = Bool | |
index (Diag a b) c = if c then b else a | |
positions = Diag False True | |
instance KnownNat n => Representable (V.Vector n) where | |
type Log (V.Vector n) = Finite n | |
index = V.index | |
positions = V.generate id | |
data Stream a = Cons a (Stream a) | |
instance Functor Stream where | |
fmap f (Cons a as) = Cons (f a) (fmap f as) | |
sIndex :: Stream a -> Natural -> a | |
sIndex (Cons a _ ) 0 = a | |
sIndex (Cons _ as) n = sIndex as (n-1) | |
instance Representable Stream where | |
type Log Stream = Natural | |
index = sIndex | |
positions = fix $ \xs -> Cons 0 (fmap (+1) xs) | |
instance (Representable f, Representable g) => Representable (Product f g) where | |
type Log (Product f g) = Either (Log f) (Log g) | |
index (Pair fa _) (Left i) = index fa i | |
index (Pair _ ga) (Right i) = index ga i | |
positions = Pair (tabulate Left) (tabulate Right) | |
instance (Representable f, Representable g) => Representable (Compose f g) where | |
type Log (Compose f g) = (Log f, Log g) | |
index fga (i, j) = (`index` j) . (`index` i) $ getCompose fga | |
tabulate = Compose . tabulate . fmap tabulate . curry | |
type family Tensor (xs :: [Nat]) = r | r -> xs where | |
Tensor '[] = Identity | |
Tensor (n ': ns) = Compose (Tensor ns) (V.Vector n) | |
type Vector n a = Tensor '[n] a | |
type Matrix m n a = Tensor '[n, m] a | |
class FromList f where | |
fromList :: [a] -> f a | |
instance FromList Identity where | |
fromList [a] = Identity a | |
instance forall n t. (KnownNat n, FromList t) => FromList (Compose t (V.Vector n)) where | |
fromList as = | |
let n = fromIntegral $ natVal (Proxy @n) | |
in Compose . fromList . map (fromJust . V.fromList) $ chunksOf n as | |
liftR2 :: Representable f => (a -> b -> c) -> f a -> f b -> f c | |
liftR2 f fa fb = tabulate $ \i -> f (index fa i) (index fb i) | |
dot :: (KnownNat n, Num a) => Vector n a -> Vector n a -> a | |
dot = (sum.) . liftR2 (*) | |
transpose :: (KnownNat m, KnownNat n) => Matrix m n a -> Matrix n m a | |
transpose (Compose (Compose xs)) = Compose (Compose (fmap distribute xs)) | |
matmul :: (KnownNat p, KnownNat q, KnownNat r, Num a) | |
=> Matrix p q a -> Matrix q r a -> Matrix p r a | |
matmul (Compose a) b' = | |
let (Compose b) = transpose b' | |
vec = Compose . Identity | |
in tabulate $ \(((), j), i) -> (vec $ index a ((), j)) `dot` (vec $ index b ((), i)) | |
unary :: Functor (Tensor ns) => (a -> b) -> Tensor ns a -> Tensor ns b | |
unary = fmap | |
class Shapely ns where | |
replicateT :: a -> Tensor ns a | |
instance Shapely '[] where | |
replicateT a = Identity a | |
instance (KnownNat n, Shapely ns, Representable (Tensor ns)) => Shapely (n ': ns) where | |
replicateT a = Compose (replicateT (tabulate (const a))) | |
type family Max xs ys where | |
Max '[] '[] = '[] | |
Max (x:xs) '[] = x ': Max xs '[] | |
Max '[] (y:ys) = y ': Max ys '[] | |
Max (x:xs) (y:ys) = (If (x <=? y) y x) ': Max xs ys | |
class Alignable (ns :: [Nat]) (ms :: [Nat]) where | |
align :: Tensor ns a -> Tensor ms a | |
instance Alignable '[] '[] where | |
align = id | |
instance Alignable xs ys => Alignable (x ': xs) (x ': ys) where | |
align (Compose t) = Compose (align t) | |
instance (KnownNat y, Functor (Tensor xs), Alignable xs ys) => Alignable (1 ': xs) (y ': ys) where | |
align = | |
let y = natVal (Proxy @y) | |
in Compose . align . fmap (V.replicate . V.head) . getCompose | |
instance (KnownNat n, Shapely ns, Representable (Tensor ns)) => Alignable '[] (n ': ns) where | |
align = replicateT . runIdentity | |
binary :: (Max xs ys ~ zs, Alignable xs zs, Alignable ys zs, Representable (Tensor zs)) | |
=> (a -> b -> c) | |
-> Tensor xs a -> Tensor ys b -> Tensor zs c | |
binary f xs ys = liftR2 f (align xs) (align ys) | |
main :: IO () | |
main = print "Hello Representable Functor" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
... | |
dependencies: | |
- finite-typelits | |
- split | |
- vector-sized |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment