Skip to content

Instantly share code, notes, and snippets.

@lotz84
Created September 24, 2019 17:18
Show Gist options
  • Save lotz84/78474ac9ee307d50376e025093316d0f to your computer and use it in GitHub Desktop.
Save lotz84/78474ac9ee307d50376e025093316d0f to your computer and use it in GitHub Desktop.
Tensor implementation using Representable Functor
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE IncoherentInstances #-}
{-# LANGUAGE LiberalTypeSynonyms #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilyDependencies #-}
{-# LANGUAGE TypeOperators #-}
module Main where
import Data.Coerce
import Data.Functor.Identity
import Data.Function (fix)
import Data.Functor.Product
import Data.Functor.Compose
import Data.Maybe
import Data.Proxy
import Data.Type.Bool
import GHC.Natural
import GHC.TypeLits
import Data.Finite (Finite)
import Data.List.Split (chunksOf)
import qualified Data.Vector.Sized as V
class Functor f => Representable f where
type Log f
index :: f a -> (Log f -> a)
tabulate :: (Log f -> a) -> f a
positions :: f (Log f)
tabulate h = fmap h positions
positions = tabulate id
instance Representable ((->) r) where
type Log ((->) r) = r
index = ($)
positions = id
distribute :: (Functor f, Representable g) => f (g a) -> g (f a)
distribute fga = tabulate $ \i -> fmap (`index` i) fga
instance Representable Identity where
type Log Identity = ()
index a _ = runIdentity a
positions = Identity ()
data Diag a = Diag a a
instance Functor Diag where
fmap f (Diag a b) = Diag (f a) (f b)
instance Representable Diag where
type Log Diag = Bool
index (Diag a b) c = if c then b else a
positions = Diag False True
instance KnownNat n => Representable (V.Vector n) where
type Log (V.Vector n) = Finite n
index = V.index
positions = V.generate id
data Stream a = Cons a (Stream a)
instance Functor Stream where
fmap f (Cons a as) = Cons (f a) (fmap f as)
sIndex :: Stream a -> Natural -> a
sIndex (Cons a _ ) 0 = a
sIndex (Cons _ as) n = sIndex as (n-1)
instance Representable Stream where
type Log Stream = Natural
index = sIndex
positions = fix $ \xs -> Cons 0 (fmap (+1) xs)
instance (Representable f, Representable g) => Representable (Product f g) where
type Log (Product f g) = Either (Log f) (Log g)
index (Pair fa _) (Left i) = index fa i
index (Pair _ ga) (Right i) = index ga i
positions = Pair (tabulate Left) (tabulate Right)
instance (Representable f, Representable g) => Representable (Compose f g) where
type Log (Compose f g) = (Log f, Log g)
index fga (i, j) = (`index` j) . (`index` i) $ getCompose fga
tabulate = Compose . tabulate . fmap tabulate . curry
type family Tensor (xs :: [Nat]) = r | r -> xs where
Tensor '[] = Identity
Tensor (n ': ns) = Compose (Tensor ns) (V.Vector n)
type Vector n a = Tensor '[n] a
type Matrix m n a = Tensor '[n, m] a
class FromList f where
fromList :: [a] -> f a
instance FromList Identity where
fromList [a] = Identity a
instance forall n t. (KnownNat n, FromList t) => FromList (Compose t (V.Vector n)) where
fromList as =
let n = fromIntegral $ natVal (Proxy @n)
in Compose . fromList . map (fromJust . V.fromList) $ chunksOf n as
liftR2 :: Representable f => (a -> b -> c) -> f a -> f b -> f c
liftR2 f fa fb = tabulate $ \i -> f (index fa i) (index fb i)
dot :: (KnownNat n, Num a) => Vector n a -> Vector n a -> a
dot = (sum.) . liftR2 (*)
transpose :: (KnownNat m, KnownNat n) => Matrix m n a -> Matrix n m a
transpose (Compose (Compose xs)) = Compose (Compose (fmap distribute xs))
matmul :: (KnownNat p, KnownNat q, KnownNat r, Num a)
=> Matrix p q a -> Matrix q r a -> Matrix p r a
matmul (Compose a) b' =
let (Compose b) = transpose b'
vec = Compose . Identity
in tabulate $ \(((), j), i) -> (vec $ index a ((), j)) `dot` (vec $ index b ((), i))
unary :: Functor (Tensor ns) => (a -> b) -> Tensor ns a -> Tensor ns b
unary = fmap
class Shapely ns where
replicateT :: a -> Tensor ns a
instance Shapely '[] where
replicateT a = Identity a
instance (KnownNat n, Shapely ns, Representable (Tensor ns)) => Shapely (n ': ns) where
replicateT a = Compose (replicateT (tabulate (const a)))
type family Max xs ys where
Max '[] '[] = '[]
Max (x:xs) '[] = x ': Max xs '[]
Max '[] (y:ys) = y ': Max ys '[]
Max (x:xs) (y:ys) = (If (x <=? y) y x) ': Max xs ys
class Alignable (ns :: [Nat]) (ms :: [Nat]) where
align :: Tensor ns a -> Tensor ms a
instance Alignable '[] '[] where
align = id
instance Alignable xs ys => Alignable (x ': xs) (x ': ys) where
align (Compose t) = Compose (align t)
instance (KnownNat y, Functor (Tensor xs), Alignable xs ys) => Alignable (1 ': xs) (y ': ys) where
align =
let y = natVal (Proxy @y)
in Compose . align . fmap (V.replicate . V.head) . getCompose
instance (KnownNat n, Shapely ns, Representable (Tensor ns)) => Alignable '[] (n ': ns) where
align = replicateT . runIdentity
binary :: (Max xs ys ~ zs, Alignable xs zs, Alignable ys zs, Representable (Tensor zs))
=> (a -> b -> c)
-> Tensor xs a -> Tensor ys b -> Tensor zs c
binary f xs ys = liftR2 f (align xs) (align ys)
main :: IO ()
main = print "Hello Representable Functor"
...
dependencies:
- finite-typelits
- split
- vector-sized
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment