Created
October 28, 2016 07:18
-
-
Save jtobin/4312880f9fa7b63ebbe5b84b9aa60ff5 to your computer and use it in GitHub Desktop.
Tweaking comonadic inference for performance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# OPTIONS_GHC -Wall #-} | |
{-# OPTIONS_GHC -fno-warn-type-defaults #-} | |
{-# LANGUAGE BangPatterns #-} | |
{-# LANGUAGE DeriveFunctor #-} | |
{-# LANGUAGE LambdaCase #-} | |
{-# LANGUAGE RankNTypes #-} | |
{-# LANGUAGE RecordWildCards #-} | |
import Control.Comonad | |
import Control.Comonad.Cofree | |
import qualified Control.Foldl as L | |
import Control.Monad.Free | |
import Data.Bits | |
import Data.Random | |
import qualified Data.Random.Distribution.Bernoulli as RF | |
import qualified Data.Random.Distribution.Beta as RF | |
import qualified Data.Random.Distribution.Normal as RF | |
import Data.Void | |
import Data.Word | |
import System.Random.Mersenne.Pure64 | |
-- language types ------------------------------------------------------------- | |
data ModelF a r = | |
BernoulliF {-# UNPACK #-} !Double (Bool -> r) | |
| BetaF {-# UNPACK #-} !Double {-# UNPACK #-} !Double (Double -> r) | |
| NormalF {-# UNPACK #-} !Double {-# UNPACK #-} !Double (Double -> r) | |
| DiracF !a | |
deriving Functor | |
type Program a = Free (ModelF a) | |
type Model b = forall a. Program a b | |
type Terminating a = Program a Void | |
type Execution a = Cofree (ModelF a) Node | |
-- boring mechanical types ---------------------------------------------------- | |
data Value = | |
VBool !Bool | |
| VDouble {-# UNPACK #-} !Double | |
| VEmpty | |
deriving (Eq, Show) | |
data Node = Node { | |
nodeCost :: {-# UNPACK #-} !Double | |
, nodeValue :: !Value | |
, nodePrng :: !PureMT | |
} deriving Show | |
data Seed = Seed {-# UNPACK #-} !Word64 {-# UNPACK #-} !Word64 | |
defaultSeed :: Seed | |
defaultSeed = Seed 42 108512 | |
-- primitive terms ------------------------------------------------------------ | |
beta :: Double -> Double -> Program a Double | |
beta a b | |
| a < 0 || b < 0 = error "out of bounds" | |
| otherwise = liftF (BetaF a b id) | |
bernoulli :: Double -> Program a Bool | |
bernoulli p = liftF (BernoulliF vp id) where | |
vp | |
| p < 0 = 0 | |
| p > 1 = 1 | |
| otherwise = p | |
normal :: Double -> Double -> Program a Double | |
normal m s | |
| s < 0 = error "negative variance" | |
| otherwise = liftF (NormalF m s id) | |
dirac :: a -> Program a b | |
dirac x = liftF (DiracF x) | |
-- densities ------------------------------------------------------------------ | |
logDensityBernoulli :: Double -> Bool -> Double | |
logDensityBernoulli p x | |
| p < 0 || p > 1 = log 0 | |
| otherwise = b * log p + (1 - b) * log (1 - p) | |
where | |
b = if x then 1 else 0 | |
logDensityBeta :: Double -> Double -> Double -> Double | |
logDensityBeta a b x | |
| x <= 0 || x >= 1 = log 0 | |
| a < 0 || b < 0 = log 0 | |
| otherwise = (a - 1) * log x + (b - 1) * log (1 - x) | |
logDensityNormal :: Double -> Double -> Double -> Double | |
logDensityNormal m s x | |
| s <= 0 = log 0 | |
| otherwise = negate (log s) - (x - m) ^ 2 / (2 * s ^ 2) | |
logDensityDirac :: Eq a => a -> a -> Double | |
logDensityDirac a x | |
| a == x = 0 | |
| otherwise = log 0 | |
-- sampling ------------------------------------------------------------------- | |
toSampler :: Program a a -> RVar a | |
toSampler = iterM $ \case | |
BernoulliF p f -> RF.bernoulli p >>= f | |
BetaF a b f -> RF.beta a b >>= f | |
NormalF m s f -> RF.normal m s >>= f | |
DiracF x -> return x | |
-- execution: initializing ---------------------------------------------------- | |
execute :: Terminating a -> Execution a | |
execute = loop defaultSeed where | |
loop seed term = case term of | |
Pure r -> absurd r | |
Free instruction -> | |
let (nseed, gseed) = xorshift seed | |
node = initialize (pureMT gseed) instruction | |
in node :< fmap (loop nseed) instruction | |
initialize :: PureMT -> ModelF a b -> Node | |
initialize prng = \case | |
BernoulliF p _ -> Node {..} where | |
(nvalue, nodePrng) = sampleState (RF.bernoulli p) prng | |
nodeCost = logDensityBernoulli p nvalue | |
nodeValue = VBool nvalue | |
BetaF a b _ -> Node {..} where | |
(nvalue, nodePrng) = sampleState (RF.beta a b) prng | |
nodeCost = logDensityBeta a b nvalue | |
nodeValue = VDouble nvalue | |
NormalF m s _ -> Node {..} where | |
(nvalue, nodePrng) = sampleState (RF.normal m s) prng | |
nodeCost = logDensityNormal m s nvalue | |
nodeValue = VDouble nvalue | |
DiracF _ -> Node {..} where | |
nodeCost = 0 | |
nodeValue = VEmpty | |
nodePrng = prng | |
-- execution: scoring and running --------------------------------------------- | |
score :: Execution a -> Double | |
score = loop 0 where | |
loop !acc (Node {..} :< cons) = case cons of | |
BernoulliF _ k -> | |
let VBool val = nodeValue | |
in loop (acc + nodeCost) (k val) | |
BetaF _ _ k -> | |
let VDouble val = nodeValue | |
in loop (acc + nodeCost) (k val) | |
NormalF _ _ k -> | |
let VDouble val = nodeValue | |
in loop (acc + nodeCost) (k val) | |
DiracF _ -> acc | |
depth :: Execution a -> Int | |
depth = loop 0 where | |
loop !acc (Node {..} :< cons) = case cons of | |
BernoulliF _ k -> | |
let VBool val = nodeValue | |
in loop (succ acc) (k val) | |
BetaF _ _ k -> | |
let VDouble val = nodeValue | |
in loop (succ acc) (k val) | |
NormalF _ _ k -> | |
let VDouble val = nodeValue | |
in loop (succ acc) (k val) | |
DiracF _ -> succ acc | |
step :: Execution a -> Execution a | |
step prog@(Node {..} :< _) = stepWithInput nodeValue prog | |
stepWithInput :: Value -> Execution a -> Execution a | |
stepWithInput value prog = case unwrap prog of | |
BernoulliF _ k -> | |
let VBool val = value | |
in k val | |
BetaF _ _ k -> | |
let VDouble val = value | |
in k val | |
NormalF _ _ k -> | |
let VDouble val = value | |
in k val | |
DiracF _ -> prog | |
run :: Execution a -> a | |
run prog = case unwrap prog of | |
DiracF a -> a | |
_ -> run (step prog) | |
runWithInput :: Value -> Execution a -> a | |
runWithInput value = run . stepWithInput value | |
stepGenerators :: Functor f => Cofree f Node -> Cofree f Node | |
stepGenerators = extend stepGenerator | |
stepGenerator :: Cofree f Node -> Node | |
stepGenerator (Node {..} :< _) = Node {nodePrng = prng, ..} where | |
(_, prng) = randomInt nodePrng | |
-- mcmc: perturb -------------------------------------------------------------- | |
perturb :: Execution a -> Execution a | |
perturb = extend perturbNode | |
perturbNode :: Execution a -> Node | |
perturbNode (node@Node {..} :< cons) = case cons of | |
BernoulliF p _ -> Node ncost nvalue prng where | |
(val, prng) = sampleState (RF.bernoulli p) nodePrng | |
ncost = logDensityBernoulli p val | |
nvalue = VBool val | |
BetaF a b _ -> Node ncost nvalue prng where | |
(val, prng) = sampleState (RF.beta a b) nodePrng | |
ncost = logDensityBeta a b val | |
nvalue = VDouble val | |
NormalF m s _ -> Node ncost nvalue prng where | |
(val, prng) = sampleState (RF.normal m s) nodePrng | |
ncost = logDensityNormal m s val | |
nvalue = VDouble val | |
DiracF _ -> node | |
-- mcmc: markov chain --------------------------------------------------------- | |
invert :: Int -> [a] -> Model b -> (b -> a -> Double) -> Model (Execution b) | |
invert epochs obs prior ll = loop epochs (execute (prior >>= dirac)) where | |
loop n current | |
| n == 0 = return current | |
| otherwise = do | |
let proposal = perturb current | |
ccostPrior = score current | |
pcostPrior = score proposal | |
cvalue = run current | |
pvalue = run proposal | |
ccostObs = L.fold (L.premap (ll cvalue) L.sum) obs | |
pcostObs = L.fold (L.premap (ll pvalue) L.sum) obs | |
ccost = ccostPrior + ccostObs | |
pcost = pcostPrior + pcostObs | |
fwcost = negate (log (fromIntegral (depth current))) + pcostPrior | |
bwcost = negate (log (fromIntegral (depth proposal))) + ccostPrior | |
prob = moveProbability ccost pcost bwcost fwcost | |
accept <- bernoulli prob | |
let next = if accept then proposal else stepGenerators current | |
loop (pred n) next | |
moveProbability :: Double -> Double -> Double -> Double -> Double | |
moveProbability current proposal bw fw = | |
whenNaN 0 (exp (min 0 (proposal - current + bw - fw))) | |
where | |
whenNaN val x | |
| isNaN x = val | |
| otherwise = x | |
-- Data.Bits.Extended --------------------------------------------------------- | |
-- | A pure xorshift implementation. | |
-- | |
-- See: https://en.wikipedia.org/wiki/Xorshift. | |
xorshift :: Seed -> (Seed, Word64) | |
xorshift (Seed s0 s1) = (Seed s1 s11, s11 + s1) where | |
x = s0 `xor` shiftL s0 23 | |
s11 = x `xor` s1 `xor` (shiftR x 17) `xor` (shiftR s1 26) | |
-- test ----------------------------------------------------------------------- | |
test :: Model Bool | |
test = do | |
p <- beta 1 2 | |
bernoulli p | |
xs :: [Double] | |
xs = [ -1.7, -1.8, -2.01, -2.4 | |
, 1.9, 1.8 | |
] | |
model :: Bool -> Double -> Double | |
model left | |
| left = logDensityNormal (negate 2) 0.5 | |
| otherwise = logDensityNormal 2 0.5 | |
posterior :: Model (Execution Bool) | |
posterior = invert 1000 xs test model | |
main :: IO () | |
main = do | |
foo <- sample (toSampler posterior) | |
print (extract foo) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment