Skip to content

Instantly share code, notes, and snippets.

@mstksg
Last active January 31, 2018 07:48
Show Gist options
  • Save mstksg/05a1aca5f45ff311e522848e2be9e3a3 to your computer and use it in GitHub Desktop.
Save mstksg/05a1aca5f45ff311e522848e2be9e3a3 to your computer and use it in GitHub Desktop.
polymorphic backprop
#!/usr/bin/env stack
-- stack --install-ghc runghc --package type-combinators --package ad --package lens --package vector --package reflection -- -Wall
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
import Control.Applicative.Backwards
import Control.DeepSeq
import Control.Exception
import Control.Lens hiding ((:<), traverse1, Index)
import Control.Monad
import Control.Monad.Primitive
import Data.IORef
import Data.Kind
import Data.Primitive.MutVar
import Data.Proxy
import Data.Reflection
import Data.Type.Combinator
import Data.Type.Conjunction
import Data.Type.Equality
import Data.Type.Index
import Data.Type.Product
import GHC.Generics (Generic)
import Numeric.AD hiding (Grad)
import System.IO.Unsafe
import Type.Class.Higher
import Type.Class.Witness
import Type.Reflection
import Unsafe.Coerce
import qualified Data.Sequence as Seq
import qualified Data.Vector.Mutable as MV
-- Iteration 11: It is all implicit!
zipProd
:: (Prod f :&: Prod g) as
-> Prod (f :&: g) as
zipProd = \case
Ø :&: Ø -> Ø
(x :< xs) :&: (y :< ys) -> (x :&: y) :< zipProd (xs :&: ys)
data Op as a = Op { runOp :: Tuple as -> (a, a -> Tuple as) }
data Var s a = VInp
| VIx Int
| VConst a
deriving Generic
instance NFData a => NFData (Var s a)
data InpRef :: Type -> Type -> Type where
IR :: Num a
=> { _irIx :: Var s b
, _irTR :: TypeRep b
, _irUpd :: Lens' b a
}
-> InpRef s a
data TapeNode :: Type -> Type -> Type where
TN :: { _tnInputs :: Prod (InpRef s) as
, _tnOp :: Op as a
}
-> TapeNode s a
data SomeTapeNode :: Type -> Type where
STN :: forall s a. Num a
=> TypeRep a
-> TapeNode s a
-> SomeTapeNode s
data Builder = forall s. B { bRef :: IORef (Seq.Seq (SomeTapeNode s)) }
initBuilder :: IO Builder
initBuilder = B <$> newIORef Seq.empty
insertNode
:: (Num a, Typeable a)
=> TapeNode s a
-> Builder
-> IO (Var s a)
insertNode tn (B (unsafeCoerce->t)) = fmap VIx . atomicModifyIORef t $ \s ->
(s Seq.|> STN typeRep tn, Seq.length s)
op0 :: a -> Op '[] a
op0 x = Op $ \_ -> (x, const Ø)
op1 :: Num a => (forall b. Num b => b -> b) -> Op '[a] a
op1 f = Op $ \(x ::< Ø) -> let (y, dx) = diff' f x
in (y, (::< Ø) . (*dx))
op2 :: Num a => (forall b. Num b => b -> b -> b) -> Op '[a,a] a
op2 f = Op $ \(x ::< y ::< Ø) -> let (z, [dx,dy]) = grad' (\[x',y'] -> f x' y') [x,y]
in (z, \d -> (d*dx) ::< (d*dy) ::< Ø)
idOp :: Num a => Op '[a] a
idOp = Op $ \(x ::< Ø) -> (x, \d -> d ::< Ø)
registerOut
:: (Reifies s Builder, Num a, Typeable a)
=> Var s a
-> IO ()
registerOut = void . lift1 idOp . join seq
liftOp
:: forall s as b. (Reifies s Builder, Num b, Typeable b, Every Typeable as, Every Num as)
=> Op as b
-> Prod (Var s) as
-> IO (Var s b)
liftOp o !vs = insertNode tn (reflect (Proxy @s))
where
tn = TN { _tnInputs = imap1 go vs
, _tnOp = o
}
go :: forall a. Index as a -> Var s a -> InpRef s a
go i !v = IR v typeRep id \\ every @_ @Num i
\\ every @_ @Typeable i
lift1
:: (Reifies s Builder, Num a, Typeable a, Num b, Typeable b)
=> Op '[a] b
-> Var s a
-> IO (Var s b)
lift1 o !x = liftOp o (x :< Ø)
lift2
:: (Reifies s Builder, Num a, Typeable a, Num b, Typeable b, Num c, Typeable c)
=> Op '[a,b] c
-> Var s a
-> Var s b
-> IO (Var s c)
lift2 o !x !y = liftOp o (x :< y :< Ø)
varLens
:: forall a b s. (Reifies s Builder, Num b, Typeable b, Num a, Typeable a)
=> Lens' b a
-> Var s b
-> IO (Var s a)
varLens l v = insertNode tn (reflect (Proxy @s))
where
tn = TN { _tnInputs = IR v typeRep l :< Ø
, _tnOp = idOp
}
(.$)
:: forall a b s. (Reifies s Builder, Num b, Typeable b, Num a, Typeable a)
=> Lens' b a
-> Var s b
-> Var s a
l .$ (!v) = unsafePerformIO (varLens l v)
instance (Reifies s Builder, Num a, Typeable a) => Num (Var s a) where
x + y = unsafePerformIO $ lift2 (op2 (+)) x y
x - y = unsafePerformIO $ lift2 (op2 (-)) x y
x * y = unsafePerformIO $ lift2 (op2 (*)) x y
negate = unsafePerformIO . lift1 (op1 negate)
abs = unsafePerformIO . lift1 (op1 abs )
signum = unsafePerformIO . lift1 (op1 signum)
fromInteger = VConst . fromInteger
data Grad :: Type -> Type where
G :: Prod (InpRef t) as
-> TypeRep a
-> (a -> Tuple as)
-> Grad t
data SomeNum :: Type where
SN :: Num a
=> TypeRep a
-> a
-> SomeNum
data Runner t s = R { _rRes :: MV.MVector s SomeNum
, _rDelta :: MV.MVector s SomeNum
, _rGrads :: MV.MVector s (Grad t)
}
initRunner
:: (PrimMonad m, PrimState m ~ s)
=> Seq.Seq (SomeTapeNode t)
-> m (Runner t s)
initRunner stns = R <$> MV.new n <*> MV.new n <*> MV.new n
where
n = Seq.length stns
evalRunner
:: forall m a b s t. (PrimMonad m, PrimState m ~ s, Typeable a, Typeable b)
=> Runner t s
-> Seq.Seq (SomeTapeNode t)
-> a
-> m b
evalRunner R{..} stn x = do
ifor_ stn $ \i (STN tr TN{..}) -> do
inps <- traverse1 (fmap I . findInps) _tnInputs
let (res, gr) = runOp _tnOp inps
MV.write _rRes i $ SN tr res
MV.write _rDelta i $ SN tr 0
MV.write _rGrads i $ G _tnInputs tr gr
SN tr y <- MV.read _rRes (MV.length _rRes - 1)
case testEquality (typeRep @b) tr of
Just Refl -> return y
Nothing -> error "The tape produced something of a different type than expected"
where
findInps :: forall x. InpRef t x -> m x
findInps (IR v tr ln) = do
targ <- case v of
VInp -> return $ case testEquality (typeRep @a) tr of
Just Refl -> x
Nothing -> error "Something referred to input but expected wrong type"
VIx i -> do
SN tr' y <- MV.read _rRes i
return $ case testEquality tr tr' of
Just Refl -> y
Nothing -> error "Something referred to internal node but it was wrong type"
VConst y -> return y
return $ targ ^. ln
gradRunner
:: forall m a b s t. (PrimMonad m, PrimState m ~ s, Typeable a, Num b)
=> TypeRep b
-> Runner t s
-> Seq.Seq (SomeTapeNode t)
-> MutVar s a
-> m ()
gradRunner trOut R{..} stns dx = do
MV.write _rDelta (n - 1) (SN trOut 1)
forwards . ifor_ stns $ \i (STN _ TN{..}) -> Backwards $ do
SN tr1 delt <- MV.read _rDelta i
G irs tr2 f <- MV.read _rGrads i
Refl <- return $ case testEquality tr1 tr2 of
Just r -> r
Nothing -> error "Gradient function needed but delta type not matched."
let gs = f delt
void . traverse1 (fmap (const Proxy) . go) $ zipProd (irs :&: gs)
return undefined
where
n = Seq.length stns
go :: forall x. (InpRef t :&: I) x -> m ()
go (IR v tr ln :&: I d) = case v of
VInp -> case testEquality (typeRep @a) tr of
Just Refl -> modifyMutVar dx (ln +~ d)
Nothing -> error "Tried to modify gradient of input node but it was the wrong type"
VIx i -> flip (MV.modify _rDelta) i $ \case
SN tr' y -> case testEquality tr tr' of
Just Refl -> SN tr' $ y & ln +~ d
Nothing -> error "Tried to modify gradient of node but it was the wrong type"
VConst _ -> return ()
backprop'
:: forall a b. (Num a, Typeable a, Num b, Typeable b)
=> (forall s. Reifies s Builder => Var s a -> Var s b)
-> a
-> (b, a)
backprop' f x = unsafePerformIO $ do
b <- initBuilder
reify b $ \(Proxy :: Proxy s) -> do
registerOut =<< evaluate (f VInp :: Var s b)
B{..} <- return b
tp <- readIORef bRef
r <- initRunner tp
y :: b <- evalRunner r tp x
o <- newMutVar (0 :: a)
gradRunner (typeRep @b) r tp o
(y,) <$> readMutVar o
-- f(x,y) = (2x + y)^2
-- df/dx = 8x + 4y
-- df/dy = 4x + 2y
main :: IO ()
main = print $ backprop' (\x -> (2*(_1 .$ x) + (_2 .$ x))^(2 :: Integer))
(2 :: Double, 3 :: Double)
instance (Num a, Num b) => Num (a, b) where
(x1,y1) + (x2,y2) = (x1 + x2, y1 + y2)
(x1,y1) * (x2,y2) = (x1 * x2, y1 * y2)
(x1,y1) - (x2,y2) = (x1 - x2, y1 - y2)
abs (x, y) = (abs x, abs y)
signum (x, y) = (signum x, signum y)
fromInteger x = (fromInteger x, fromInteger x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment