Last active
November 19, 2016 05:43
-
-
Save notogawa/d32595f2eb79a3cf6bcdd4bd97add3d3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE TypeOperators #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE DataKinds #-} | |
import Data.Singletons | |
import Data.Singletons.Prelude.Enum | |
import Data.Singletons.Prelude.List | |
import Data.Singletons.Prelude.Num | |
import GHC.TypeLits | |
data NDArray (shape :: [Nat]) a = NDArray (Sing shape) -- 中身はshape意外省略 | |
-- https://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html | |
reshape :: Product from ~ Product to => Sing to -> NDArray from a -> NDArray to a | |
reshape = const . NDArray | |
-- reshape可 | |
reshapeable :: NDArray '[2,3,4] a -> NDArray '[3,8] a | |
reshapeable = reshape sing | |
-- 型検査でreshape不可 (次元の積が合わない) | |
-- unreshapeable :: NDArray '[2,3,4] a -> NDArray '[3,3,4] a | |
-- unreshapeable = reshape sing -- Couldn't match type ‘24’ with ‘36’ | |
-- https://docs.scipy.org/doc/numpy/reference/generated/numpy.dot.html | |
dot :: (Last xs ~ Last (Init ys), Num a) => | |
NDArray xs a -> NDArray ys a -> NDArray (Init xs :++ Init (Init ys) :++ '[Last ys]) a | |
NDArray xs `dot` NDArray ys = NDArray (sInit xs %:++ sInit (sInit ys) %:++ SCons (sLast ys) SNil) | |
-- dot可 | |
dottable :: Num a => NDArray '[2,3,4] a -> NDArray '[3,4,2] a -> NDArray '[2,3,3,2] a | |
dottable = dot | |
-- 型検査でdot不可 (結果の型が合わない) | |
-- undottable1 :: Num a => NDArray '[2,3,4] a -> NDArray '[3,4,4] a -> NDArray '[2,3,3,2] a | |
-- undottable1 = dot -- Couldn't match type ‘4’ with ‘2’ | |
-- 型検査でdot不可 (引数の型が合わない) | |
-- undottable2 :: Num a => NDArray '[2,3,4] a -> NDArray '[4,3,2] a -> NDArray '[2,3,4,2] a | |
-- undottable2 = dot -- Couldn't match type ‘4’ with ‘3’ | |
-- https://docs.scipy.org/doc/numpy/reference/generated/numpy.transpose.html | |
transpose :: Sort axes ~ EnumFromTo 0 (Length shape - 1) => | |
Sing axes -> NDArray shape a -> | |
NDArray (Map ((:!!$$) shape) axes) a | |
transpose axes (NDArray shape) = | |
NDArray (sMap (singFun1 (toProxy shape) (shape %:!!)) axes) where | |
toProxy :: Sing (shape :: [Nat]) -> Proxy (Apply (:!!$) shape) | |
toProxy _ = Proxy | |
transposable :: Sing '[1,0,2] -> NDArray '[2,3,4] a -> NDArray '[3,2,4] a | |
transposable = transpose | |
-- transpose不可 (axesにshapeの長さ以上のものが含まれる) | |
-- untransposable1 :: Sing '[1,0,3] -> NDArray '[2,3,4] a -> NDArray '[3,2,4] a | |
-- untransposable1 = transpose -- Couldn't match type ‘3’ with ‘2’ | |
-- transpose不可 (axesに同じ要素が2つ以上含まれる) | |
-- untransposable2 :: Sing '[1,0,0] -> NDArray '[2,3,4] a -> NDArray '[3,2,2] a | |
-- untransposable2 = transpose -- Couldn't match type ‘1’ with ‘2’ | |
-- transpose不可 (axesの長さとshapeの長さが一致しない) | |
-- untransposable3 :: Sing '[1,0] -> NDArray '[2,3,4] a -> NDArray '[3,2] a | |
-- untransposable3 = transpose -- Couldn't match type ‘'[]’ with ‘'[2]’ | |
-- transpose不可 (結果の型が合わない) | |
-- untransposable4 :: Sing '[1,0,2] -> NDArray '[2,3,4] a -> NDArray '[3,2,5] a | |
-- untransposable4 = transpose -- Couldn't match type ‘4’ with ‘5’ | |
-- これらreshape,dot,transposeがあればtensordotが定義できるはず, | |
-- しかし,そのまま素直に(こうできたらいいなと思うように)書くと, | |
-- 型検査に失敗し,1200行くらいのエラーを吐く. | |
-- https://docs.scipy.org/doc/numpy/reference/generated/numpy.tensordot.html | |
tensordot :: (Num a, ns ~ Nub ns, ms ~ Nub ms, | |
Map ((:!!$$) xs) ns ~ Map ((:!!$$) ys) ms) => | |
NDArray xs a -> NDArray ys a -> (Sing ns, Sing ms) -> | |
NDArray (Map ((:!!$$) xs) (EnumFromTo 0 (Length xs - 1) :\\ ns) :++ | |
Map ((:!!$$) ys) (EnumFromTo 0 (Length ys - 1) :\\ ms)) a | |
tensordot x@(NDArray xs) y@(NDArray ys) (ns, ms) = result where | |
range n = sEnumFromTo (sing :: Sing 0) (n %:- (sing :: Sing 1)) | |
notinns = range (sLength xs) %:\\ ns | |
notinms = range (sLength ys) %:\\ ms | |
tx = transpose (notinns %:++ ns) x -- 130行 | |
ty = transpose (ms %:++ notinms) y -- 130行 | |
dimsIn xs = sMap (singFun1 (toProxy xs) (xs %:!!)) where | |
toProxy :: Sing (shape :: [Nat]) -> Proxy (Apply (:!!$) shape) | |
toProxy _ = Proxy | |
(oldxs, oldys) = (dimsIn xs notinns, dimsIn ys notinms) where | |
rtx = reshape (SCons (sProduct oldxs) $ SCons (sProduct $ dimsIn xs ns) SNil) tx -- 280行 | |
rty = reshape (SCons (sProduct $ dimsIn ys ms) $ SCons (sProduct oldys) SNil) ty -- 280行 | |
result = reshape (oldxs %:++ oldys) (rtx `dot` rty) -- 400行 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment