Created
February 17, 2018 20:16
-
-
Save dpiponi/c1b53f2b19a96aab521e3166c8816ab7 to your computer and use it in GitHub Desktop.
Infinitely differentiable stochastic functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{- LANGUAGE UnicodeSyntax -} | |
import Prelude hiding (sum) | |
import Control.Monad | |
import qualified System.Random as R | |
import qualified Data.Map.Strict as M | |
-- | |
-- Define formal power series | |
-- I'm just using lists of coefficients rather than defining a new type. | |
-- This is just http://blog.sigfpe.com/2007/11/small-combinatorial-library.html | |
-- stripped down a bit. | |
-- | |
(*!) _ 0 = 0 | |
(*!) a b = a*b | |
(!*) 0 _ = 0 | |
(!*) a b = a*b | |
(^+) a b = zipWith (+) a b | |
(^-) a b = zipWith (-) a b | |
~(a:as) `convolve` (b:bs) = (a *! b): | |
((map (a !*) bs) ^+ (as `convolve` (b:bs))) | |
compose (f:fs) (0:gs) = f:(gs `convolve` (compose fs (0:gs))) | |
inverse (0:f:fs) = x where x = map (recip f *) (0:1:g) | |
_:_:g = map negate (compose (0:0:fs) x) | |
invert x = r where r = map (/x0) ((1:repeat 0) ^- (r `convolve` (0:xs))) | |
x0:xs = x | |
(^/) (0:a) (0:b) = a ^/ b | |
(^/) a b = a `convolve` (invert b) | |
instance (Eq r, Num r) => Num [r] where | |
x+y = zipWith (+) x y | |
x-y = zipWith (-) x y | |
~x*y = x `convolve` y | |
fromInteger x = fromInteger x:repeat 0 | |
negate x = map negate x | |
signum (x:_) = signum x:repeat 0 | |
abs (x:xs) = error "Can't form abs of a power series" | |
instance (Eq r, Fractional r) => Fractional [r] where | |
x/y = x ^/ y | |
fromRational x = fromRational x:repeat 0 | |
lead [] x = x | |
lead (a : as) x = a : (lead as (tail x)) | |
a ... x = lead a x | |
(//) :: Fractional a => [a] -> (Integer -> Bool) -> [a] | |
(//) a c = zipWith (\a-> \b->(if (c a :: Bool) then b else 0)) [(0::Integer)..] a | |
nonEmpty a = a // (/= 0) | |
factorial 0 = 1 | |
factorial n = n*factorial (n-1) | |
count :: Fractional a => Integer -> [a] -> a | |
count n a = (a!!(fromInteger n)) * fromInteger (factorial n) | |
-- ε is an infinitesimal all of whose powers we track | |
ε :: Fractional a => [a] | |
ε = 0 : 1 : repeat 0 | |
-- | |
-- List first 10 derivatives of argument | |
-- | |
derivatives a = map (flip count a) [0..10] | |
-- | |
-- Free monad giving generic probability interface. | |
-- Interpreter below. | |
-- | |
-- See http://blog.sigfpe.com/2017/06/a-relaxation-technique.html | |
-- | |
data Random p a = Pure a | Bernoulli p (Int -> Random p a) | |
instance Functor (Random p) where | |
fmap f (Pure a) = Pure (f a) | |
fmap f (Bernoulli p g) = Bernoulli p (fmap f . g) | |
instance Applicative (Random p) where | |
pure = return | |
(<*>) = ap | |
instance Monad (Random p) where | |
return = Pure | |
Pure a >>= f = f a | |
Bernoulli p g >>= f = Bernoulli p (\x -> g x >>= f) | |
bernoulli :: p -> Random p Int | |
bernoulli p = Bernoulli p return | |
scale :: Num p => p -> [(a, p)] -> [(a, p)] | |
scale s = map (\(a, p) -> (a, s*p)) | |
collect :: (Ord a, Num b) => [(a, b)] -> [(a, b)] | |
collect = M.toList . M.fromListWith (+) | |
-- | |
-- Interpreter for our free monad. | |
-- This provides "weighted importance sampling" semantics with | |
-- a rule to specify how a weight is converted to a probability. | |
-- The interpreter generates random samples that also carry an | |
-- importance. | |
-- For a simple sampling based interpreter the rule can be the | |
-- identity. | |
-- | |
interpret :: (Fractional p, R.RandomGen g) => | |
(p -> Float) -> Random p a -> g -> ((a, p), g) | |
interpret rule (Pure a) g = ((a, 1), g) | |
interpret rule (Bernoulli p f) g = | |
let (r, g') = R.random g | |
-- | |
-- This is the key line where a weight is converted | |
-- into an actual probability. | |
-- This makes it possible for p to be an element of | |
-- some kind of algebraic extension of the reals | |
-- while still giving meaningful probabilities. | |
-- We adjust the "importance" of the sample accordingly. | |
-- | |
prob = rule p | |
(b, i) = if (r :: Float) <= prob | |
then (1, p/realToFrac prob) | |
else (0, (1-p)/realToFrac (1-prob)) | |
((a, i'), g'') = interpret rule (f b) g' | |
in ((a, i*i'), g'') | |
-- | |
-- Compute expected values taking into account weights. | |
-- | |
expect :: (Fractional p, R.RandomGen g) => | |
(p -> Float) -> Random p p -> Int -> g -> (p, g) | |
expect rule r n g = | |
let (x, g') = sum rule 0 r n g | |
in (x/fromIntegral n, g') | |
sum :: (Fractional p, R.RandomGen g) => | |
(p -> Float) -> p -> Random p p -> Int -> g -> (p, g) | |
sum rule t r 0 g = (t, g) | |
sum rule t r n g = | |
let ((a, imp), g') = interpret rule r g | |
in sum rule (t+a*imp) r (n-1) g' | |
-- Example from https://www.arxiv-vanity.com/papers/1802.05098/ | |
-- See section 3.3 "Simple Failing Example" | |
-- | |
-- We're able to correctly differentiate the expected value of | |
-- | |
-- X(1-θ)+(1-X)(1+θ) | |
-- | |
-- arbitrarily many times even though X is sampled from Ber(θ) | |
-- and so only ever takes integer values. | |
-- | |
ex1 θ = do | |
x <- bernoulli θ | |
return $ fromIntegral x*(1-θ)+(1-fromIntegral x)*(1+θ) | |
-- | |
-- Get just the value of a power series evaluated at zero killing | |
-- all of the derivatives. | |
-- | |
(⊥) = head | |
main = do | |
-- | |
-- Evaluate derivative from example at θ = 0.5 | |
-- | |
e <- R.getStdRandom (expect (⊥) (ex1 (0.5+ε)) 1000) | |
print $ derivatives e |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment