Created
May 30, 2019 17:01
-
-
Save cdepillabout/e9cfb0ee13986d98d2ed410f4628f343 to your computer and use it in GitHub Desktop.
A half-working dependently typed tensor library based on the linear package
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE DataKinds #-} | |
{-# LANGUAGE DefaultSignatures #-} | |
{-# LANGUAGE DerivingStrategies #-} | |
{-# LANGUAGE EmptyCase #-} | |
{-# LANGUAGE ExistentialQuantification #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE InstanceSigs #-} | |
{-# LANGUAGE KindSignatures #-} | |
{-# LANGUAGE NoStarIsType #-} | |
{-# LANGUAGE PolyKinds #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE StandaloneDeriving #-} | |
{-# LANGUAGE TemplateHaskell #-} | |
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
module Foo where | |
import Data.Constraint | |
import Data.Foldable | |
import Data.Kind | |
import Data.Singletons | |
import Data.Singletons.Prelude (Fmap, Product) | |
import Data.Singletons.TypeLits () | |
import qualified Data.Vector as Vector | |
import GHC.TypeLits (type (+), KnownNat, Nat, SomeNat(SomeNat), natVal, someNatVal) | |
import Linear | |
import Linear.V | |
import Unsafe.Coerce (unsafeCoerce) | |
----------------- Fin Stuff --------------------------- | |
newtype Fin (n :: k) = Fin | |
{ unFin :: Int | |
} deriving stock (Show) | |
fin :: forall k (n :: k). Dim n => Int -> Maybe (Fin n) | |
fin i = if reflectDim (Proxy @n) > i then Just (Fin i) else Nothing | |
finUnsafe :: forall k (n :: k). Dim n => Int -> Fin n | |
finUnsafe i = | |
case fin i of | |
Nothing -> error "finUnsafe" | |
Just res -> res | |
---------------- HList stuff -------------------------- | |
data HList (as :: [k]) where | |
EmptyHList :: HList '[] | |
ConsHList :: x -> HList xs -> HList (x ': xs) | |
---------------- Vector Stuff -------------------------- | |
replicateVec :: forall n a. Dim n => a -> V n a | |
replicateVec a = V $ Vector.replicate (reflectDim (Proxy @n)) a | |
reifyDim' :: Int -> (forall k (n :: k). Dim n => Proxy n -> r) -> r | |
reifyDim' i f = | |
case someNatVal (fromIntegral i) of | |
Nothing -> error "lalala" | |
Just (SomeNat proxy) -> f @Nat proxy | |
indexVec :: forall k (n :: k) a. Fin n -> V n a -> a | |
indexVec (Fin i) (V vec) = vec Vector.! i | |
fromListVec :: forall k (n :: k) a. Dim n => [a] -> Maybe (V n a) | |
fromListVec as = | |
let vec = Vector.fromList as | |
len = Vector.length vec | |
di = reflectDim (Proxy @n) | |
in | |
if di <= len then Just (V $ Vector.take di vec) else Nothing | |
fromListVecUnsafe :: forall k (n :: k) a. Dim n => [a] -> V n a | |
fromListVecUnsafe as = | |
case fromListVec as of | |
Nothing -> error "fromListVecUnsafe" | |
Just res -> res | |
dropVec :: forall m n a. KnownNat m => V (m + n) a -> V n a | |
dropVec (V vec) = V $ Vector.drop (fromIntegral $ natVal (Proxy @m)) vec | |
genVec :: forall n a. Dim n => (Fin n -> a) -> V n a | |
genVec f = V $ Vector.generate (reflectDim (Proxy @n)) (\i -> f $ finUnsafe i) | |
---------------- Matrix Stuff -------------------------- | |
newtype Matrix (ns :: [k]) (a :: Type) = Matrix | |
{ unMatrix :: V (Product ns) a | |
} deriving stock Show | |
eqMatrix :: Eq a => Matrix ns a -> Matrix ns a -> Bool | |
eqMatrix (Matrix v1) (Matrix v2) = v1 == v2 | |
fmapMatrix :: (a -> b) -> Matrix ns a -> Matrix ns b | |
fmapMatrix a2b (Matrix v) = Matrix $ fmap a2b v | |
replicateMatrix :: forall (ns :: [k]) a. Dims ns => a -> Matrix ns a | |
replicateMatrix a = | |
case prodDimFromDims @_ @ns of | |
Sub Dict -> Matrix (replicateVec @(Product ns) a) | |
genMatrix :: forall k (ns :: [k]) a. Dims ns => (HList (Fmap Fin ns) -> a) -> Matrix ns a | |
genMatrix f = | |
Matrix $ | |
V $ | |
Vector.generate | |
(product (reflectDims (Proxy @ns))) | |
(\i -> f _) | |
prodDimFromDims :: forall k (ns :: [k]). Dims ns :- Dim (Product ns) | |
prodDimFromDims = | |
Sub $ | |
let prod = product $ reflectDims (Proxy @ns) | |
in reifyDim prod f | |
where | |
f :: forall m x. Dim m => Proxy m -> Dict (Dim x) | |
f _ = unsafeCoerce (Dict :: Dict (Dim m)) | |
allDimFromDims :: forall (ns :: [k]). Dims ns :- AllC Dim ns | |
allDimFromDims = | |
Sub $ | |
case reflectDims (Proxy @ns) of | |
[] -> unsafeCoerce (Dict :: Dict ()) | |
(h:t) -> | |
reifyDimNat h (f t) | |
where | |
f :: forall (m :: Nat). Dim m => [Int] -> Proxy m -> Dict (AllC Dim ns) | |
f t Proxy = | |
reifyDimsNat t g | |
where | |
g :: forall (ms :: [Nat]). Dims ms => Proxy ms -> Dict (AllC Dim ns) | |
g Proxy = | |
case allDimFromDims @_ @ms of | |
Sub (Dict :: Dict (AllC Dim ms)) -> | |
case proveAll (Proxy @Dim) (Proxy @m) (Proxy @ms) of | |
Sub Dict -> unsafeCoerce (Dict :: Dict (AllC Dim (m ': ms))) | |
testtest :: forall n m o. Dims '[n, m, o] => Proxy '[n, m, o] -> Int | |
testtest _ = | |
case allDimFromDims @_ @'[n, m, o] of | |
Sub Dict -> reflectDim (Proxy @m) | |
data MyNat where | |
MyZero :: MyNat | |
MySucc :: MyNat -> MyNat | |
instance Dim MyZero where | |
reflectDim :: forall p. p MyZero -> Int | |
reflectDim _ = 0 | |
instance Dim n => Dim (MySucc n) where | |
reflectDim :: forall p. p (MySucc n) -> Int | |
reflectDim _ = reflectDim (Proxy @n) + 1 | |
type MyOne = 'MySucc 'MyZero | |
type MyTwo = 'MySucc MyOne | |
type MyThree = 'MySucc MyTwo | |
type MyFour = 'MySucc MyThree | |
class Dims ns where | |
reflectDims :: forall p. p ns -> [Int] | |
instance Dim n => Dims (V n a) where | |
reflectDims _ = [reflectDim (Proxy @n)] | |
instance Dims ns => Dims (Matrix ns a) where | |
reflectDims _ = reflectDims (Proxy @ns) | |
instance Dims '[] where | |
reflectDims _ = [] | |
instance (Dim n, Dims ns) => Dims (n ': ns) where | |
reflectDims _ = reflectDim (Proxy @n) : reflectDims (Proxy @ns) | |
dims :: forall ns a. Dims ns => Matrix ns a -> [Int] | |
dims _ = reflectDims (Proxy @ns) | |
reifyDimsNat :: forall r. [Int] -> (forall (ns :: [Nat]). Dims ns => Proxy ns -> r) -> r | |
reifyDimsNat [] f = f @'[] Proxy | |
reifyDimsNat (h:t) f = | |
reifyDimNat h go1 | |
where | |
go1 :: forall (m :: Nat). KnownNat m => Proxy m -> r | |
go1 Proxy = | |
reifyDimsNat t go2 | |
where | |
go2 :: forall (ms :: [Nat]). Dims ms => Proxy ms -> r | |
go2 Proxy = | |
case knownNatToDim @m of | |
Sub Dict -> f (Proxy :: Proxy (m ': ms)) | |
reifyDims :: forall r. [Int] -> (forall k (ns :: [k]). Dims ns => Proxy ns -> r) -> r | |
reifyDims [] f = f @_ @'[] Proxy | |
reifyDims (h:t) f = | |
reifyDimNat h go1 | |
where | |
go1 :: forall (m :: Nat). KnownNat m => Proxy m -> r | |
go1 Proxy = | |
reifyDimsNat t go2 | |
where | |
go2 :: forall (ms :: [Nat]). Dims ms => Proxy ms -> r | |
go2 Proxy = | |
case knownNatToDim @m of | |
Sub Dict -> f @Nat (Proxy :: Proxy (m ': ms)) | |
knownNatToDim :: KnownNat n :- Dim n | |
knownNatToDim = Sub Dict | |
type family AllC (constraint :: k -> Constraint) (as :: [k]) :: Constraint where | |
AllC _ '[] = () | |
AllC constraint (a ': as) = (constraint a, AllC constraint as) | |
proveAll | |
:: forall (n :: k) (ns :: [k]) (constraint :: k -> Constraint) proxy1 proxy2 proxy3 | |
. proxy1 constraint | |
-> proxy2 n | |
-> proxy3 ns | |
-> (constraint n, AllC constraint ns) :- AllC constraint (n ': ns) | |
proveAll _ _ _ = Sub Dict |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
let | |
nixpkgsSrc = builtins.fetchTarball { | |
# nixpkgs-unstable as of 2019/05/30. | |
url = "https://github.com/NixOS/nixpkgs/archive/eccb90a2d99.tar.gz"; | |
sha256 = "0ffa84mp1fgmnqx2vn43q9pypm3ip9y67dkhigsj598d8k1chzzw"; | |
}; | |
nixpkgs = import nixpkgsSrc {}; | |
haskellPkgs = nixpkgs.haskellPackages; | |
ghcEnv = haskellPkgs.ghcWithPackages (pkgs: with pkgs; [ | |
constraints | |
linear | |
singletons | |
]); | |
in | |
nixpkgs.mkShell { | |
buildInputs = [ ghcEnv ]; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment