Last active
March 22, 2021 12:20
-
-
Save bitonic/640f789a6d879c8186ca71b237e633fa to your computer and use it in GitHub Desktop.
Simple reverse AD
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE FlexibleContexts #-} | |
{-# LANGUAGE UndecidableInstances #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE RecordWildCards #-} | |
{-# LANGUAGE ScopedTypeVariables #-} | |
{-# LANGUAGE TypeApplications #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE BangPatterns #-} | |
{-# OPTIONS_GHC -Wall #-} | |
import Data.IORef | |
import Data.Reflection | |
import qualified Data.Vector as V | |
import Data.Proxy | |
import System.IO.Unsafe (unsafePerformIO) | |
import Control.Exception (evaluate) | |
import qualified Data.Vector.Unboxed.Mutable as VUM | |
import Control.Monad | |
import Data.Foldable | |
import Control.DeepSeq | |
data UnaryOp = | |
Sin | |
| Abs | |
| Signum | |
| Exp | |
| Log | |
deriving (Eq, Show) | |
evalUnaryOp :: Floating a => UnaryOp -> a -> a | |
evalUnaryOp op = case op of | |
Sin -> sin | |
Abs -> abs | |
Signum -> signum | |
Exp -> exp | |
Log -> log | |
unaryGradientWeight :: Floating a => UnaryOp -> a -> a | |
unaryGradientWeight op x = case op of | |
Sin -> cos x | |
Abs -> abs x / x | |
Signum -> 0 | |
Exp -> exp x | |
Log -> 1 / x | |
data BinaryOp = | |
Plus | |
| Times | |
| Divide | |
deriving (Eq, Show) | |
evalBinaryOp :: Fractional a => BinaryOp -> a -> a -> a | |
evalBinaryOp op = case op of | |
Plus -> (+) | |
Times -> (*) | |
Divide -> (/) | |
binaryGradientWeights :: Floating a => BinaryOp -> a -> a -> (a, a) | |
binaryGradientWeights op x y = case op of | |
Plus -> (1, 1) | |
Times -> (y, x) | |
Divide -> (1 / y, - x / (y * y)) | |
data Cells a = | |
Nil | |
| Lift | |
{ _liftTail :: Cells a | |
} | |
| Unary | |
{ _unaryIndex :: {-# UNPACK #-} !Int | |
, _unaryValue :: a | |
, _unaryOp :: {-# UNPACK #-} !UnaryOp | |
, _unaryTail :: Cells a | |
} | |
| Binary | |
{ _binaryIndex1 :: {-# UNPACK #-} !Int | |
, _binaryValue1 :: a | |
, _binaryIndex2 :: {-# UNPACK #-} !Int | |
, _binaryValue2 :: a | |
, _binaryOp :: {-# UNPACK #-} !BinaryOp | |
, _binaryTail :: Cells a | |
} | |
deriving (Eq, Show) | |
data Head a = Head | |
{ headCounter :: {-# UNPACK #-} !Int | |
, headCells :: Cells a | |
} deriving (Eq, Show) | |
newtype Tape a = Tape { unTape :: IORef (Head a) } | |
newTape :: Int -> IO (Tape a) | |
newTape numVars = Tape <$> newIORef (Head numVars Nil) | |
data Reverse s a = Reverse | |
{ _index :: {-# UNPACK #-} !Int | |
, _value :: a | |
} deriving (Eq, Show) | |
newReflect :: Reifies s (Tape a) => Proxy s -> (Cells a -> Cells a) -> Int | |
newReflect p cell = unsafePerformIO (atomicModifyIORef (unTape (reflect p)) modifyHead) | |
where | |
modifyHead (Head count cells) = head' `seq` count' `seq` (head', count) | |
where | |
count' = count+1 | |
head' = Head count' (cell cells) | |
lift :: Reifies s (Tape a) => Proxy s -> a -> Reverse s a | |
lift p !x = Reverse (newReflect p Lift) x | |
binary :: | |
forall s a. | |
(Floating a, Reifies s (Tape a)) | |
=> BinaryOp | |
-> Reverse s a | |
-> Reverse s a | |
-> Reverse s a | |
binary op (Reverse ixx x) (Reverse ixy y) = | |
Reverse (newReflect (Proxy @s) (Binary ixx x ixy y op)) $! evalBinaryOp op x y | |
unary :: | |
forall s a. | |
(Floating a, Reifies s (Tape a)) | |
=> UnaryOp | |
-> Reverse s a | |
-> Reverse s a | |
unary op (Reverse ix x) = Reverse (newReflect (Proxy @s) (Unary ix x op)) $! evalUnaryOp op x | |
instance (Reifies s (Tape a), Floating a) => Num (Reverse s a) where | |
fromInteger = lift (Proxy @s) . fromInteger | |
(+) = binary Plus | |
(*) = binary Times | |
abs = unary Abs | |
signum = unary Signum | |
negate x = x * lift (Proxy @s) (-1) | |
instance (Reifies s (Tape a), Floating a) => Fractional (Reverse s a) where | |
fromRational = lift (Proxy @s) . fromRational | |
(/) = binary Divide | |
instance (Reifies s (Tape a), Floating a) => Floating (Reverse s a) where | |
pi = lift (Proxy @s) pi | |
sin = unary Sin | |
exp = unary Exp | |
log = unary Log | |
data Cell a = | |
Var' | |
| Lift' | |
| Unary' {-# UNPACK #-} !Int a {-# UNPACK #-} !UnaryOp | |
| Binary' {-# UNPACK #-} !Int a {-# UNPACK #-} !Int a {-# UNPACK #-} !BinaryOp | |
deriving (Eq, Show) | |
-- | Given a function $f : R^n -> R^m$, gives us the result, and the Jacobian | |
grad :: | |
forall a. | |
(VUM.Unbox a, Floating a, NFData a) | |
=> (forall s. Reifies s (Tape a) => [Reverse s a] -> [Reverse s a]) | |
-> [a] | |
-> ([a], [[a]]) | |
grad f args = unsafePerformIO $ do | |
when (length args == 0) $ | |
error "No variables provided" | |
tape <- newTape (length args) | |
(result, resultIndices) <- | |
evaluate (force (reify tape (\(_ :: Proxy s) -> unzip (map (\Reverse{..} -> (_value, _index)) (f (zipWith (Reverse @s) [0..] args)))))) | |
head' <- readIORef (unTape tape) | |
-- For every index, get the row vector of the jacobian | |
let | |
getGradients :: Int -> IO [a] | |
getGradients resultIndex = do | |
let cells = V.fromList (replicate (length args) Var' ++ reverse (go (headCells head'))) | |
gradients :: VUM.IOVector a <- VUM.new (V.length cells) | |
for_ [0..V.length cells - 1] (\ix -> VUM.write gradients ix 0) | |
-- propagate backwards, fixing the result to gradient 1 | |
VUM.write gradients resultIndex 1 | |
for_ (reverse [0 .. V.length cells - 1]) $ \cellIx -> do | |
let cell = cells V.! cellIx | |
cellGrad <- VUM.read gradients cellIx | |
case cell of | |
Lift'{} -> return () | |
Var'{} -> return () | |
Unary' ix x op -> do | |
VUM.modify gradients ((cellGrad * unaryGradientWeight op x) +) ix | |
Binary' ixx x ixy y op -> do | |
let (gradwx, gradwy) = binaryGradientWeights op x y | |
VUM.modify gradients ((cellGrad * gradwx) +) ixx | |
VUM.modify gradients ((cellGrad * gradwy) +) ixy | |
mapM (VUM.read gradients) [0..length args-1] | |
jacobian <- mapM getGradients resultIndices | |
return (result, jacobian) | |
where | |
go = \case | |
Nil -> [] | |
Lift cells -> Lift' : go cells | |
Unary{..} -> Unary' _unaryIndex _unaryValue _unaryOp : go _unaryTail | |
Binary{..} -> Binary' _binaryIndex1 _binaryValue1 _binaryIndex2 _binaryValue2 _binaryOp : go _binaryTail |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment