Skip to content

Instantly share code, notes, and snippets.

@dpiponi
Created February 17, 2018 20:16
Show Gist options
  • Save dpiponi/c1b53f2b19a96aab521e3166c8816ab7 to your computer and use it in GitHub Desktop.
Save dpiponi/c1b53f2b19a96aab521e3166c8816ab7 to your computer and use it in GitHub Desktop.
Infinitely differentiable stochastic functions
{- LANGUAGE UnicodeSyntax -}
import Prelude hiding (sum)
import Control.Monad
import qualified System.Random as R
import qualified Data.Map.Strict as M
--
-- Define formal power series
-- I'm just using lists of coefficients rather than defining a new type.
-- This is just http://blog.sigfpe.com/2007/11/small-combinatorial-library.html
-- stripped down a bit.
--
(*!) _ 0 = 0
(*!) a b = a*b
(!*) 0 _ = 0
(!*) a b = a*b
(^+) a b = zipWith (+) a b
(^-) a b = zipWith (-) a b
~(a:as) `convolve` (b:bs) = (a *! b):
((map (a !*) bs) ^+ (as `convolve` (b:bs)))
compose (f:fs) (0:gs) = f:(gs `convolve` (compose fs (0:gs)))
inverse (0:f:fs) = x where x = map (recip f *) (0:1:g)
_:_:g = map negate (compose (0:0:fs) x)
invert x = r where r = map (/x0) ((1:repeat 0) ^- (r `convolve` (0:xs)))
x0:xs = x
(^/) (0:a) (0:b) = a ^/ b
(^/) a b = a `convolve` (invert b)
instance (Eq r, Num r) => Num [r] where
x+y = zipWith (+) x y
x-y = zipWith (-) x y
~x*y = x `convolve` y
fromInteger x = fromInteger x:repeat 0
negate x = map negate x
signum (x:_) = signum x:repeat 0
abs (x:xs) = error "Can't form abs of a power series"
instance (Eq r, Fractional r) => Fractional [r] where
x/y = x ^/ y
fromRational x = fromRational x:repeat 0
lead [] x = x
lead (a : as) x = a : (lead as (tail x))
a ... x = lead a x
(//) :: Fractional a => [a] -> (Integer -> Bool) -> [a]
(//) a c = zipWith (\a-> \b->(if (c a :: Bool) then b else 0)) [(0::Integer)..] a
nonEmpty a = a // (/= 0)
factorial 0 = 1
factorial n = n*factorial (n-1)
count :: Fractional a => Integer -> [a] -> a
count n a = (a!!(fromInteger n)) * fromInteger (factorial n)
-- ε is an infinitesimal all of whose powers we track
ε :: Fractional a => [a]
ε = 0 : 1 : repeat 0
--
-- List first 10 derivatives of argument
--
derivatives a = map (flip count a) [0..10]
--
-- Free monad giving generic probability interface.
-- Interpreter below.
--
-- See http://blog.sigfpe.com/2017/06/a-relaxation-technique.html
--
data Random p a = Pure a | Bernoulli p (Int -> Random p a)
instance Functor (Random p) where
fmap f (Pure a) = Pure (f a)
fmap f (Bernoulli p g) = Bernoulli p (fmap f . g)
instance Applicative (Random p) where
pure = return
(<*>) = ap
instance Monad (Random p) where
return = Pure
Pure a >>= f = f a
Bernoulli p g >>= f = Bernoulli p (\x -> g x >>= f)
bernoulli :: p -> Random p Int
bernoulli p = Bernoulli p return
scale :: Num p => p -> [(a, p)] -> [(a, p)]
scale s = map (\(a, p) -> (a, s*p))
collect :: (Ord a, Num b) => [(a, b)] -> [(a, b)]
collect = M.toList . M.fromListWith (+)
--
-- Interpreter for our free monad.
-- This provides "weighted importance sampling" semantics with
-- a rule to specify how a weight is converted to a probability.
-- The interpreter generates random samples that also carry an
-- importance.
-- For a simple sampling based interpreter the rule can be the
-- identity.
--
interpret :: (Fractional p, R.RandomGen g) =>
(p -> Float) -> Random p a -> g -> ((a, p), g)
interpret rule (Pure a) g = ((a, 1), g)
interpret rule (Bernoulli p f) g =
let (r, g') = R.random g
--
-- This is the key line where a weight is converted
-- into an actual probability.
-- This makes it possible for p to be an element of
-- some kind of algebraic extension of the reals
-- while still giving meaningful probabilities.
-- We adjust the "importance" of the sample accordingly.
--
prob = rule p
(b, i) = if (r :: Float) <= prob
then (1, p/realToFrac prob)
else (0, (1-p)/realToFrac (1-prob))
((a, i'), g'') = interpret rule (f b) g'
in ((a, i*i'), g'')
--
-- Compute expected values taking into account weights.
--
expect :: (Fractional p, R.RandomGen g) =>
(p -> Float) -> Random p p -> Int -> g -> (p, g)
expect rule r n g =
let (x, g') = sum rule 0 r n g
in (x/fromIntegral n, g')
sum :: (Fractional p, R.RandomGen g) =>
(p -> Float) -> p -> Random p p -> Int -> g -> (p, g)
sum rule t r 0 g = (t, g)
sum rule t r n g =
let ((a, imp), g') = interpret rule r g
in sum rule (t+a*imp) r (n-1) g'
-- Example from https://www.arxiv-vanity.com/papers/1802.05098/
-- See section 3.3 "Simple Failing Example"
--
-- We're able to correctly differentiate the expected value of
--
-- X(1-θ)+(1-X)(1+θ)
--
-- arbitrarily many times even though X is sampled from Ber(θ)
-- and so only ever takes integer values.
--
ex1 θ = do
x <- bernoulli θ
return $ fromIntegral x*(1-θ)+(1-fromIntegral x)*(1+θ)
--
-- Get just the value of a power series evaluated at zero killing
-- all of the derivatives.
--
(⊥) = head
main = do
--
-- Evaluate derivative from example at θ = 0.5
--
e <- R.getStdRandom (expect (⊥) (ex1 (0.5+ε)) 1000)
print $ derivatives e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment