Created
December 4, 2018 23:07
-
-
Save JakobBruenker/7b31d56b200c00856d15cf94e3eff899 to your computer and use it in GitHub Desktop.
An interface to accelerate that includes matrix sizes in the types. Needs the accelerate and singletons libraries.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ViewPatterns #-} | |
{-# LANGUAGE DataKinds #-} | |
-- {-# LANGUAGE IncoherentInstances #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE TypeSynonymInstances #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE AllowAmbiguousTypes #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
module LinearTypesafe where | |
import TypesafeAccelerate | |
import Prelude hiding (zipWith, Num, replicate, (++)) | |
import Data.Array.Accelerate (Num, unindex1, index2, constant) | |
import Data.Proxy | |
import Data.Singletons.Prelude | |
import Data.Singletons.TypeLits | |
-- import qualified Data.Semigroup as S | |
-- inner product | |
infixr 8 <.> | |
(<.>) :: Num e => Vector n e -> Vector n e -> Scalar e | |
v <.> u = fold (+) 0 $ zipWith (*) v u | |
--outer product | |
infixr 8 >< | |
(><) :: forall n m e. (KnownNat n, KnownNat m, Num e) | |
=> Vector n e -> Vector m e -> Matrix n m e | |
a >< b = (reshape a :: Matrix n 1 e) <> reshape b | |
-- -- outer product | |
-- infixr 8 >< | |
-- (><) :: Acc (Vector Float) -> Acc (Vector Float) -> Acc (Matrix Float) | |
-- a >< b = reshape (index2 (length a) 1) a <> | |
-- reshape (index2 1 (length b)) b | |
-- matrix vector product | |
infixr 8 #> | |
(#>) :: forall n m e. (Num e, KnownNat m, KnownNat n) | |
=> Matrix n m e -> Vector m e -> Vector n e | |
a #> v = fold (+) 0 $ zipWith (*) a (replicate ss v :: Matrix n m e) | |
where ss = Proxy :: Proxy [SN n, SAll] | |
-- vector matrix product | |
infixr 8 <# | |
(<#) :: forall n m e. (Num e, KnownNat m, KnownNat n) | |
=> Vector n e -> Matrix n m e -> Vector m e | |
v <# a = fold (+) 0 $ zipWith (*) (replicate ss v) $ transpose a | |
where ss = Proxy :: Proxy [SN m, SAll] | |
-- TODO make foldseg typesafe? | |
-- FIXME did not expect this but `td <> tc` results in "ssFromSsd: Illegal | |
-- combination of slice and shape" | |
infixr 8 <> | |
(<>) :: forall n m o e. | |
(KnownNat n, KnownNat m, KnownNat o, KnownNat (o :* m), Num e) | |
=> Matrix n m e -> Matrix m o e -> Matrix n o e | |
a <> b = foldSeg (+) 0 mul seg | |
-- The type signatures are not actually necessary, but nonetheless helpful | |
-- for documentation | |
where mul :: Matrix n (o :* m) e | |
mul = zipWith (*) repa repb | |
repa :: Matrix n (o :* m) e | |
repa = reshape $ replicate (Proxy :: Proxy [SAll, SN o, SAll]) a | |
repb :: Matrix n (o :* m) e | |
repb = replicate (Proxy :: Proxy [SN n, SAll]) . flatten $ transpose b | |
seg :: Vector o Int | |
seg = fill . constant . snd $ matrixShape a | |
-- -- matrix matrix product | |
-- -- if columns of a aren't equal in number to rows of b, the larger | |
-- -- array is cropped so they match | |
-- infixr 8 <> | |
-- (<>) :: Acc (Matrix Float) -> Acc (Matrix Float) -> Acc (Matrix Float) | |
-- a <> b = foldSeg (+) 0 mul seg | |
-- where (ra, ca) = let rca = unindex2 $ shape a in (fst rca, snd rca) | |
-- (rb, cb) = let rcb = unindex2 $ shape b in (fst rcb, snd rcb) | |
-- len = min rb ca | |
-- a' = take len a | |
-- b' = transpose . take len $ transpose b | |
-- mul = zipWith (*) repa repb | |
-- repb = replicate (lift $ Z :. ra :. All) . flatten $ transpose b' | |
-- repa = reshape (index2 ra $ len * cb) $ | |
-- replicate (lift $ Z :. All :. cb :. All) a' | |
-- seg = fill (index1 cb) len :: Acc (Segments Int) | |
-- XXX Do we want this? If so, this is an orphan instance, so either newtype | |
-- it or rethink the modules | |
-- type Square n e = Matrix n n e | |
-- Not sure if these instance are useful | |
-- instance (KnownNat n, KnownNat (n :* n), Num e) => S.Semigroup (Square n e) where | |
-- (<>) = (<>) | |
-- instance (KnownNat n, KnownNat (n :* n), Num e) => Monoid (Square n e) where | |
-- mempty = identity | |
-- mappend = (S.<>) | |
-- diagonal matrix | |
diagonal :: (Num e, KnownNat n, KnownNat m) | |
=> Vector (Min n m) e -> Matrix n m e | |
diagonal v = permute const zeros (\(unindex1 -> i) -> index2 i i) v | |
-- synonym for identity | |
eye :: (Num e, KnownNat n, KnownNat m, KnownNat (Min n m)) => Matrix n m e | |
eye = identity | |
-- identity matrix | |
identity :: (Num e, KnownNat n, KnownNat m, KnownNat (Min n m)) | |
=> Matrix n m e | |
identity = diagonal ones | |
zeros :: (Num e, KnownShape dims, ShapeLike dims) => Tensor dims e | |
zeros = fill 0 | |
ones :: (Num e, KnownShape dims, ShapeLike dims) => Tensor dims e | |
ones = fill 1 | |
-- TODO Right now these are defined for matrices, will have to generalize to | |
-- tensors | |
-- if the new dimension is larger by an uneven amount, the padding will be | |
-- larger on the upper/left side | |
-- zeroPad :: forall n m n' m' e. Matrix n m e -> Matrix n' m' e | |
-- zeroPad :: forall n m n' m' e. Matrix n m e -> Matrix n m' e | |
-- zeroPad a = (zeros :: Matrix n (Div2 (m' :- m)) e) ++ a ++ zeros | |
-- zeroPad a = (zeros :: Matrix n (Minus m' m) e) ++ a | |
zeroPadL :: forall n m l e. (KnownNat l, KnownNat n, Num e) | |
=> Matrix n m e -> Matrix n (l :+ m) e | |
zeroPadL a = (zeros :: Matrix n l e) ++ a | |
zeroPadR :: forall n m r e. (KnownNat r, KnownNat n, Num e) | |
=> Matrix n m e -> Matrix n (m :+ r) e | |
zeroPadR a = a ++ (zeros :: Matrix n r e) | |
zeroPadT :: forall n m t e. (KnownNat t, KnownNat m, Num e) | |
=> Matrix n m e -> Matrix (t :+ n) m e | |
zeroPadT a = transpose $ (zeros :: Matrix m t e) ++ transpose a | |
zeroPadB :: forall n m b e. (KnownNat b, KnownNat m, Num e) | |
=> Matrix n m e -> Matrix (n :+ b) m e | |
zeroPadB a = transpose $ transpose a ++ (zeros :: Matrix m b e) | |
zeroPadAll :: forall n m k e. | |
(KnownNat n, KnownNat k, KnownNat (k :+ m :+ k), Num e) | |
=> Proxy k -> Matrix n m e | |
-> Matrix (k :+ n :+ k) (k :+ m :+ k) e | |
-- -> Matrix n (k :+ m :+ k) e | |
zeroPadAll _ a = zeroPadB (zeroPadT (zeroPadR (zeroPadL a :: Matrix n (k :+ m) e) :: Matrix n (k :+ m :+ k) e) :: Matrix (k :+ n) (k :+ m :+ k) e) :: Matrix (k :+ n :+ k) (k :+ m :+ k) e | |
-- TODO | |
-- im2col | |
-- XXX padding? | |
-- XXX using generate appears to be a very bad implementation, since you | |
-- probably cannot parallelize it | |
-- conv2d :: Matrix n m e -> Matrix h w e -> Matrix n m e | |
-- conv2d = reshape $ kernelMat <> imMat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE NoImplicitPrelude #-} | |
{-# LANGUAGE KindSignatures #-} | |
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE AllowAmbiguousTypes #-} | |
{-# LANGUAGE PolyKinds #-} | |
-- {-# LANGUAGE TypeInType #-} | |
{-# LANGUAGE StandaloneDeriving #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE ViewPatterns #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE TypeFamilyDependencies #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
{-# LANGUAGE ConstraintKinds #-} | |
{-# LANGUAGE MultiParamTypeClasses #-} | |
module TypesafeAccelerate where | |
import Prelude (Show, ($), show, (<$>), error, otherwise) | |
import qualified Prelude as P | |
import Data.Kind | |
import Data.Singletons | |
import Data.Singletons.Prelude hiding ((:.), Reverse) | |
import Data.Singletons.Prelude.Enum | |
import Data.Singletons.TypeLits | |
import GHC.TypeLits | |
import qualified Data.Array.Accelerate as A | |
import Data.Array.Accelerate (Array, Acc, Exp, Double, DIM2, Elt, Z(..), Shape, Int, Bool(..), (:.)(..), Num, (.), (+), fromList, use, constant, fromInteger, IsIntegral, FromIntegral, Slice, (?|), Arrays, const, SliceShape, FullShape, Ord) | |
-- TODO To work with -XRebindableSyntax, this module has to export these | |
-- functions, or at the very least the ones that aren't identical to their | |
-- Prelude versions: | |
-- - fromInteger (identical to Prelude) | |
-- - fromRational (identical to Prelude) | |
-- - (==) | |
-- - (-) (identical to Prelude) | |
-- - (>=) | |
-- - negate (identical to Prelude) | |
-- - ifThenElse | |
-- TODO think about (re-)exports in general | |
data Slicer :: Type where | |
SN :: (n :: Type) -> Slicer | |
SAll :: Slicer | |
SAny :: Slicer | |
data instance Sing (a :: Slicer) where | |
SSN :: KnownNat n => Sing n -> Sing (SN n) | |
SSAll :: Sing SAll | |
SSAny :: Sing SAny | |
type SSlicer (a :: Slicer) = Sing a | |
instance SingI SAny where | |
sing = SSAny | |
instance SingI SAll where | |
sing = SSAll | |
-- instance SingI n => SingI (SN n) where | |
instance KnownNat n => SingI (SN n) where | |
sing = SSN sing | |
data SliceMode = Sliced | Replicated | |
type family IsReplicated (mode :: SliceMode) where | |
IsReplicated Replicated = True | |
IsReplicated Sliced = False | |
type Slicing ss dims = Chopping Sliced ss dims | |
type Replicating ss dims = Chopping Replicated ss dims | |
type SlicedShape ss dims = ChoppedShape Sliced ss dims | |
type ReplicatedShape ss dims = ChoppedShape Replicated ss dims | |
type ShapeLike dims = Shape (ShapeOf dims) | |
-- TODO: add nice errors, but be careful because cutting is also used in | |
-- ChoppedShaped | |
-- For any Slicer | |
type family Chopping (mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat]) | |
:: Constraint | |
where | |
Chopping _ '[] '[] = () | |
Chopping mode (SAny : ss) ds = (Cleaving mode ss ds ~ True) | |
Chopping mode ss ds = (Cutting mode ss ds ~ True) | |
-- For Slicers not containing SAny | |
type family Cutting (mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat]) | |
where | |
Cutting _ '[] '[] = True | |
-- Cutting mode (SN n : ss) (d : ds) = | |
-- (IsReplicated mode :|| n :< d) :&& Cutting mode ss ds | |
Cutting Sliced (SN n : ss) (d : ds) = n :< d :&& Cutting Sliced ss ds | |
Cutting Replicated (SN n : ss) ds = Cutting Replicated ss ds | |
Cutting mode (SAll : ss) (_ : ds) = Cutting mode ss ds | |
Cutting _ _ _ = False | |
-- For Slicers containing Any | |
type family Cleaving (mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat]) | |
where | |
Cleaving _ '[] _ = True | |
Cleaving mode ss '[] = Cutting mode ss '[] | |
-- Cleaving mode (SN n : ss) (d : ds) = | |
-- (IsReplicated mode :|| n :< d) :&& Cutting mode ss ds :|| | |
-- Cleaving mode (SN n : ss) ds | |
Cleaving Sliced (SN n : ss) (d : ds) = | |
n :< d :&& Cutting Sliced ss ds :|| Cleaving Sliced (SN n : ss) ds | |
Cleaving Replicated (SN n : ss) (d : ds) = | |
Cutting Replicated ss (d : ds) :|| Cleaving Replicated (SN n : ss) ds | |
Cleaving mode (SAll : ss) (_ : ds) = | |
Cutting mode ss ds :|| Cleaving mode (SAll : ss) ds | |
Cleaving _ _ _ = False | |
-- For any Slicer | |
-- NB: it might make sense to add a catch-all pattern, but it doesn't seem | |
-- strictly necessary, especially considering there really *isn't* a valid | |
-- 'ChoppedShape' if none of these patterns match | |
-- TODO Probably add a type error in that place though | |
type family ChoppedShape | |
(mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat]) | |
where | |
ChoppedShape mode '[] '[] = '[] | |
ChoppedShape mode (SAny : ss) ds = CleftShape mode ss ds | |
ChoppedShape mode ss ds = CutShape mode ss ds | |
type family CutShape (mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat]) | |
where | |
CutShape _ '[] '[] = '[] | |
CutShape Sliced (SN _ : ss) (_ : ds) = CutShape Sliced ss ds | |
-- CutShape Replicated (SN n : ss) (d : ds) = | |
-- (n :* d) : CutShape Replicated ss ds | |
CutShape Replicated (SN n : ss) ds = | |
n : CutShape Replicated ss ds | |
CutShape mode (SAll : ss) (d : ds) = d : CutShape mode ss ds | |
type family CleftShape | |
(mode :: SliceMode) (ss :: [Slicer]) (dims :: [Nat]) | |
where | |
CleftShape _ '[] _ = '[] | |
-- TODO is this for Sliced too or just for replicated | |
CleftShape mode ss '[] = CutShape mode ss '[] | |
CleftShape mode (SN n : ss) (d : ds) = | |
If (Cutting mode (SN n : ss) (d : ds)) | |
(CutShape mode (SN n : ss) (d : ds)) | |
(d : CleftShape mode (SN n : ss) ds) | |
CleftShape mode (SAll : ss) (d : ds) = | |
If (Cutting mode (SAll : ss) (d : ds)) | |
(CutShape mode (SAll : ss) (d : ds)) | |
(d : CleftShape mode (SAll : ss) ds) | |
-- type level reverse | |
-- this is used instead of the version provided by singletons so that in type | |
-- signatures, it shows up as "Reverse xs", and not as something like | |
-- "Data.Singletons.Prelude.List.Let6989586621679748992Rev xs xs '[]" | |
-- Similarly, case splitting prevents ghc from applying function on unknown | |
-- list | |
type family Reverse (xs :: [a]) = (rs :: [a]) where | |
Reverse '[] = '[] | |
Reverse xs = Rev xs '[] | |
type family Rev (xs :: [a]) (acc :: [a]) = (rs :: [a]) where | |
Rev '[] acc = acc | |
Rev (x:xs) acc = Rev xs (x:acc) | |
-- XXX maybe have this be ATensor or AccTensor and introduce second type | |
-- Tensor, which has Arrays that are not in Acc | |
data Tensor :: [Nat] -> Type -> Type where | |
Tensor :: Acc (Array (ShapeOf dims) e) -> Tensor dims e | |
type Scalar = Tensor '[] | |
type Vector n = Tensor '[n] | |
type Matrix n m = Tensor '[n, m] | |
-- We have to use this instead of record syntax because, because here we | |
-- can specify the type | |
unTensor :: Tensor dims e -> Acc (Array (ShapeOf dims) e) | |
unTensor (Tensor t) = t | |
-- | 'unsafeMakeTensor' unsafely creates a Tensor from an Accelerate array, | |
-- but does not check whether the size of the array match the dimensions | |
-- in the type of the 'Tensor'. It does check whether the dimensionality of | |
-- the size is identical. | |
unsafeMakeTensor :: Acc (Array (ShapeOf dims) e) -> Tensor dims e | |
unsafeMakeTensor = Tensor | |
-- | 'unsafeUseArray' creates a Tensor from an accelerate array. If the size | |
-- is incorrect, it will throw an error at runtime | |
unsafeUseArray :: forall dims e sh. | |
(KnownShape dims, sh ~ ShapeOf dims, | |
Shape sh, Elt e, P.Eq sh) | |
=> Array (ShapeOf dims) e -> Tensor dims e | |
unsafeUseArray array | |
| arrayShape P.== tensorShape = Tensor (A.use array) | |
| otherwise = error $ "Couldn't match expected shape " P.++ | |
show tensorShape P.++ " with actual shape " P.++ show arrayShape | |
where arrayShape = A.arrayShape array | |
tensorShape = shFromDims (Proxy :: Proxy dims) | |
type family IntShape (dims :: [Nat]) where | |
IntShape '[] = Z | |
IntShape (_:ds) = IntShape ds :. Int | |
type ShapeOf (dims :: [Nat]) = IntShape (Reverse dims) | |
type KnownShape dims = SingI (Reverse dims) | |
type KnownSlicer slcr = SingI (Reverse slcr) | |
-- IntShapeData is a type equivalent to the following data family, if it were | |
-- closed: | |
-- | |
-- data family IntShapeData (dims :: [Nat]) | |
-- data instance IntShapeData '[] = RISDNil Z | |
-- data instance IntShapeData (_:ds) = RISDCons (IntShapeData ds :. Int) | |
-- | |
-- This makes the similarities to the IntShape type family more obvious. | |
-- The reason this is necessary is because IntShape isn't injective, but this | |
-- injectivity is required for shFromDims. | |
data IntShapeData :: [Nat] -> Type where | |
RISDNil :: Z -> IntShapeData '[] | |
RISDCons :: (IntShapeData ds :. Int) -> IntShapeData (d:ds) | |
type family SlicerShape (slcr :: [Slicer]) (dims :: [Nat]) where | |
SlicerShape '[] _ = Z | |
SlicerShape '[SAny] ds = A.Any (IntShape ds) | |
SlicerShape (SN _ : ss) ds = SlicerShape ss ds :. Int | |
SlicerShape (SAll : ss) (_ : ds) = SlicerShape ss ds :. A.All | |
-- TODO convert to GADT | |
data family SlicerShapeData (slcr :: [Slicer]) (dims :: [Nat]) | |
data instance SlicerShapeData '[] _ = SSDNil Z | |
data instance SlicerShapeData '[SAny] ds = SSDAny (A.Any (IntShape ds)) | |
data instance SlicerShapeData (SN _ : ss) ds = | |
SSDConsN (SlicerShapeData ss ds :. Int) | |
data instance SlicerShapeData (SAll : ss) (_ : ds) = | |
SSDConsAll (SlicerShapeData ss ds :. A.All) | |
instance Show (SlicerShapeData '[] ns) where | |
show (SSDNil Z) = "SSDNil Z" | |
instance Show (SlicerShapeData '[SAny] ns) where | |
show (SSDAny A.Any) = "SSDAny A.Any" | |
instance (Show (SlicerShapeData ss ns)) | |
=> Show (SlicerShapeData (SAll : ss) (n : ns)) where | |
show (SSDConsAll (ssd :. A.All)) = "SSDConsAll (" P.++ show ssd P.++ | |
" :. A.All)" | |
instance (Show (SlicerShapeData ss ns)) => Show (SlicerShapeData (SN n : ss) ns) where | |
show (SSDConsN (ssd :. d)) = "SSDConsN (" P.++ | |
show ssd P.++ " :. " P.++ show d P.++ ")" | |
type SliceOf (slcr :: [Slicer]) (dims :: [Nat]) = | |
SlicerShape (Reverse slcr) (Reverse dims) | |
isdFromDims :: forall dims. KnownShape dims | |
=> Proxy dims -> IntShapeData (Reverse dims) | |
isdFromDims _ = case sing :: SList (Reverse dims) of | |
SNil -> RISDNil Z | |
SCons (fromSing -> d) (singInstance -> SingInstance) -> | |
RISDCons (risdFromDims :. fromInteger d) | |
risdFromDims :: forall dims. SingI dims => IntShapeData dims | |
risdFromDims = case sing :: SList dims of | |
SNil -> RISDNil Z | |
SCons (fromSing -> d) (singInstance -> SingInstance) -> | |
RISDCons (risdFromDims :. fromInteger d) | |
isFromIsd :: forall dims. SingI dims | |
=> IntShapeData dims -> IntShape dims | |
isFromIsd shd = case (sing :: SList dims, shd) of | |
(SNil, RISDNil Z) -> Z | |
(SCons _ (singInstance -> SingInstance), | |
RISDCons (risd :. d)) -> | |
isFromIsd risd :. d | |
shFromDims :: KnownShape dims => Proxy dims -> ShapeOf dims | |
shFromDims = isFromIsd . isdFromDims | |
ssFromSsd :: forall dims slcr. (SingI dims, SingI slcr) | |
=> SlicerShapeData slcr dims -> SlicerShape slcr dims | |
ssFromSsd sld = | |
case (sing :: SList slcr, sing :: SList dims, sld) of | |
(SNil, _, SSDNil Z) -> Z | |
(SCons SSAny SNil, _, SSDAny A.Any) -> A.Any | |
(SCons SSAll (singInstance -> SingInstance), | |
SCons _ (singInstance -> SingInstance), | |
SSDConsAll (ssd :. A.All)) -> | |
ssFromSsd ssd :. A.All | |
(SCons (SSN _) (singInstance -> SingInstance), _, SSDConsN (rssd :. d)) -> | |
ssFromSsd rssd :. d | |
-- This should never happen due to the 'Chopping' constraint in | |
-- 'shFromSlicer' | |
_ -> error "LinearTypesafe.ssFromSsd: Illegal combination of slice \ | |
\and shape" | |
rssdFromSlicer :: forall slcr dims. (KnownSlicer slcr, KnownShape dims) | |
=> Proxy slcr -> Proxy dims | |
-> SlicerShapeData (Reverse slcr) (Reverse dims) | |
rssdFromSlicer _ _ = | |
case (sing :: SList (Reverse slcr), sing :: SList (Reverse dims)) of | |
(SNil, SNil) -> SSDNil Z | |
(SCons SSAny SNil, _) -> SSDAny A.Any | |
(SCons SSAll (singInstance -> SingInstance), | |
SCons _ (singInstance -> SingInstance)) -> | |
SSDConsAll (ssdFromSlicer Proxy Proxy :. A.All) | |
(SCons (SSN (fromSing -> d)) (singInstance -> SingInstance), _) -> | |
SSDConsN (ssdFromSlicer Proxy Proxy :. fromInteger d) | |
-- This should never happen due to the 'Chopping' constraint in | |
-- 'shFromSlicer' | |
_ -> error "LinearTypesafe.rssdFromSlicer: Illegal combination of \ | |
\slice and shape" | |
ssdFromSlicer :: forall slcr dims. (SingI slcr, SingI dims) | |
=> Proxy slcr -> Proxy dims -> SlicerShapeData slcr dims | |
ssdFromSlicer _ _ = | |
case (sing :: SList slcr, sing :: SList dims) of | |
(SNil, _) -> SSDNil Z | |
(SCons SSAny SNil, _) -> SSDAny A.Any | |
(SCons SSAll (singInstance -> SingInstance), | |
SCons _ (singInstance -> SingInstance)) -> | |
SSDConsAll (ssdFromSlicer Proxy Proxy :. A.All) | |
(SCons (SSN (fromSing -> d)) (singInstance -> SingInstance), _) -> | |
SSDConsN (ssdFromSlicer Proxy Proxy :. fromInteger d) | |
-- This should never happen due to the 'Chopping' constraint in | |
-- 'shFromSlicer' | |
_ -> error "LinearTypesafe.ssdFromSlicer: Illegal combination of \ | |
\slice and shape" | |
shFromSlicer :: (KnownShape dims, KnownSlicer slcr, | |
Chopping mode slcr dims) | |
=> Proxy mode -> Proxy slcr -> Proxy dims -> SliceOf slcr dims | |
shFromSlicer _ = (ssFromSsd .) . rssdFromSlicer | |
instance (Show e, SingI dims, ShapeLike dims, Elt e) | |
=> Show (Tensor dims e) where | |
show t@(Tensor a) = | |
"Tensor " P.++ show (shapeInts t) P.++ " (" P.++ show a P.++ ")" | |
type family Length (list :: [a]) :: Nat where | |
Length '[] = 0 | |
Length (_:xs) = Succ (Length xs) | |
type family AllLess (idx :: [Nat]) (dims :: [Nat]) :: Constraint where | |
AllLess is ds = If (Length is :== Length ds) (AllLess' is ds is ds) | |
(TypeError (ShowType is :<>: Text " and " :<>: ShowType ds :<>: | |
Text " must have the same length, but" :$$: | |
ShowType is :<>: Text " has length " :<>: ShowType (Length is) :$$: | |
ShowType ds :<>: Text " has length " :<>: ShowType (Length ds))) | |
type family AllLess' (idx :: [Nat]) (dims :: [Nat]) idx' dims' :: Constraint | |
where | |
AllLess' '[] '[] _ _ = () | |
AllLess' (i:is) (d:ds) is' ds' = If (i :< d) (AllLess' is ds is' ds') | |
(TypeError | |
(ShowType i :<>: Text " is not less than " :<>: ShowType d :$$: | |
Text "While comparing " :<>: ShowType is' :<>: | |
Text " and " :<>: ShowType ds')) | |
AllLess' is ds is' ds' = TypeError (Text "Unexpected case in AllLess'") | |
type family Product (ns :: [Nat]) :: Nat where | |
Product '[] = 1 | |
Product (n:ns) = n :* Product ns | |
(!) :: (ShapeOf dims ~ ShapeOf idx, ShapeLike idx, | |
AllLess idx dims, KnownShape idx, Elt e, Elt (ShapeOf idx)) | |
=> Tensor dims e -> Proxy idx -> Exp e | |
Tensor t ! p = t A.! (constant $ shFromDims p) | |
unsafeIndex :: (Elt e, ShapeLike dims) | |
=> Tensor dims e -> Exp (ShapeOf dims) -> Exp e | |
unsafeIndex (Tensor t) = (t A.!) | |
(!!) :: forall idx dims e. | |
(SingI idx, (Product dims :> idx) ~ True, | |
ShapeLike dims, Elt e) | |
=> Tensor dims e -> Proxy idx -> Exp e | |
Tensor t !! _ = t A.!! (constant . fromInteger $ fromSing (sing :: SNat idx)) | |
unsafeLinearIndex :: (ShapeLike dims, Elt e) | |
=> Tensor dims e -> Exp Int -> Exp e | |
unsafeLinearIndex (Tensor t) = (t A.!!) | |
the :: Elt e => Scalar e -> Exp e | |
the (Tensor t) = A.the t | |
null :: forall dims e. SingI (Product dims) => Tensor dims e -> Bool | |
null _ = 0 P.== fromSing (sing :: SNat (Product dims)) | |
length :: SingI n => Vector n e -> Int | |
length = size | |
shape :: forall dims e. KnownShape dims => Tensor dims e -> ShapeOf dims | |
shape _ = shFromDims (Proxy :: Proxy dims) | |
shapeInts :: forall dims e. SingI dims => Tensor dims e -> [Int] | |
shapeInts _ = fromInteger <$> fromSing (sing :: SList dims) | |
-- this is provided so total functions can be used to index into the pair | |
matrixShape :: forall n m e. (SingI n, SingI m) => Matrix n m e -> (Int, Int) | |
matrixShape _ = (toInt (sing :: SNat n), toInt (sing :: SNat m)) | |
where toInt = fromInteger . fromSing | |
size :: forall dims e. SingI (Product dims) => Tensor dims e -> Int | |
size _ = shapeSize (Proxy :: Proxy dims) | |
shapeSize :: forall dims. SingI (Product dims) => Proxy dims -> Int | |
shapeSize _ = fromInteger $ fromSing (sing :: SNat (Product dims)) | |
-- XXX use? not sure how to handle this | |
unit :: Elt e => Exp e -> Scalar e | |
unit = Tensor . A.unit | |
generate :: forall dims e sh. | |
(sh ~ ShapeOf dims, Shape sh, Elt e, KnownShape dims) | |
=> (Exp sh -> Exp e) -> Tensor dims e | |
generate f = | |
Tensor $ A.generate (constant $ shFromDims (Proxy :: Proxy dims)) f | |
fill :: forall dims e. (ShapeLike dims, KnownShape dims, Elt e) | |
=> Exp e -> Tensor dims e | |
fill x = generate (const x) | |
enumFromN :: forall dims e. | |
(ShapeLike dims, KnownShape dims, | |
FromIntegral Int e, Num e) | |
=> Exp e -> Tensor dims e | |
enumFromN = Tensor . A.enumFromN (constant $ shFromDims (Proxy :: Proxy dims)) | |
enumFromStepN :: forall dims e. | |
(ShapeLike dims, KnownShape dims, | |
FromIntegral Int e, Num e) | |
=> Exp e -> Exp e -> Tensor dims e | |
enumFromStepN = | |
(Tensor .) . A.enumFromStepN (constant $ shFromDims (Proxy :: Proxy dims)) | |
infixr 5 ++ | |
(++) :: (ShapeOf das ~ (sh :. Int), ShapeOf dbs ~ (sh :. Int), | |
ShapeOf dcs ~ (sh :. Int), Init das ~ Init dbs, | |
dcs ~ (Init das :++ '[Last das :+ Last dbs]), | |
Slice sh, Shape sh, Elt e) | |
=> Tensor das e -> Tensor dbs e -> Tensor dcs e | |
Tensor s ++ Tensor t = Tensor $ s A.++ t | |
-- TODO: once using accelerate 1.2, add concatOn and the other lens | |
-- functions | |
class IfThenElse bool a where | |
ifThenElse :: bool -> a -> a -> a | |
instance IfThenElse Bool a where | |
ifThenElse True a _ = a | |
ifThenElse False _ b = b | |
instance (ShapeLike dims, Elt a) | |
=> IfThenElse (Exp Bool) (Tensor dims a) where | |
ifThenElse bool (Tensor s) (Tensor t) = Tensor (bool ?| (s, t)) | |
instance Arrays a => IfThenElse (Exp Bool) (Acc a) where | |
ifThenElse bool a b = bool ?| (a, b) | |
-- XXX THESE ARE ACTUALLY UNSAFE! (just the lifts, the unlifts are safe. I | |
-- think. It's a bit strange because once you've done it, you can feed in any | |
-- array, even if the size is incorrect. But that doesn't really mean the | |
-- function itself is unsafe, I think?, Anyway, this could be 'fixed' by | |
-- adding KnownShape constraints to unlift, and then throwing an error if the | |
-- size is incorrect. I'm just not sure if it's a good idea to do that.) | |
-- You can give the result any dimension drs if you use these, so be careful | |
liftT :: (ShapeLike das, ShapeLike drs) | |
=> (Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf drs) r)) | |
-> Tensor das a -> Tensor drs r | |
liftT f = \(Tensor t) -> Tensor (f t) | |
liftT2 :: (ShapeLike das, ShapeLike dbs, ShapeLike drs) | |
=> (Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf dbs) b) | |
-> Acc (Array (ShapeOf drs) r)) | |
-> Tensor das a -> Tensor dbs b -> Tensor drs r | |
liftT2 f = \(Tensor s) (Tensor t) -> Tensor (f s t) | |
liftT3 :: (ShapeLike das, ShapeLike dbs, ShapeLike dcs, ShapeLike drs) | |
=> (Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf dbs) b) | |
-> Acc (Array (ShapeOf dcs) c) -> Acc (Array (ShapeOf drs) r)) | |
-> Tensor das a -> Tensor dbs b -> Tensor dcs c -> Tensor drs r | |
liftT3 f = \(Tensor s) (Tensor t) (Tensor u) -> Tensor (f s t u) | |
unliftT :: (ShapeLike das, ShapeLike drs, Elt a, Elt r) | |
=> (Tensor das a -> Tensor drs r) | |
-> Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf drs) r) | |
unliftT f = \a -> case f (Tensor a) of Tensor t -> t | |
unliftT2 :: (ShapeLike das, ShapeLike dbs, ShapeLike drs, Elt a, Elt b, Elt r) | |
=> (Tensor das a -> Tensor dbs b -> Tensor drs r) | |
-> Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf dbs) b) | |
-> Acc (Array (ShapeOf drs) r) | |
unliftT2 f = \a b -> case f (Tensor a) (Tensor b) of Tensor t -> t | |
unliftT3 :: (ShapeLike das, ShapeLike dbs, ShapeLike dcs, ShapeLike drs, | |
Elt a, Elt b, Elt c, Elt r) | |
=> (Tensor das a -> Tensor dbs b -> Tensor dcs c -> Tensor drs r) | |
-> Acc (Array (ShapeOf das) a) -> Acc (Array (ShapeOf dbs) b) | |
-> Acc (Array (ShapeOf dcs) c) -> Acc (Array (ShapeOf drs) r) | |
unliftT3 f = \a b c -> case f (Tensor a) (Tensor b) (Tensor c) of | |
Tensor t -> t | |
(>->) :: (ShapeLike das, ShapeLike dbs, ShapeLike dcs, Elt a, Elt b, Elt c) | |
=> (Tensor das a -> Tensor dbs b) -> (Tensor dbs b -> Tensor dcs c) | |
-> Tensor das a -> Tensor dcs c | |
f >-> g = \(Tensor t) -> Tensor $ (unliftT f A.>-> unliftT g) t | |
compute :: (ShapeLike dims, Elt e) => Tensor dims e -> Tensor dims e | |
compute (Tensor t) = Tensor (A.compute t) | |
indexed :: (ShapeLike dims, Elt e) | |
=> Tensor dims e -> Tensor dims (ShapeOf dims, e) | |
indexed (Tensor t) = Tensor (A.indexed t) | |
map :: (Elt b, ShapeLike dims, Elt a) | |
=> (Exp a -> Exp b) -> Tensor dims a -> Tensor dims b | |
map = liftT . A.map | |
imap :: (sh ~ ShapeOf dims, Elt a, Elt b, Shape sh) | |
=> (Exp sh -> Exp a -> Exp b) -> Tensor dims a -> Tensor dims b | |
imap = liftT . A.imap | |
zipWith :: (Elt a, Elt b, Elt r, ShapeLike dims) | |
=> (Exp a -> Exp b -> Exp r) | |
-> Tensor dims a -> Tensor dims b -> Tensor dims r | |
-- Technically we could use this simpler version, because we don't need to | |
-- intersect the shapes, but for now, I will stick with simply wrapping the | |
-- Accelerate functions | |
-- zipWith f (Tensor s) (Tensor t) = Tensor $ | |
-- generate (shape s) (\idx -> f (s!idx) (t!idx)) | |
zipWith = liftT2 . A.zipWith | |
zipWith3 :: (Elt a, Elt b, Elt c, Elt r, ShapeLike dims) | |
=> (Exp a -> Exp b -> Exp c -> Exp r) | |
-> Tensor dims a -> Tensor dims b -> Tensor dims c -> Tensor dims r | |
zipWith3 = liftT3 . A.zipWith3 | |
izipWith :: (sh ~ ShapeOf dims, Shape sh, Elt a, Elt b, Elt c, Elt r) | |
=> (Exp sh -> Exp a -> Exp b -> Exp r) | |
-> Tensor dims a -> Tensor dims b -> Tensor dims r | |
izipWith = liftT2 . A.izipWith | |
izipWith3 :: (sh ~ ShapeOf dims, Shape sh, Elt a, Elt b, Elt c, Elt r) | |
=> (Exp sh -> Exp a -> Exp b -> Exp c -> Exp r) | |
-> Tensor dims a -> Tensor dims b -> Tensor dims c -> Tensor dims r | |
izipWith3 = liftT3 . A.izipWith3 | |
zip :: (ShapeLike dims, Elt a, Elt b) | |
=> Tensor dims a -> Tensor dims b -> Tensor dims (a, b) | |
zip = liftT2 A.zip | |
zip3 :: (ShapeLike dims, Elt a, Elt b, Elt c) | |
=> Tensor dims a -> Tensor dims b -> Tensor dims c | |
-> Tensor dims (a, b, c) | |
zip3 = liftT3 A.zip3 | |
unzip :: (ShapeLike dims, (Elt a, Elt b)) | |
=> Tensor dims (a, b) | |
-> (Tensor dims a, Tensor dims b) | |
unzip (Tensor (A.unzip -> (s, t))) = (Tensor s, Tensor t) | |
unzip3 :: (ShapeLike dims, Elt a, Elt b, Elt c) | |
=> Tensor dims (a, b, c) | |
-> (Tensor dims a, Tensor dims b, Tensor dims c) | |
unzip3 (Tensor (A.unzip3 -> (s, t, u))) = (Tensor s, Tensor t, Tensor u) | |
reshape :: forall dims dims' e. | |
(Product dims ~ Product dims', ShapeLike dims, ShapeLike dims', | |
KnownShape dims', Elt e) | |
=> Tensor dims e -> Tensor dims' e | |
reshape = liftT $ A.reshape (constant $ shFromDims (Proxy :: Proxy dims')) | |
flatten :: (ShapeLike dims, Elt e) => Tensor dims e -> Vector (Product dims) e | |
flatten = liftT A.flatten | |
-- XXX it *might* be possible to get rid of the Proxy here, although I doubt it | |
replicate :: forall sl sh ss dims e. | |
(sl ~ SliceOf ss dims, sh ~ ShapeOf (ReplicatedShape ss dims), | |
ShapeOf dims ~ SliceShape sl, sh ~ FullShape sl, | |
KnownShape dims, KnownSlicer ss, Replicating ss dims, | |
Shape sh, Slice sl, Elt e) | |
=> Proxy ss -> Tensor dims e -> Tensor (ReplicatedShape ss dims) e | |
replicate _ = liftT $ A.replicate (constant $ shFromSlicer | |
(Proxy :: Proxy Replicated) | |
(Proxy :: Proxy ss) | |
(Proxy :: Proxy dims)) | |
slice :: forall sl ss dims dims' e. | |
(sl ~ SliceOf ss dims, dims' ~ SlicedShape ss dims, | |
SliceShape sl ~ ShapeOf dims', ShapeOf dims ~ FullShape sl, | |
KnownShape dims, KnownSlicer ss, Slicing ss dims, | |
Slice sl, Elt e) | |
=> Tensor dims e -> Proxy ss -> Tensor dims' e | |
slice (Tensor t) _ = Tensor $ A.slice t (constant $ shFromSlicer | |
(Proxy :: Proxy Sliced) | |
(Proxy :: Proxy ss) | |
(Proxy :: Proxy dims)) | |
-- XXX Should we explicitly make sure that Last dims is at least 1? Probably | |
-- not strictly necessary because it won't typecheck if that's not the case | |
-- anyway. | |
init :: (dims' ~ (Init dims :++ '[Pred (Last dims)]), | |
ShapeOf dims' ~ (sh :. Int), | |
ShapeOf dims ~ (sh :. Int), Slice sh, Shape sh, Elt e) | |
=> Tensor dims e -> Tensor dims' e | |
init = liftT A.init | |
tail :: (dims' ~ (Init dims :++ '[Pred (Last dims)]), | |
ShapeOf dims' ~ (sh :. Int), | |
ShapeOf dims ~ (sh :. Int), Slice sh, Shape sh, Elt e) | |
=> Tensor dims e -> Tensor dims' e | |
tail = liftT A.tail | |
-- TODO better type error if n is too high (also for drop) | |
take :: forall n dims dims' e sh. | |
(ShapeOf dims' ~ (sh :. Int), ShapeOf dims ~ (sh :. Int), | |
dims' ~ (Init dims :++ '[n]), (n :<= Last dims) ~ True, | |
Slice sh, Shape sh, SingI n, Elt e) | |
=> Tensor dims e -> Tensor dims' e | |
take = liftT $ A.take (constant . fromInteger $ fromSing (sing :: SNat n)) | |
drop :: forall n dims dims' e sh. | |
(ShapeOf dims' ~ (sh :. Int), ShapeOf dims ~ (sh :. Int), | |
dims' ~ (Init dims :++ '[n]), (n :<= Last dims) ~ True, | |
Slice sh, Shape sh, SingI n, Elt e) | |
=> Tensor dims e -> Tensor dims' e | |
drop = liftT $ A.drop (constant . fromInteger $ fromSing (sing :: SNat n)) | |
slit :: forall idx len dims dims' e sh. | |
(ShapeOf dims' ~ (sh :. Int), ShapeOf dims ~ (sh :. Int), | |
dims' ~ (Init dims :++ '[len]), (idx + len :<= Last dims) ~ True, | |
Slice sh, Shape sh, SingI idx, SingI len, Elt e) | |
=> Proxy idx -> Tensor dims e -> Tensor dims' e | |
slit _ = liftT $ A.slit (constant . fromInteger $ fromSing (sing :: SNat idx)) | |
(constant . fromInteger $ fromSing (sing :: SNat len)) | |
permute :: (sh ~ ShapeOf dims, sh' ~ ShapeOf dims', | |
Shape sh, Shape sh', Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Tensor dims' e -> (Exp sh -> Exp sh') | |
-> Tensor dims e -> Tensor dims' e | |
permute f (Tensor def) g = liftT $ A.permute f def g | |
scatter :: Elt e | |
=> Vector dst Int -> Vector def e -> Vector src e -> Vector def e | |
scatter = liftT3 A.scatter | |
backpermute :: forall dims dims' sh' sh e. | |
(sh ~ ShapeOf dims, sh' ~ ShapeOf dims', | |
KnownShape dims, KnownShape dims', Shape sh', Elt e, Shape sh) | |
=> (Exp (ShapeOf dims') -> Exp (ShapeOf dims)) | |
-> Tensor dims e -> Tensor dims' e | |
backpermute f = liftT $ A.backpermute (constant $ shFromDims (Proxy :: Proxy dims')) f | |
gather :: (ShapeLike idx, Elt e) | |
=> Tensor idx Int -> Vector src e -> Tensor idx e | |
gather = liftT2 A.gather | |
reverse :: Elt e => Vector n e -> Vector n e | |
reverse = liftT A.reverse | |
transpose :: Elt e => Matrix n m e -> Matrix m n e | |
transpose = liftT A.transpose | |
-- XXX use unlift to bring Acc of tuple into tuple of Accs | |
-- TODO If we actually want to return a Vector of results here, we don't know | |
-- the length at runtime, which means we need to have SomeVector to store the | |
-- result | |
-- filter :: (Exp e -> Exp Bool) | |
-- -> Tensor dims e -> (Vector e, Tensor (Init dims) Int) | |
-- filter f (Tensor t) = (filtered, nums) | |
-- where (filtered, nums) = A.unlift (A.filter f t) | |
fold :: (dims' ~ Init dims, sh ~ ShapeOf dims, sh' ~ ShapeOf dims', | |
sh ~ (sh' :. Int), Shape sh, Shape sh', Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e | |
-> Tensor dims e -> Tensor (Init dims) e | |
fold f x = liftT $ A.fold f x | |
-- TODO nice error message for empty Tensor | |
fold1 :: (dims' ~ Init dims, sh ~ ShapeOf dims, sh' ~ ShapeOf dims', | |
sh ~ (sh' :. Int), (Last dims :> 0) ~ True, | |
Shape sh, Shape sh', Elt e) | |
=> (Exp e -> Exp e -> Exp e) | |
-> Tensor dims e -> Tensor (Init dims) e | |
fold1 f = liftT $ A.fold1 f | |
foldAll :: (ShapeLike dims, Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Scalar e | |
foldAll f x = liftT $ A.foldAll f x | |
-- TODO nice error message for empty Tensor. Also for scans. | |
fold1All :: (ShapeLike dims, Elt e, (Product dims :> 0) ~ True) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Scalar e | |
fold1All f x = liftT $ A.foldAll f x | |
-- TODO The sum of elements of the Vector must be less than the last | |
-- dimension of dims. Is this something we should/could do on the type level? | |
-- Also for scans. | |
foldSeg :: (sh ~ ShapeOf dims, dims' ~ (Init dims :++ '[n]), | |
sh' ~ ShapeOf dims', sh ~ (ish :. Int), sh' ~ (ish :. Int), | |
IsIntegral i, Elt e, Elt i, Shape ish) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Vector n i | |
-> Tensor dims' e | |
foldSeg f x = liftT2 $ A.foldSeg f x | |
-- TODO think about what the restriction here should be to make sure every | |
-- segment is non-empty. Also for scans. | |
fold1Seg :: (sh ~ ShapeOf dims, dims' ~ (Init dims :++ '[n]), | |
sh' ~ ShapeOf dims', sh ~ (ish :. Int), sh' ~ (ish :. Int), | |
IsIntegral i, Elt e, Elt i, Shape ish) | |
=> (Exp e -> Exp e -> Exp e) -> Tensor dims e -> Vector n i | |
-> Tensor dims' e | |
fold1Seg f = liftT2 $ A.fold1Seg f | |
all :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int), | |
Shape sh', Elt e) | |
=> (Exp e -> Exp Bool) -> Tensor dims e -> Tensor (Init dims) Bool | |
all f = liftT $ A.all f | |
any :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int), | |
Shape sh', Elt e) | |
=> (Exp e -> Exp Bool) -> Tensor dims e -> Tensor (Init dims) Bool | |
any f = liftT $ A.any f | |
and :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int), | |
Shape sh') | |
=> Tensor dims Bool -> Tensor (Init dims) Bool | |
and = liftT $ A.and | |
or :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int), | |
Shape sh') | |
=> Tensor dims Bool -> Tensor (Init dims) Bool | |
or = liftT $ A.or | |
sum :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int), | |
Shape sh', Num e) | |
=> Tensor dims e -> Tensor (Init dims) e | |
sum = liftT $ A.sum | |
product :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int), | |
Shape sh', Num e) | |
=> Tensor dims e -> Tensor (Init dims) e | |
product = liftT $ A.product | |
minimum :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int), | |
Shape sh', Ord e) | |
=> Tensor dims e -> Tensor (Init dims) e | |
minimum = liftT $ A.minimum | |
maximum :: (sh ~ ShapeOf dims, sh' ~ ShapeOf (Init dims), sh ~ (sh' :. Int), | |
Shape sh', Ord e) | |
=> Tensor dims e -> Tensor (Init dims) e | |
maximum = liftT $ A.maximum | |
scanl :: (dims' ~ (Init dims :++ '[Last dims + 1]), | |
ShapeOf dims ~ (sh :. Int), ShapeOf dims' ~ (sh :. Int), | |
Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Tensor dims' e | |
scanl f x = liftT $ A.scanl f x | |
scanl1 :: (ShapeOf dims ~ (sh :. Int), (Last dims :> 0) ~ True, | |
Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) | |
-> Tensor dims e -> Tensor dims e | |
scanl1 f = liftT $ A.scanl1 f | |
scanl' :: (sh ~ ShapeOf (Init dims), ShapeOf dims ~ (sh :. Int), | |
Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e | |
-> (Tensor dims e, Tensor (Init dims) e) | |
scanl' f x (Tensor t) = (Tensor scanned, Tensor folded) | |
where (scanned, folded) = A.unlift (A.scanl' f x t) | |
scanr :: (dims' ~ (Init dims :++ '[Last dims + 1]), | |
ShapeOf dims ~ (sh :. Int), ShapeOf dims' ~ (sh :. Int), | |
Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e -> Tensor dims' e | |
scanr f x = liftT $ A.scanr f x | |
scanr1 :: (ShapeOf dims ~ (sh :. Int), (Last dims :> 0) ~ True, | |
Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) | |
-> Tensor dims e -> Tensor dims e | |
scanr1 f = liftT $ A.scanr1 f | |
scanr' :: (sh ~ ShapeOf (Init dims), ShapeOf dims ~ (sh :. Int), | |
Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e | |
-> (Tensor dims e, Tensor (Init dims) e) | |
scanr' f x (Tensor t) = (Tensor scanned, Tensor folded) | |
where (scanned, folded) = A.unlift (A.scanr' f x t) | |
prescanl :: (ShapeOf dims ~ (sh :. Int), Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e | |
-> Tensor dims e | |
prescanl f x = liftT $ A.prescanl f x | |
postscanl :: (ShapeOf dims ~ (sh :. Int), Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e | |
-> Tensor dims e | |
postscanl f x = liftT $ A.postscanl f x | |
prescanr :: (ShapeOf dims ~ (sh :. Int), Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e | |
-> Tensor dims e | |
prescanr f x = liftT $ A.prescanr f x | |
postscanr :: (ShapeOf dims ~ (sh :. Int), Shape sh, Elt e) | |
=> (Exp e -> Exp e -> Exp e) -> Exp e -> Tensor dims e | |
-> Tensor dims e | |
postscanr f x = liftT $ A.postscanr f x | |
-- TODO scans with segments | |
-- TODO stencils | |
rank :: forall dims e. SingI (Length dims) => Tensor dims e -> Int | |
rank _ = fromInteger $ fromSing (sing :: SNat (Length dims)) | |
-- XXX uhm, you're supposed to use runN, so are these really a good idea? | |
-- possibly could have some kind of runN function that takes a function from | |
-- Tensors to a Tensor, as well as (non-acc) arrays as arguments | |
-- array. | |
fromFunction :: forall dims sh e. | |
(sh ~ ShapeOf dims, KnownShape dims, Shape sh, Elt e) | |
=> (sh -> e) -> Tensor dims e | |
fromFunction = | |
Tensor . use . A.fromFunction (shFromDims (Proxy :: Proxy dims)) | |
-- TODO: this only exists in accelerate 1.2 | |
-- fromFunctionM :: (sh ~ ShapeOf dims) | |
-- => (sh -> IO e) -> Tensor dims e | |
-- fromFunctionM f = do | |
-- array <- A.fromFunctionM (shFromDims (Proxy :: Proxy dims)) f | |
-- pure $ Tensor array | |
unsafeFromList :: forall dims e. (KnownShape dims, ShapeLike dims, Elt e) | |
=> [e] -> Tensor dims e | |
unsafeFromList = Tensor . use . fromList (shFromDims (Proxy :: Proxy dims)) | |
-- examples | |
ta :: Tensor [3, 2] Double | |
ta = Tensor . use $ fromList (Z :. 3 :. 2 :: DIM2) [1..] | |
tb :: Tensor [3, 2] Double | |
tb = Tensor . use $ fromList (Z :. 3 :. 2 :: DIM2) [2..] | |
tc :: Tensor [3, 3] Double | |
tc = Tensor . use $ fromList (Z :. 3 :. 3 :: DIM2) [1..] | |
td :: Tensor [2, 3] Double | |
td = Tensor . use $ fromList (Z :. 2 :. 3 :: DIM2) [2..] | |
te :: Tensor [3, 5] Double | |
te = enumFromN (constant 2) | |
zipped :: Tensor [3, 2] Double | |
zipped = zipWith (+) ta tb |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment