Created
October 27, 2016 00:57
-
-
Save jtobin/497e688359c17d1fdf9215868a300b55 to your computer and use it in GitHub Desktop.
Probabilistic programming using comonads.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE BangPatterns #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE RecordWildCards #-} | |
import Control.Comonad | |
import Control.Comonad.Cofree | |
import Control.Monad | |
import Control.Monad.ST | |
import Control.Monad.Free | |
import Data.Bits | |
import Data.Dynamic | |
import Data.Maybe | |
import Data.Void | |
import qualified Data.Vector as V | |
import Data.Word | |
import qualified System.Random.MWC as MWC | |
import System.Random.MWC.Probability (Prob) | |
import qualified System.Random.MWC.Probability as Prob | |
data ModelF a r = | |
BernoulliF Double (Bool -> r) | |
| BetaF Double Double (Double -> r) | |
| NormalF Double Double (Double -> r) | |
| DiracF a | |
deriving Functor | |
type Program a = Free (ModelF a) | |
type Model b = forall a. Program a b | |
type Terminating a = Program a Void | |
type Execution a = Cofree (ModelF a) Node | |
data Node = Node { | |
nodeCost :: Double | |
, nodeValue :: Dynamic | |
, nodeSeed :: MWC.Seed | |
, nodeHistory :: [Dynamic] | |
} deriving Show | |
-- primitive terms ------------------------------------------------------------ | |
beta :: Double -> Double -> Program a Double | |
beta a b = liftF (BetaF a b id) | |
bernoulli :: Double -> Program a Bool | |
bernoulli p = liftF (BernoulliF p id) | |
normal :: Double -> Double -> Program a Double | |
normal m s = liftF (NormalF m s id) | |
dirac :: a -> Program a b | |
dirac x = liftF (DiracF x) | |
-- sampling ------------------------------------------------------------------- | |
toSampler :: Program a a -> Prob IO a | |
toSampler = iterM $ \case | |
BernoulliF p f -> Prob.bernoulli p >>= f | |
BetaF a b f -> Prob.beta a b >>= f | |
NormalF m s f -> Prob.normal m s >>= f | |
DiracF x -> return x | |
simulate :: Prob IO a -> IO a | |
simulate model = MWC.withSystemRandom . MWC.asGenIO $ Prob.sample model | |
-- densities ------------------------------------------------------------------ | |
logDensityBernoulli :: Double -> Bool -> Double | |
logDensityBernoulli p x | |
| p < 0 || p > 1 = log 0 | |
| otherwise = b * log p + (1 - b) * log (1 - p) | |
where | |
b = if x then 1 else 0 | |
logDensityBeta :: Double -> Double -> Double -> Double | |
logDensityBeta a b x | |
| x <= 0 || x >= 1 = log 0 | |
| a < 0 || b < 0 = log 0 | |
| otherwise = (a - 1) * log x + (b - 1) * log (1 - x) | |
logDensityNormal :: Double -> Double -> Double -> Double | |
logDensityNormal m s x | |
| s <= 0 = log 0 | |
| otherwise = negate (log s) - (x - m) ^ 2 / (2 * s ^ 2) | |
logDensityDirac :: Eq a => a -> a -> Double | |
logDensityDirac a x | |
| a == x = 0 | |
| otherwise = negate (1 / 0) | |
-- execution: initializing ---------------------------------------------------- | |
execute :: Typeable a => Terminating a -> Execution a | |
execute = executeGeneric (42, 108512) | |
executeGeneric | |
:: Typeable a => (Word32, Word32) -> Terminating a -> Execution a | |
executeGeneric = annotate where | |
annotate seeds term = case term of | |
Pure r -> absurd r | |
Free instruction -> | |
let (nextSeeds, generator) = xorshift seeds | |
seed = MWC.toSeed (V.singleton generator) | |
node = initialize seed instruction | |
in node :< fmap (annotate nextSeeds) instruction | |
samplePurely | |
:: Typeable a => Prob (ST s) a -> Prob.Seed -> ST s (Dynamic, Prob.Seed) | |
samplePurely prog seed = do | |
prng <- MWC.restore seed | |
value <- MWC.asGenST (Prob.sample prog) prng | |
nodeSeed <- MWC.save prng | |
if seed == nodeSeed | |
then error "a generator failed to step!" | |
else return (toDyn value, nodeSeed) | |
initialize :: Typeable a => MWC.Seed -> ModelF a b -> Node | |
initialize seed = \case | |
BernoulliF p _ -> runST $ do | |
(nodeValue, nodeSeed) <- samplePurely (Prob.bernoulli p) seed | |
let nodeCost = logDensityBernoulli p (unsafeFromDyn nodeValue) | |
nodeHistory = mempty | |
return Node {..} | |
BetaF a b _ -> runST $ do | |
(nodeValue, nodeSeed) <- samplePurely (Prob.beta a b) seed | |
let nodeCost = logDensityBeta a b (unsafeFromDyn nodeValue) | |
nodeHistory = mempty | |
return Node {..} | |
NormalF m s _ -> runST $ do | |
(nodeValue, nodeSeed) <- samplePurely (Prob.normal m s) seed | |
let nodeCost = logDensityNormal m s (unsafeFromDyn nodeValue) | |
nodeHistory = mempty | |
return Node {..} | |
DiracF a -> Node 0 (toDyn a) seed mempty | |
-- execution: scoring and running --------------------------------------------- | |
score :: Execution a -> Double | |
score = loop 0 where | |
loop !acc (Node {..} :< cons) = case cons of | |
BernoulliF _ k -> loop (acc + nodeCost) (k (unsafeFromDyn nodeValue)) | |
BetaF _ _ k -> loop (acc + nodeCost) (k (unsafeFromDyn nodeValue)) | |
NormalF _ _ k -> loop (acc + nodeCost) (k (unsafeFromDyn nodeValue)) | |
DiracF _ -> acc | |
depth :: Execution a -> Int | |
depth = loop 0 where | |
loop !acc (Node {..} :< cons) = case cons of | |
BernoulliF _ k -> loop (succ acc) (k (unsafeFromDyn nodeValue)) | |
BetaF _ _ k -> loop (succ acc) (k (unsafeFromDyn nodeValue)) | |
NormalF _ _ k -> loop (succ acc) (k (unsafeFromDyn nodeValue)) | |
DiracF _ -> succ acc | |
step :: Typeable a => Execution a -> Execution a | |
step prog@(Node {..} :< _) = stepWithInput nodeValue prog | |
stepWithInput :: Typeable a => Dynamic -> Execution a -> Execution a | |
stepWithInput value prog = case unwrap prog of | |
BernoulliF _ k -> k (unsafeFromDyn value) | |
BetaF _ _ k -> k (unsafeFromDyn value) | |
NormalF _ _ k -> k (unsafeFromDyn value) | |
DiracF _ -> prog | |
run :: Typeable a => Execution a -> a | |
run prog = case unwrap prog of | |
DiracF a -> a | |
_ -> run (step prog) | |
runWithInput :: Typeable a => Dynamic -> Execution a -> a | |
runWithInput value = run . stepWithInput value | |
stepGenerators :: Functor f => Cofree f Node -> Cofree f Node | |
stepGenerators = extend stepGenerator | |
stepGenerator :: Cofree f Node -> Node | |
stepGenerator (Node {..} :< cons) = runST $ do | |
(_, nseed) <- samplePurely (Prob.beta 1 1) nodeSeed | |
return Node {nodeSeed = nseed, ..} | |
-- mcmc: perturb -------------------------------------------------------------- | |
perturb :: Execution a -> Execution a | |
perturb = extend perturbNode | |
perturbNode :: Execution a -> Node | |
perturbNode (node@Node {..} :< cons) = case cons of | |
BernoulliF p _ -> runST $ do | |
(nvalue, nseed) <- samplePurely (Prob.bernoulli p) nodeSeed | |
let nscore = logDensityBernoulli p (unsafeFromDyn nvalue) | |
return $! Node nscore nvalue nseed nodeHistory | |
BetaF a b _ -> runST $ do | |
(nvalue, nseed) <- samplePurely (Prob.beta a b) nodeSeed | |
let nscore = logDensityBeta a b (unsafeFromDyn nvalue) | |
return $! Node nscore nvalue nseed nodeHistory | |
NormalF m s _ -> runST $ do | |
(nvalue, nseed) <- samplePurely (Prob.normal m s) nodeSeed | |
let nscore = logDensityNormal m s (unsafeFromDyn nvalue) | |
return $! Node nscore nvalue nseed nodeHistory | |
DiracF a -> node | |
-- mcmc: markov chain --------------------------------------------------------- | |
invert | |
:: (Eq a, Typeable a, Typeable b) | |
=> Int -> [a] -> Model b -> (b -> a -> Double) | |
-> Model (Execution b) | |
invert epochs obs prior ll = loop epochs (execute (prior >>= dirac)) where | |
loop n current | |
| n == 0 = return current | |
| otherwise = do | |
let proposal = perturb current | |
valueAtCurrent = run current | |
valueAtProposal = run proposal | |
currentLl = ll valueAtCurrent | |
proposalLl = ll valueAtProposal | |
currentContribution = sum (fmap currentLl obs) | |
proposalContribution = sum (fmap proposalLl obs) | |
currentScore = score current + currentContribution | |
proposalScore = score proposal + proposalContribution | |
fw = negate (log (fromIntegral (depth current))) + score proposal | |
bw = negate (log (fromIntegral (depth proposal))) + score current | |
prob = moveProbability currentScore proposalScore bw fw | |
accept <- bernoulli prob | |
let next = if accept then proposal else stepGenerators current | |
loop (pred n) (snapshot next) | |
moveProbability :: Double -> Double -> Double -> Double -> Double | |
moveProbability current proposal bw fw = | |
whenNaN 0 (exp (min 0 (proposal - current + bw - fw))) | |
where | |
whenNaN val x | |
| isNaN x = val | |
| otherwise = x | |
-- Record the present value of every node in its history. | |
snapshot :: Functor f => Cofree f Node -> Cofree f Node | |
snapshot = extend snapshotValue | |
snapshotValue :: Cofree f Node -> Node | |
snapshotValue (Node {..} :< _) = Node { nodeHistory = history, .. } where | |
history = nodeValue : nodeHistory | |
-- Data.Bits.Extended --------------------------------------------------------- | |
-- | A pure xorshift implementation. | |
-- | |
-- See: https://en.wikipedia.org/wiki/Xorshift. | |
xorshift :: (Bits t, Num t) => (t, t) -> ((t, t), t) | |
xorshift (s0, s1) = ((s1, s11), s11 + s1) where | |
x = s0 `xor` shiftL s0 23 | |
s11 = x `xor` s1 `xor` (shiftR x 17) `xor` (shiftR s1 26) | |
-- Data.Dynamic.Extended ------------------------------------------------------ | |
unsafeFromDyn :: Typeable a => Dynamic -> a | |
unsafeFromDyn = fromJust . fromDynamic | |
-- test / illustration -------------------------------------------------------- | |
posterior1 :: Model (Execution Bool) | |
posterior1 = invert 1000 obs prior model where | |
obs = [ -1.7, -1.8, -2.01, -2.4 | |
, 1.9, 1.8 | |
] | |
prior = do | |
p <- beta 3 2 | |
bernoulli p | |
model left | |
| left = logDensityNormal (negate 2) 0.5 | |
| otherwise = logDensityNormal 2 0.5 | |
mixture :: Double -> Double -> Model Double | |
mixture a b = do | |
prob <- beta a b | |
accept <- bernoulli prob | |
if accept | |
then normal (negate 2) 0.5 | |
else normal 2 0.5 | |
trollGeometric :: Double -> Model Int | |
trollGeometric p = loop where | |
loop = do | |
accept <- return False | |
if accept | |
then return 1 | |
else fmap succ loop | |
analysis1 :: IO () | |
analysis1 = do | |
level0@(Node {..} :< _) <- simulate (toSampler posterior1) | |
writeFile "post_p1_raw.dat" (show (fmap unsafeFromDyn nodeHistory :: [Double])) | |
let level1@(Node {..} :< _) = step level0 | |
writeFile "post_b1_raw.dat" (show (fmap unsafeFromDyn nodeHistory :: [Bool])) | |
main :: IO () | |
main = analysis1 | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment