Created
August 28, 2019 00:37
-
-
Save mstksg/b92956c17da4026b876f5b218b9ed6e1 to your computer and use it in GitHub Desktop.
backprop but using with mutable variables
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env stack | |
-- stack --install-ghc ghci --package ad --package lens --package vinyl --package reflection --package tagged --package transformers --package vector | |
{-# LANGUAGE BangPatterns #-} | |
{-# LANGUAGE ExistentialQuantification #-} | |
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE FlexibleInstances #-} | |
{-# LANGUAGE GADTs #-} | |
{-# LANGUAGE KindSignatures #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE PatternSynonyms #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE RecordWildCards #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE TypeFamilies #-} | |
{-# LANGUAGE TypeInType #-} | |
{-# LANGUAGE ViewPatterns #-} | |
{-# OPTIONS_GHC -Wall #-} | |
{-# OPTIONS_GHC -Werror=incomplete-patterns #-} | |
{-# OPTIONS_GHC -Wredundant-constraints #-} | |
import Control.Applicative.Backwards | |
import Control.Exception | |
import Control.Lens hiding (Identity(..), Const(..)) | |
import Control.Monad.ST | |
import Control.Monad.Trans.Class | |
import Control.Monad.Trans.State | |
import Data.Bifunctor | |
import Data.Coerce | |
import Data.Foldable | |
import Data.IORef | |
import Data.Kind | |
import Data.Proxy | |
import Data.Reflection | |
import Data.STRef | |
import Data.Tagged | |
import Data.Type.Equality | |
import Data.Vinyl | |
import Data.Vinyl.Functor | |
import Numeric.AD | |
import Numeric.AD.Internal.Reverse (Tape, Reverse) | |
import Numeric.AD.Rank1.Forward (Forward) | |
import System.IO.Unsafe | |
import Type.Reflection | |
import qualified Data.Vector as V | |
import qualified Data.Vinyl.Recursive as VR | |
data Op :: [Type] -> Type -> Type where | |
Op :: { opFunc :: HList as -> (a, a -> HList as) | |
} | |
-> Op as a | |
-- | Initialize to zero | |
newtype InitFunc (r :: Type -> Type) (a :: Type) = IF { runIF :: forall q. ST q (r q) } | |
-- | Set a ref to be one | |
newtype OneFunc (r :: Type -> Type) (a :: Type) = OF { runOF :: forall q. r q -> ST q () } | |
-- | Add a value to a ref | |
newtype AddFunc (r :: Type -> Type) (a :: Type) = AF { runAF :: forall q. r q -> a -> ST q () } | |
-- | Read a ref into a value | |
newtype ReadFunc (r :: Type -> Type) (a :: Type) = RF { runRF :: forall q. r q -> ST q a } | |
data OpR :: [Type] -> (Type -> Type, Type) -> Type where | |
OpR :: { opRInit :: InitFunc r a | |
, opROp :: Op as a | |
} | |
-> OpR as '(r, a) | |
data BRef (s :: Type) = BRInp !Int -- ^ Input number | |
| BRIx !Int -- ^ Number in tape | |
| BRC -- ^ no source | |
deriving Show | |
data BVar s r a = BV { _bvRef :: !(BRef s) | |
, _bvVal :: !a | |
} | |
forceBVar :: BVar s r a -> () | |
forceBVar (BV !_ !_) = () | |
data SomeF f = forall a. SomeF (TypeRep a) !(f a) | |
data Uncur :: (a -> b -> Type) -> (a, b) -> Type where | |
Uncur :: { getUncur :: !(f a b) } -> Uncur f '(a, b) | |
data InpRef :: Type -> (Type -> Type, Type) -> Type where | |
IR :: { _irVar :: !(BVar s r b) | |
, _irAdd :: !(AddFunc r a) | |
} -> InpRef a '(r, b) | |
-- | InputRef to some source | |
data SomeInpRef a = SIR (SomeF (InpRef a)) | |
data TapeNode :: [Type] -> (Type -> Type, Type) -> Type where | |
TN :: { _tnInputs :: !(Rec SomeInpRef as) | |
, _tnGrad :: !(a -> HList as) | |
, _tnInit :: !(InitFunc r a) | |
, _tnRead :: !(ReadFunc r a) | |
} | |
-> TapeNode as '(r, a) | |
inspectTN :: TapeNode as ra -> String | |
inspectTN (TN inp _ _ _) = unlines $ VR.rfoldMap go inp | |
where | |
go :: SomeInpRef x -> [String] | |
go (SIR (SomeF (TupRep trx try_) (IR v _))) = [show trx, show try_, show (_bvRef v)] | |
newtype W = W { wRef :: IORef (Int, [SomeF (Uncur TapeNode)]) } | |
insertNode | |
:: forall a as r s. (Typeable as, Typeable r, Typeable a) | |
=> TapeNode as '(r, a) | |
-> a -- ^ val | |
-> W | |
-> IO (BVar s r a) | |
insertNode !tn !x !w = fmap mkVar . atomicModifyIORef (wRef w) $ \(n, t) -> | |
let !n' = n + 1 | |
!t' = SomeF typeRep (Uncur tn) : t | |
in ((n', t'), n) | |
where | |
mkVar :: Int -> BVar s r a | |
mkVar i = BV (BRIx i) x | |
constVar :: a -> BVar s r a | |
constVar = BV BRC | |
{-# INLINE constVar #-} | |
type family Snds as where | |
Snds '[] = '[] | |
Snds ('(a, b) ': abs) = b ': Snds abs | |
-- | Project out a constant value if the 'BVar' refers to one. | |
bvConst :: BVar s r a -> Maybe a | |
bvConst (BV BRC !x) = Just x | |
bvConst _ = Nothing | |
{-# INLINE bvConst #-} | |
evalOp :: Op as a -> HList as -> a | |
evalOp o = fst . opFunc o | |
getSnds | |
:: (forall a b. f a b -> g b) | |
-> Rec (Uncur f) as | |
-> Rec g (Snds as) | |
getSnds f = \case | |
RNil -> RNil | |
Uncur x :& xs -> f x :& getSnds f xs | |
data TaggedF f a b = TaggedF { unTaggedF :: f b } | |
rzipWith3 | |
:: (forall a. f a -> g a -> h a -> j a) | |
-> Rec f as | |
-> Rec g as | |
-> Rec h as | |
-> Rec j as | |
rzipWith3 f = \case | |
RNil -> \case | |
RNil -> \case | |
RNil -> RNil | |
x :& xs -> \case | |
y :& ys -> \case | |
z :& zs -> f x y z :& rzipWith3 f xs ys zs | |
liftOp_ | |
:: forall s ras r b. (Reifies s W, Typeable b, Typeable r, Typeable ras) | |
=> Rec (Uncur AddFunc) ras | |
-> Op (Snds ras) b | |
-> Rec (Uncur (BVar s)) ras | |
-> (b -> InitFunc r b) | |
-> ReadFunc r b | |
-> IO (BVar s r b) | |
liftOp_ afs o vs ifb rfb = case rtraverse seekOutConst vs of | |
Just xs -> return . constVar $ evalOp o (getSnds (Identity . unTagged) xs) | |
Nothing -> | |
let ras = typeRepRec $ typeRep @ras | |
!(!y, !g) = opFunc o (getSnds (Identity . _bvVal) vs) | |
!tn = TN | |
{ _tnInputs = getSnds unTaggedF $ rzipWith3 combineAfs ras afs vs | |
, _tnGrad = g | |
, _tnInit = ifb y | |
, _tnRead = rfb | |
} | |
in withTypeable (recTypeRep (typeRepSnds ras)) $ | |
insertNode tn y (reflect (Proxy @s)) | |
where | |
seekOutConst :: Uncur (BVar s) x -> Maybe (Uncur Tagged x) | |
seekOutConst (Uncur x) = Uncur . Tagged <$> bvConst x | |
combineAfs | |
:: TypeRep x | |
-> Uncur AddFunc x | |
-> Uncur (BVar s) x | |
-> Uncur (TaggedF SomeInpRef) x | |
combineAfs tr (Uncur af) (Uncur bv) = Uncur . TaggedF . SIR $ | |
SomeF tr $ IR bv af | |
-- | 'Numeric.Backprop.liftOp', but with explicit 'add' and 'zero'. | |
liftOp | |
:: forall s ras r b. (Reifies s W, Typeable b, Typeable r, Typeable ras) | |
=> Rec (Uncur AddFunc) ras | |
-> Op (Snds ras) b | |
-> Rec (Uncur (BVar s)) ras | |
-> (b -> InitFunc r b) | |
-> ReadFunc r b | |
-> BVar s r b | |
liftOp afs o !vs ifb = unsafePerformIO . liftOp_ afs o vs ifb | |
{-# INLINE liftOp #-} | |
liftOp1 | |
:: (Reifies s W, Typeable ra, Typeable a, Typeable rb, Typeable b) | |
=> AddFunc ra a | |
-> Op '[a] b | |
-> BVar s ra a | |
-> (b -> InitFunc rb b) | |
-> ReadFunc rb b | |
-> BVar s rb b | |
liftOp1 af o v = liftOp (Uncur af :& RNil) o (Uncur v :& RNil) | |
liftOp2 | |
:: (Reifies s W, Typeable ra, Typeable a, Typeable rb, Typeable b, Typeable rc, Typeable c) | |
=> AddFunc ra a | |
-> AddFunc rb b | |
-> Op '[a, b] c | |
-> BVar s ra a | |
-> BVar s rb b | |
-> (c -> InitFunc rc c) | |
-> ReadFunc rc c | |
-> BVar s rc c | |
liftOp2 af1 af2 o v1 v2 = liftOp (Uncur af1 :& Uncur af2 :& RNil) o (Uncur v1 :& Uncur v2 :& RNil) | |
partVar_ | |
:: forall a b ra rb s. (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W) | |
=> InitFunc ra a | |
-> ReadFunc ra a | |
-> AddFunc rb a | |
-> (b -> a) | |
-> BVar s rb b | |
-> IO (BVar s ra a) | |
partVar_ ifa rfa afa getA v = insertNode tn y (reflect (Proxy @s)) | |
where | |
x = _bvVal v | |
y = getA x | |
tn :: TapeNode '[a] '(ra, a) | |
tn = TN | |
{ _tnInputs = SIR (SomeF typeRep (IR v afa)) :& RNil | |
, _tnGrad = (:& RNil) . Identity | |
, _tnInit = ifa | |
, _tnRead = rfa | |
} | |
partVar | |
:: forall a b ra rb s. (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W) | |
=> InitFunc ra a | |
-> ReadFunc ra a | |
-> AddFunc rb a | |
-> (b -> a) | |
-> BVar s rb b | |
-> BVar s ra a | |
partVar ifa rfa afa getA = unsafePerformIO . partVar_ ifa rfa afa getA | |
viewVar_ | |
:: (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W) | |
=> InitFunc rb b | |
-> InitFunc ra a | |
-> ReadFunc rb b | |
-> ReadFunc ra a | |
-> AddFunc rb b | |
-> Lens' b a | |
-> BVar s rb b | |
-> IO (BVar s ra a) | |
viewVar_ ifb ifa rfb rfa afa l v = partVar_ ifa rfa af (view l) v | |
where | |
af = AF $ \r x -> do | |
y <- runRF rfb =<< runIF ifb | |
runAF afa r $ set l x y | |
viewVar | |
:: (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W) | |
=> InitFunc rb b | |
-> InitFunc ra a | |
-> ReadFunc rb b | |
-> ReadFunc ra a | |
-> AddFunc rb b | |
-> Lens' b a | |
-> BVar s rb b | |
-> BVar s ra a | |
viewVar ifb ifa rfb rfa afa l = unsafePerformIO . viewVar_ ifb ifa rfb rfa afa l | |
initWengert :: IO W | |
initWengert = W <$> newIORef (0,[]) | |
{-# INLINE initWengert #-} | |
fillWengert | |
:: forall ras rb b. () | |
=> (forall s. Reifies s W => Rec (Uncur (BVar s)) ras -> BVar s rb b) | |
-> Rec (Uncur Tagged) ras | |
-> IO (V.Vector (SomeF (Uncur TapeNode)), b) | |
fillWengert f xs = do | |
w <- initWengert | |
reify w $ \(Proxy :: Proxy s) -> do | |
let !oVar = f (inpRec @s) | |
evaluate (forceBVar oVar) | |
(_, tp) <- readIORef (wRef w) | |
pure (V.fromList (reverse tp), _bvVal oVar) | |
where | |
inpRec :: forall s. Rec (Uncur (BVar s)) ras | |
inpRec = evalState (rtraverse (state . go) xs) 0 | |
where | |
go :: Uncur Tagged x -> Int -> (Uncur (BVar s) x, Int) | |
go (Uncur x) i = (Uncur (BV (BRInp i) (unTagged x)), i + 1) | |
newtype RecFor s (r :: Type -> Type) (a :: Type) = RecFor (r s) | |
data Runner s = R { _rDelta :: !(V.Vector (SomeF (Uncur (RecFor s)))) | |
, _rInputs :: !(V.Vector (SomeF (Uncur (RecFor s)))) | |
} | |
initRunner | |
:: forall s. () | |
=> V.Vector (SomeF (Uncur TapeNode)) | |
-> V.Vector (SomeF (Uncur InitFunc)) | |
-> ST s (Runner s) | |
initRunner stns xs = | |
R <$> V.mapM mkDelts stns | |
<*> V.mapM mkInps xs | |
where | |
mkDelts :: SomeF (Uncur TapeNode) -> ST s (SomeF (Uncur (RecFor s))) | |
mkDelts (SomeF (TupRep _ tr) (Uncur TN{..})) = do | |
r <- runIF _tnInit | |
pure . SomeF tr . Uncur $ RecFor r | |
mkInps :: SomeF (Uncur InitFunc) -> ST s (SomeF (Uncur (RecFor s))) | |
mkInps (SomeF tr (Uncur ifx)) = do | |
r <- runIF ifx | |
pure . SomeF tr . Uncur $ RecFor r | |
gradRunner | |
:: forall rb b s. (Typeable rb, Typeable b) | |
=> OneFunc rb b -- ^ set to be one | |
-> Runner s | |
-> V.Vector (SomeF (Uncur TapeNode)) | |
-> ST s () | |
gradRunner so R{..} stns = do | |
runOF so rO | |
forwards . traverse_ Backwards $ V.zipWith go _rDelta stns | |
where | |
Uncur (RecFor rO) = coerceSomeF (typeRep @'(rb, b)) "gradRunner init" (V.last _rDelta) | |
go :: SomeF (Uncur (RecFor s)) | |
-> SomeF (Uncur (TapeNode)) | |
-> ST s () | |
go rf (SomeF (TupRep _ trrx) (Uncur TN{..})) = do | |
d <- runRF _tnRead rx | |
let gs = _tnGrad d | |
rzipWithM_ propagate _tnInputs gs | |
where | |
Uncur (RecFor rx) = coerceSomeF trrx "gradRunner tape" rf | |
propagate :: SomeInpRef x -> Identity x -> ST s () | |
propagate (SIR (SomeF trb (IR irv ira))) (Identity x) = do | |
case _bvRef irv of | |
BRInp i -> | |
let Uncur (RecFor rb) = coerceSomeF trb "propagate input" (_rInputs V.! i) | |
in runAF ira rb x | |
BRIx i -> | |
let Uncur (RecFor rb) = coerceSomeF trb "propagate deltas" (_rDelta V.! i) | |
in runAF ira rb x | |
BRC -> return () | |
coerceSomeF | |
:: forall a f. () | |
=> TypeRep a | |
-> String | |
-> SomeF f | |
-> f a | |
coerceSomeF tra e (SomeF tr x) | |
| Just HRefl <- tr `eqTypeRep` tra = x | |
| otherwise = error $ e ++ " <" ++ show tra ++ "> vs. <" ++ show tr ++ ">" | |
backpropN | |
:: forall ras rb b. (Typeable ras, Typeable rb, Typeable b) | |
=> (forall s. Reifies s W => Rec (Uncur (BVar s)) ras -> BVar s rb b) | |
-> OneFunc rb b | |
-> Rec (Uncur InitFunc) ras | |
-> Rec (Uncur ReadFunc) ras | |
-> Rec (Uncur Tagged) ras | |
-> (b, Rec (Uncur Tagged) ras) | |
backpropN f sf ifs rfs !xs = (y, g') | |
where | |
!(!tp,!y) = unsafePerformIO $ fillWengert f xs | |
g' :: Rec (Uncur Tagged) ras | |
g' = runST $ do | |
r <- initRunner tp | |
. V.fromList | |
. VR.recordToList | |
. VR.rzipWith (\tr ifx -> Const (SomeF tr ifx)) (typeRepRec typeRep) | |
$ ifs | |
gradRunner sf r tp | |
evalStateT (rzipWithM (pullOut (_rInputs r)) (typeRepRec typeRep) rfs) 0 | |
pullOut | |
:: V.Vector (SomeF (Uncur (RecFor s))) | |
-> TypeRep x | |
-> Uncur ReadFunc x | |
-> StateT Int (ST s) (Uncur Tagged x) | |
pullOut inps trx (Uncur rf) = do | |
i <- state $ \i' -> (i', i' + 1) | |
let Uncur (RecFor r) = coerceSomeF trx "pullOut" $ inps V.! i | |
Uncur . Tagged <$> lift (runRF rf r) | |
backprop | |
:: (Typeable ra, Typeable a, Typeable rb, Typeable b) | |
=> (forall s. Reifies s W => BVar s ra a -> BVar s rb b) | |
-> OneFunc rb b | |
-> InitFunc ra a | |
-> ReadFunc ra a | |
-> a | |
-> (b, a) | |
backprop f sfb ifa rfa x = second (unTagged . getUncur . rHead) $ | |
backpropN (f . getUncur . rHead) sfb | |
(Uncur ifa :& RNil) | |
(Uncur rfa :& RNil) | |
(Uncur (Tagged x) :& RNil) | |
backprop2 | |
:: forall a b c ra rb rc. (Typeable ra, Typeable a, Typeable rb, Typeable b, Typeable rc, Typeable c) | |
=> (forall s. Reifies s W => BVar s ra a -> BVar s rb b -> BVar s rc c) | |
-> OneFunc rc c | |
-> InitFunc ra a | |
-> InitFunc rb b | |
-> ReadFunc ra a | |
-> ReadFunc rb b | |
-> a | |
-> b | |
-> (c, (a, b)) | |
backprop2 f sfc ifa ifb rfa rfb x y = second getOut $ | |
backpropN getIn sfc | |
(Uncur ifa :& Uncur ifb :& RNil) | |
(Uncur rfa :& Uncur rfb :& RNil) | |
(Uncur (Tagged x) :& Uncur (Tagged y) :& RNil) | |
where | |
getOut :: Rec (Uncur Tagged) '[ '(ra, a), '(rb, b) ] -> (a, b) | |
getOut (Uncur (Tagged dx) :& Uncur (Tagged dy) :& RNil) = (dx, dy) | |
getIn :: Reifies s W => Rec (Uncur (BVar s)) '[ '(ra, a), '(rb, b) ] -> BVar s rc c | |
getIn (Uncur vx :& Uncur vy :& RNil) = f vx vy | |
gradBPN | |
:: forall ras rb b. (Typeable ras, Typeable rb, Typeable b) | |
=> (forall s. Reifies s W => Rec (Uncur (BVar s)) ras -> BVar s rb b) | |
-> OneFunc rb b | |
-> Rec (Uncur InitFunc) ras | |
-> Rec (Uncur ReadFunc) ras | |
-> Rec (Uncur Tagged) ras | |
-> Rec (Uncur Tagged) ras | |
gradBPN f sf ifs rfs = snd . backpropN f sf ifs rfs | |
gradBP | |
:: (Typeable ra, Typeable a, Typeable rb, Typeable b) | |
=> (forall s. Reifies s W => BVar s ra a -> BVar s rb b) | |
-> OneFunc rb b | |
-> InitFunc ra a | |
-> ReadFunc ra a | |
-> a | |
-> a | |
gradBP f sfb ifa rfa = snd . backprop f sfb ifa rfa | |
gradBP2 | |
:: forall a b c ra rb rc. (Typeable ra, Typeable a, Typeable rb, Typeable b, Typeable rc, Typeable c) | |
=> (forall s. Reifies s W => BVar s ra a -> BVar s rb b -> BVar s rc c) | |
-> OneFunc rc c | |
-> InitFunc ra a | |
-> InitFunc rb b | |
-> ReadFunc ra a | |
-> ReadFunc rb b | |
-> a | |
-> b | |
-> (a, b) | |
gradBP2 f sfc ifa ifb rfa rfb x = snd . backprop2 f sfc ifa ifb rfa rfb x | |
main :: IO () | |
main = putStrLn "hi" | |
op1 :: Num a => (forall s. AD s (Forward a) -> AD s (Forward a)) -> Op '[a] a | |
op1 f = Op $ \(Identity x :& RNil) -> second (\dx dy -> Identity (dx * dy) :& RNil) $ diff' f x | |
op2 :: Num a => (forall s. Reifies s Tape => Reverse s a -> Reverse s a -> Reverse s a) -> Op '[a,a] a | |
op2 f = Op $ \(Identity x :& Identity y :& RNil) -> | |
let (z, [dX,dY]) = grad' (\[x',y'] -> f x' y') [x,y] | |
in (z, \dZ -> Identity (dZ * dX) :& Identity (dZ * dY) :& RNil) | |
newtype WholeRef a s = WR { getWR :: STRef s a } | |
readWR :: forall a. ReadFunc (WholeRef a) a | |
readWR = RF $ coerce (readSTRef @_ @a) | |
addWR :: forall a. Num a => AddFunc (WholeRef a) a | |
addWR = AF $ \(WR r) x -> modifySTRef r (+ x) | |
initWR :: forall a. Num a => InitFunc (WholeRef a) a | |
initWR = IF $ WR <$> newSTRef @a 0 | |
oneWR :: forall a. Num a => OneFunc (WholeRef a) a | |
oneWR = OF $ \(WR r) -> writeSTRef r 1 | |
instance (Num a, Typeable a, Reifies s W) => Num (BVar s (WholeRef a) a) where | |
x + y = liftOp2 addWR addWR (op2 (+)) x y (const initWR) readWR | |
x * y = liftOp2 addWR addWR (op2 (*)) x y (const initWR) readWR | |
x - y = liftOp2 addWR addWR (op2 (-)) x y (const initWR) readWR | |
negate x = liftOp1 addWR (op1 negate) x (const initWR) readWR | |
abs x = liftOp1 addWR (op1 abs) x (const initWR) readWR | |
signum x = liftOp1 addWR (op1 signum) x (const initWR) readWR | |
fromInteger = constVar . fromIntegral | |
instance (Fractional a, Typeable a, Reifies s W) => Fractional (BVar s (WholeRef a) a) where | |
x / y = liftOp2 addWR addWR (op2 (/)) x y (const initWR) readWR | |
recip x = liftOp1 addWR (op1 recip) x (const initWR) readWR | |
fromRational = constVar . fromRational | |
data TupleRef ra rb (s :: Type) = TR (ra s) (rb s) | |
readTR :: ReadFunc ra a -> ReadFunc rb b -> ReadFunc (TupleRef ra rb) (a, b) | |
readTR ra rb = RF $ \(TR rx ry) -> (,) <$> runRF ra rx <*> runRF rb ry | |
addTR :: AddFunc ra a -> AddFunc rb b -> AddFunc (TupleRef ra rb) (a, b) | |
addTR aa ab = AF $ \(TR rx ry) (x, y) -> runAF aa rx x *> runAF ab ry y | |
initTR :: InitFunc ra a -> InitFunc rb b -> InitFunc (TupleRef ra rb) (a, b) | |
initTR ia ib = IF $ TR <$> runIF ia <*> runIF ib | |
oneTR :: OneFunc ra a -> OneFunc rb b -> OneFunc (TupleRef ra rb) (a, b) | |
oneTR oa ob = OF $ \(TR rx ry) -> runOF oa rx *> runOF ob ry | |
fstVar | |
:: (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W) | |
=> InitFunc ra a | |
-> ReadFunc ra a | |
-> AddFunc ra a | |
-> BVar s (TupleRef ra rb) (a, b) | |
-> BVar s ra a | |
fstVar ifa rfa afa = partVar ifa rfa af fst | |
where | |
af = AF $ \(TR rx _) -> runAF afa rx | |
sndVar | |
:: (Typeable ra, Typeable a, Typeable rb, Typeable b, Reifies s W) | |
=> InitFunc rb b | |
-> ReadFunc rb b | |
-> AddFunc rb b | |
-> BVar s (TupleRef ra rb) (a, b) | |
-> BVar s rb b | |
sndVar ifa rfa afa = partVar ifa rfa af snd | |
where | |
af = AF $ \(TR _ ry) -> runAF afa ry | |
typeRepRec :: forall k (as :: [k]). Typeable k => TypeRep as -> Rec TypeRep as | |
typeRepRec tr | |
| Just Refl <- testEquality tr (typeRep @'[]) = RNil | |
| App (App c x) xs <- tr | |
, Just HRefl <- eqTypeRep c (typeRep @('(:) :: k -> [k] -> [k])) | |
= let ys = typeRepRec xs | |
in x :& ys | |
| otherwise = undefined | |
recTypeRep :: forall k (as :: [k]). Typeable k => Rec TypeRep as -> TypeRep as | |
recTypeRep = \case | |
RNil -> typeRep | |
x :& xs -> App (App (typeRep @('(:))) x) (recTypeRep xs) | |
data SplitTup :: (a, b) -> Type where | |
SplitTup :: TypeRep x -> TypeRep y -> SplitTup '(x, y) | |
splitTup | |
:: forall a b (xy :: (a, b)). (Typeable a, Typeable b) | |
=> TypeRep xy | |
-> SplitTup xy | |
splitTup = \case | |
App (App tup x) y | |
| Just HRefl <- eqTypeRep tup (typeRep @('(,) :: a -> b -> (a, b))) | |
-> SplitTup x y | |
_ -> errorWithoutStackTrace "what" | |
pattern TupRep | |
:: forall a b xy. | |
(Typeable a, Typeable b) | |
=> forall (x :: a) (y :: b). (xy ~ '(x, y)) | |
=> TypeRep x | |
-> TypeRep y | |
-> TypeRep xy | |
pattern TupRep x y <- (splitTup->SplitTup x y) | |
where | |
TupRep x y = App (App (typeRep @'(,)) x) y | |
{-# COMPLETE TupRep #-} | |
typeRepSnds | |
:: forall (abs :: [(a,b)]). (Typeable a, Typeable b) | |
=> Rec TypeRep abs | |
-> Rec TypeRep (Snds abs) | |
typeRepSnds = \case | |
RNil -> RNil | |
TupRep _ y :& xs -> y :& typeRepSnds xs | |
rzipWithM_ | |
:: forall h f g as. Applicative h | |
=> (forall a. f a -> g a -> h ()) | |
-> Rec f as | |
-> Rec g as | |
-> h () | |
rzipWithM_ f = go | |
where | |
go :: forall bs. Rec f bs -> Rec g bs -> h () | |
go = \case | |
RNil -> \case | |
RNil -> pure () | |
x :& xs -> \case | |
y :& ys -> f x y *> go xs ys | |
rzipWithM | |
:: forall h f g j as. Applicative h | |
=> (forall a. f a -> g a -> h (j a)) | |
-> Rec f as | |
-> Rec g as | |
-> h (Rec j as) | |
rzipWithM f = go | |
where | |
go :: forall bs. Rec f bs -> Rec g bs -> h (Rec j bs) | |
go = \case | |
RNil -> \case | |
RNil -> pure RNil | |
x :& xs -> \case | |
y :& ys -> (:&) <$> f x y <*> go xs ys | |
rzipWithM3_ | |
:: forall h f g j as. Applicative h | |
=> (forall a. f a -> g a -> j a -> h ()) | |
-> Rec f as | |
-> Rec g as | |
-> Rec j as | |
-> h () | |
rzipWithM3_ f = go | |
where | |
go :: forall bs. Rec f bs -> Rec g bs -> Rec j bs -> h () | |
go = \case | |
RNil -> \case | |
RNil -> \case | |
RNil -> pure () | |
x :& xs -> \case | |
y :& ys -> \case | |
z :& zs -> f x y z *> go xs ys zs | |
rHead :: Rec f '[a] -> f a | |
rHead (x :& RNil) = x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment