Last active
January 31, 2018 07:48
-
-
Save mstksg/05a1aca5f45ff311e522848e2be9e3a3 to your computer and use it in GitHub Desktop.
polymorphic backprop
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env stack | |
| -- stack --install-ghc runghc --package type-combinators --package ad --package lens --package vector --package reflection -- -Wall | |
| {-# LANGUAGE BangPatterns #-} | |
| {-# LANGUAGE DataKinds #-} | |
| {-# LANGUAGE DeriveGeneric #-} | |
| {-# LANGUAGE ExistentialQuantification #-} | |
| {-# LANGUAGE FlexibleContexts #-} | |
| {-# LANGUAGE GADTs #-} | |
| {-# LANGUAGE KindSignatures #-} | |
| {-# LANGUAGE LambdaCase #-} | |
| {-# LANGUAGE RankNTypes #-} | |
| {-# LANGUAGE RecordWildCards #-} | |
| {-# LANGUAGE ScopedTypeVariables #-} | |
| {-# LANGUAGE TemplateHaskell #-} | |
| {-# LANGUAGE TupleSections #-} | |
| {-# LANGUAGE TypeApplications #-} | |
| {-# LANGUAGE TypeFamilies #-} | |
| {-# LANGUAGE TypeOperators #-} | |
| {-# LANGUAGE UndecidableInstances #-} | |
| {-# LANGUAGE ViewPatterns #-} | |
| {-# OPTIONS_GHC -fno-warn-orphans #-} | |
| import Control.Applicative.Backwards | |
| import Control.DeepSeq | |
| import Control.Exception | |
| import Control.Lens hiding ((:<), traverse1, Index) | |
| import Control.Monad | |
| import Control.Monad.Primitive | |
| import Data.IORef | |
| import Data.Kind | |
| import Data.Primitive.MutVar | |
| import Data.Proxy | |
| import Data.Reflection | |
| import Data.Type.Combinator | |
| import Data.Type.Conjunction | |
| import Data.Type.Equality | |
| import Data.Type.Index | |
| import Data.Type.Product | |
| import GHC.Generics (Generic) | |
| import Numeric.AD hiding (Grad) | |
| import System.IO.Unsafe | |
| import Type.Class.Higher | |
| import Type.Class.Witness | |
| import Type.Reflection | |
| import Unsafe.Coerce | |
| import qualified Data.Sequence as Seq | |
| import qualified Data.Vector.Mutable as MV | |
| -- Iteration 11: It is all implicit! | |
| zipProd | |
| :: (Prod f :&: Prod g) as | |
| -> Prod (f :&: g) as | |
| zipProd = \case | |
| Ø :&: Ø -> Ø | |
| (x :< xs) :&: (y :< ys) -> (x :&: y) :< zipProd (xs :&: ys) | |
| data Op as a = Op { runOp :: Tuple as -> (a, a -> Tuple as) } | |
| data Var s a = VInp | |
| | VIx Int | |
| | VConst a | |
| deriving Generic | |
| instance NFData a => NFData (Var s a) | |
| data InpRef :: Type -> Type -> Type where | |
| IR :: Num a | |
| => { _irIx :: Var s b | |
| , _irTR :: TypeRep b | |
| , _irUpd :: Lens' b a | |
| } | |
| -> InpRef s a | |
| data TapeNode :: Type -> Type -> Type where | |
| TN :: { _tnInputs :: Prod (InpRef s) as | |
| , _tnOp :: Op as a | |
| } | |
| -> TapeNode s a | |
| data SomeTapeNode :: Type -> Type where | |
| STN :: forall s a. Num a | |
| => TypeRep a | |
| -> TapeNode s a | |
| -> SomeTapeNode s | |
| data Builder = forall s. B { bRef :: IORef (Seq.Seq (SomeTapeNode s)) } | |
| initBuilder :: IO Builder | |
| initBuilder = B <$> newIORef Seq.empty | |
| insertNode | |
| :: (Num a, Typeable a) | |
| => TapeNode s a | |
| -> Builder | |
| -> IO (Var s a) | |
| insertNode tn (B (unsafeCoerce->t)) = fmap VIx . atomicModifyIORef t $ \s -> | |
| (s Seq.|> STN typeRep tn, Seq.length s) | |
| op0 :: a -> Op '[] a | |
| op0 x = Op $ \_ -> (x, const Ø) | |
| op1 :: Num a => (forall b. Num b => b -> b) -> Op '[a] a | |
| op1 f = Op $ \(x ::< Ø) -> let (y, dx) = diff' f x | |
| in (y, (::< Ø) . (*dx)) | |
| op2 :: Num a => (forall b. Num b => b -> b -> b) -> Op '[a,a] a | |
| op2 f = Op $ \(x ::< y ::< Ø) -> let (z, [dx,dy]) = grad' (\[x',y'] -> f x' y') [x,y] | |
| in (z, \d -> (d*dx) ::< (d*dy) ::< Ø) | |
| idOp :: Num a => Op '[a] a | |
| idOp = Op $ \(x ::< Ø) -> (x, \d -> d ::< Ø) | |
| registerOut | |
| :: (Reifies s Builder, Num a, Typeable a) | |
| => Var s a | |
| -> IO () | |
| registerOut = void . lift1 idOp . join seq | |
| liftOp | |
| :: forall s as b. (Reifies s Builder, Num b, Typeable b, Every Typeable as, Every Num as) | |
| => Op as b | |
| -> Prod (Var s) as | |
| -> IO (Var s b) | |
| liftOp o !vs = insertNode tn (reflect (Proxy @s)) | |
| where | |
| tn = TN { _tnInputs = imap1 go vs | |
| , _tnOp = o | |
| } | |
| go :: forall a. Index as a -> Var s a -> InpRef s a | |
| go i !v = IR v typeRep id \\ every @_ @Num i | |
| \\ every @_ @Typeable i | |
| lift1 | |
| :: (Reifies s Builder, Num a, Typeable a, Num b, Typeable b) | |
| => Op '[a] b | |
| -> Var s a | |
| -> IO (Var s b) | |
| lift1 o !x = liftOp o (x :< Ø) | |
| lift2 | |
| :: (Reifies s Builder, Num a, Typeable a, Num b, Typeable b, Num c, Typeable c) | |
| => Op '[a,b] c | |
| -> Var s a | |
| -> Var s b | |
| -> IO (Var s c) | |
| lift2 o !x !y = liftOp o (x :< y :< Ø) | |
| varLens | |
| :: forall a b s. (Reifies s Builder, Num b, Typeable b, Num a, Typeable a) | |
| => Lens' b a | |
| -> Var s b | |
| -> IO (Var s a) | |
| varLens l v = insertNode tn (reflect (Proxy @s)) | |
| where | |
| tn = TN { _tnInputs = IR v typeRep l :< Ø | |
| , _tnOp = idOp | |
| } | |
| (.$) | |
| :: forall a b s. (Reifies s Builder, Num b, Typeable b, Num a, Typeable a) | |
| => Lens' b a | |
| -> Var s b | |
| -> Var s a | |
| l .$ (!v) = unsafePerformIO (varLens l v) | |
| instance (Reifies s Builder, Num a, Typeable a) => Num (Var s a) where | |
| x + y = unsafePerformIO $ lift2 (op2 (+)) x y | |
| x - y = unsafePerformIO $ lift2 (op2 (-)) x y | |
| x * y = unsafePerformIO $ lift2 (op2 (*)) x y | |
| negate = unsafePerformIO . lift1 (op1 negate) | |
| abs = unsafePerformIO . lift1 (op1 abs ) | |
| signum = unsafePerformIO . lift1 (op1 signum) | |
| fromInteger = VConst . fromInteger | |
| data Grad :: Type -> Type where | |
| G :: Prod (InpRef t) as | |
| -> TypeRep a | |
| -> (a -> Tuple as) | |
| -> Grad t | |
| data SomeNum :: Type where | |
| SN :: Num a | |
| => TypeRep a | |
| -> a | |
| -> SomeNum | |
| data Runner t s = R { _rRes :: MV.MVector s SomeNum | |
| , _rDelta :: MV.MVector s SomeNum | |
| , _rGrads :: MV.MVector s (Grad t) | |
| } | |
| initRunner | |
| :: (PrimMonad m, PrimState m ~ s) | |
| => Seq.Seq (SomeTapeNode t) | |
| -> m (Runner t s) | |
| initRunner stns = R <$> MV.new n <*> MV.new n <*> MV.new n | |
| where | |
| n = Seq.length stns | |
| evalRunner | |
| :: forall m a b s t. (PrimMonad m, PrimState m ~ s, Typeable a, Typeable b) | |
| => Runner t s | |
| -> Seq.Seq (SomeTapeNode t) | |
| -> a | |
| -> m b | |
| evalRunner R{..} stn x = do | |
| ifor_ stn $ \i (STN tr TN{..}) -> do | |
| inps <- traverse1 (fmap I . findInps) _tnInputs | |
| let (res, gr) = runOp _tnOp inps | |
| MV.write _rRes i $ SN tr res | |
| MV.write _rDelta i $ SN tr 0 | |
| MV.write _rGrads i $ G _tnInputs tr gr | |
| SN tr y <- MV.read _rRes (MV.length _rRes - 1) | |
| case testEquality (typeRep @b) tr of | |
| Just Refl -> return y | |
| Nothing -> error "The tape produced something of a different type than expected" | |
| where | |
| findInps :: forall x. InpRef t x -> m x | |
| findInps (IR v tr ln) = do | |
| targ <- case v of | |
| VInp -> return $ case testEquality (typeRep @a) tr of | |
| Just Refl -> x | |
| Nothing -> error "Something referred to input but expected wrong type" | |
| VIx i -> do | |
| SN tr' y <- MV.read _rRes i | |
| return $ case testEquality tr tr' of | |
| Just Refl -> y | |
| Nothing -> error "Something referred to internal node but it was wrong type" | |
| VConst y -> return y | |
| return $ targ ^. ln | |
| gradRunner | |
| :: forall m a b s t. (PrimMonad m, PrimState m ~ s, Typeable a, Num b) | |
| => TypeRep b | |
| -> Runner t s | |
| -> Seq.Seq (SomeTapeNode t) | |
| -> MutVar s a | |
| -> m () | |
| gradRunner trOut R{..} stns dx = do | |
| MV.write _rDelta (n - 1) (SN trOut 1) | |
| forwards . ifor_ stns $ \i (STN _ TN{..}) -> Backwards $ do | |
| SN tr1 delt <- MV.read _rDelta i | |
| G irs tr2 f <- MV.read _rGrads i | |
| Refl <- return $ case testEquality tr1 tr2 of | |
| Just r -> r | |
| Nothing -> error "Gradient function needed but delta type not matched." | |
| let gs = f delt | |
| void . traverse1 (fmap (const Proxy) . go) $ zipProd (irs :&: gs) | |
| return undefined | |
| where | |
| n = Seq.length stns | |
| go :: forall x. (InpRef t :&: I) x -> m () | |
| go (IR v tr ln :&: I d) = case v of | |
| VInp -> case testEquality (typeRep @a) tr of | |
| Just Refl -> modifyMutVar dx (ln +~ d) | |
| Nothing -> error "Tried to modify gradient of input node but it was the wrong type" | |
| VIx i -> flip (MV.modify _rDelta) i $ \case | |
| SN tr' y -> case testEquality tr tr' of | |
| Just Refl -> SN tr' $ y & ln +~ d | |
| Nothing -> error "Tried to modify gradient of node but it was the wrong type" | |
| VConst _ -> return () | |
| backprop' | |
| :: forall a b. (Num a, Typeable a, Num b, Typeable b) | |
| => (forall s. Reifies s Builder => Var s a -> Var s b) | |
| -> a | |
| -> (b, a) | |
| backprop' f x = unsafePerformIO $ do | |
| b <- initBuilder | |
| reify b $ \(Proxy :: Proxy s) -> do | |
| registerOut =<< evaluate (f VInp :: Var s b) | |
| B{..} <- return b | |
| tp <- readIORef bRef | |
| r <- initRunner tp | |
| y :: b <- evalRunner r tp x | |
| o <- newMutVar (0 :: a) | |
| gradRunner (typeRep @b) r tp o | |
| (y,) <$> readMutVar o | |
| -- f(x,y) = (2x + y)^2 | |
| -- df/dx = 8x + 4y | |
| -- df/dy = 4x + 2y | |
| main :: IO () | |
| main = print $ backprop' (\x -> (2*(_1 .$ x) + (_2 .$ x))^(2 :: Integer)) | |
| (2 :: Double, 3 :: Double) | |
| instance (Num a, Num b) => Num (a, b) where | |
| (x1,y1) + (x2,y2) = (x1 + x2, y1 + y2) | |
| (x1,y1) * (x2,y2) = (x1 * x2, y1 * y2) | |
| (x1,y1) - (x2,y2) = (x1 - x2, y1 - y2) | |
| abs (x, y) = (abs x, abs y) | |
| signum (x, y) = (signum x, signum y) | |
| fromInteger x = (fromInteger x, fromInteger x) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment