Skip to content

Instantly share code, notes, and snippets.

@MMesch
Created January 13, 2020 15:07
Show Gist options
  • Save MMesch/ed9f2a74ecd9035c525ae7fbced8de5a to your computer and use it in GitHub Desktop.
Save MMesch/ed9f2a74ecd9035c525ae7fbced8de5a to your computer and use it in GitHub Desktop.
monad-bayes walkthrough
_: pkgs:
let
ihaskellSrc = pkgs.fetchFromGitHub {
owner = "gibiansky";
repo = "IHaskell";
rev = "2318ee2a90cfc98390651657aec434586b963235";
sha256 = "0svjzs81i77s710cfb7pxkfdi979mhjazpc2l9k9ha752spz04cj";
};
monadBayesSrc = pkgs.fetchFromGitHub {
owner = "adscib";
repo = "monad-bayes";
rev = "fb87bf039bab35dcc82de8ccf8963a7a576af355";
sha256 = "0jz7lswdzxzn5zzwypdawdj7j0y20aakmqggv9pw4sknajdqqqyf";
};
hVegaSrc = pkgs.fetchFromGitHub {
owner = "DougBurke";
repo = "hvega";
rev = "hvega-0.4.0.0";
sha256 = "1pg655a36nsz7h2l1sbyk4zzzjjw4dlah8794bc0flpigr7iik13";
};
overrides = self: hspkgs:
let
callDisplayPackage = name:
hspkgs.callCabal2nix
"ihaskell-${name}"
"${ihaskellSrc}/ihaskell-display/ihaskell-${name}"
{};
dontCheck = pkgs.haskell.lib.dontCheck;
dontHaddock = pkgs.haskell.lib.dontHaddock;
in
{
monad-bayes = hspkgs.callCabal2nix "monad-bayes" "${monadBayesSrc}" {};
hvega = hspkgs.callCabal2nix "hvega" "${hVegaSrc}/hvega" {};
ihaskell-hvega = hspkgs.callCabal2nix "ihaskell-hvega" "${hVegaSrc}/ihaskell-hvega" {};
ihaskell = pkgs.haskell.lib.overrideCabal
(hspkgs.callCabal2nix "ihaskell" ihaskellSrc {})
(_drv: {
preCheck = ''
export HOME=$(${pkgs.pkgs.coreutils}/bin/mktemp -d)
export PATH=$PWD/dist/build/ihaskell:$PATH
export GHC_PACKAGE_PATH=$PWD/dist/package.conf.inplace/:$GHC_PACKAGE_PATH
'';
configureFlags = (_drv.configureFlags or []) ++ [
# otherwise the tests are agonisingly slow and the kernel times out
"--enable-executable-dynamic"
];
doHaddock = false;
});
ghc-parser = hspkgs.callCabal2nix "ghc-parser" "${ihaskellSrc}/ghc-parser" {};
ipython-kernel = hspkgs.callCabal2nix "ipython-kernel" "${ihaskellSrc}/ipython-kernel" {};
ihaskell-aeson = callDisplayPackage "aeson";
ihaskell-blaze = callDisplayPackage "blaze";
ihaskell-charts = callDisplayPackage "charts";
ihaskell-diagrams = callDisplayPackage "diagrams";
ihaskell-gnuplot = callDisplayPackage "gnuplot";
ihaskell-graphviz = callDisplayPackage "graphviz";
ihaskell-hatex = callDisplayPackage "hatex";
ihaskell-juicypixels = callDisplayPackage "juicypixels";
ihaskell-magic = callDisplayPackage "magic";
ihaskell-plot = callDisplayPackage "plot";
ihaskell-rlangqq = callDisplayPackage "rlangqq";
ihaskell-static-canvas = callDisplayPackage "static-canvas";
ihaskell-widgets = callDisplayPackage "widgets";
# Marked as broken in this version of Nixpkgs.
chell = hspkgs.callHackage "chell" "0.4.0.2" {};
patience = hspkgs.callHackage "patience" "0.1.1" {};
# Version compatible with ghc-lib-parser.
hlint = hspkgs.callHackage "hlint" "2.2.1" {};
# Tests not passing.
Diff = dontCheck hspkgs.Diff;
zeromq4-haskell = dontCheck hspkgs.zeromq4-haskell;
funflow = dontCheck hspkgs.funflow;
# Haddocks not building.
ghc-lib-parser = dontHaddock hspkgs.ghc-lib-parser;
# Missing dependency.
aeson = pkgs.haskell.lib.addBuildDepends hspkgs.aeson [ self.contravariant ];
};
in
{
haskell = pkgs.haskell // {
packages = pkgs.haskell.packages // {
"ghc865" = pkgs.haskell.packages.ghc865.override (old: {
overrides =
pkgs.lib.composeExtensions
(old.overrides or (_: _: {}))
overrides;}
);
};
};
}
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
":e GADTs\n",
":e GeneralizedNewtypeDeriving\n",
":e FlexibleContexts\n",
":e MultiParamTypeClasses\n",
":e RankNTypes\n",
":e NoMonadFailDesugaring\n",
":e TupleSections\n",
"\n",
"import Data.Functor.Identity\n",
"import Control.Monad.Trans.Identity\n",
"\n",
"import Control.Arrow (second)\n",
"import Control.Monad (replicateM)\n",
"import qualified Data.List\n",
"\n",
"import Control.Monad.Trans\n",
"import Control.Monad.Coroutine hiding (suspend)\n",
"import Control.Monad.Coroutine.SuspensionFunctors\n",
"import Data.Either\n",
"import Control.Applicative (liftA2)\n",
"import Control.Monad.Writer\n",
"import Control.Monad.State\n",
"import Control.Monad.Trans.List\n",
"import Control.Monad.ST (RealWorld)\n",
"import qualified Data.Vector as V\n",
"import qualified Data.Vector.Generic as VG\n",
"import Control.Monad.Trans.Free.Church\n",
"import qualified System.Random.MWC as MWC\n",
"import Control.Monad.Free (Free(..))\n",
"import Control.Monad.Trans.Free (FreeT(..))\n",
"import Data.Char (toUpper)\n",
"\n",
"import Statistics.Distribution (ContDistr, quantile)\n",
"import Numeric.Log (Log, ln)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Walk through Monad-Bayes implementation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Free.hs\n",
"\n",
"basically an explanation for [this code](https://github.com/adscib/monad-bayes/blob/master/src/Control/Monad/Bayes/Free.hs).\n",
"We will go line by line through this file to understand what Monad-Bayes is doing.\n",
"\n",
"### `SamF` - a sampling function type\n",
"\n",
"The first type that is introduced in the above mentioned code of Monad Bayes is `SamF` that probably means something like _sampling function_.\n",
"It provides a data constructor `Random`, that wraps a function from `Double -> a`:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"newtype SamF a = Random (Double -> a)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"How does this type reflect a _sampling function_?\n",
"Let's take a few moments to reflect on this.\n",
"What happens if we wrap a simple function that returns `True` for certain inputs and `False` for others: "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>/* Styles used for the Hoogle display in the pager */\n",
".hoogle-doc {\n",
"display: block;\n",
"padding-bottom: 1.3em;\n",
"padding-left: 0.4em;\n",
"}\n",
".hoogle-code {\n",
"display: block;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"}\n",
".hoogle-text {\n",
"display: block;\n",
"}\n",
".hoogle-name {\n",
"color: green;\n",
"font-weight: bold;\n",
"}\n",
".hoogle-head {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-sub {\n",
"display: block;\n",
"margin-left: 0.4em;\n",
"}\n",
".hoogle-package {\n",
"font-weight: bold;\n",
"font-style: italic;\n",
"}\n",
".hoogle-module {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-class {\n",
"font-weight: bold;\n",
"}\n",
".get-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"display: block;\n",
"white-space: pre-wrap;\n",
"}\n",
".show-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"margin-left: 1em;\n",
"}\n",
".mono {\n",
"font-family: monospace;\n",
"display: block;\n",
"}\n",
".err-msg {\n",
"color: red;\n",
"font-style: italic;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"display: block;\n",
"}\n",
"#unshowable {\n",
"color: red;\n",
"font-weight: bold;\n",
"}\n",
".err-msg.in.collapse {\n",
"padding-top: 0.7em;\n",
"}\n",
".highlight-code {\n",
"white-space: pre;\n",
"font-family: monospace;\n",
"}\n",
".suggestion-warning { \n",
"font-weight: bold;\n",
"color: rgb(200, 130, 0);\n",
"}\n",
".suggestion-error { \n",
"font-weight: bold;\n",
"color: red;\n",
"}\n",
".suggestion-name {\n",
"font-weight: bold;\n",
"}\n",
"</style><span class='get-type'>a :: SamF Bool</span>"
],
"text/plain": [
"a :: SamF Bool"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"a = Random (\\x -> if x < 0.5 then True else False)\n",
":t a"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`SamF Bool` is not a value of `Bool`.\n",
"But we can get a value of `Bool` out of it if we feed it a `Double`.\n",
"The idea is that if we feed a random `Double`, say from a uniform distribution between 0 and 1, we will also obtain a random `Bool`.\n",
"`SamF Bool` therefore corresponds to something like a _probabilistic_ `Bool`, or in other words something from which we can sample `Bool` values.\n",
"\n",
"Let's try to draw an actual `Bool` sample from `SamF Bool`.\n",
"To this end, we first need a random `Double` value that we get from a standard random number generator (described with the opaque type `MWC.Gen RealWorld`):"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"generateSampler :: MWC.Gen RealWorld -> SamF a -> IO a\n",
"generateSampler gen (Random func) = do\n",
" v <- MWC.uniform gen :: IO Double\n",
" return $ func v\n",
"\n",
"gen <- MWC.create :: IO (MWC.Gen RealWorld)\n",
"sample = generateSampler gen"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The function `runSamF` thus extracts a sample from our distribution of possible `Bool` values.\n",
"This distribution is directly linked to the input distribution of `Doubles` that define the randomness of the whole process.\n",
"Here is how it behaves:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"False"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"True"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"False"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sample a\n",
"sample a\n",
"sample a\n",
"sample a"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ok, this works, but why should we go through all this hassle?\n",
"Why shouldn't we just define a sampler that directly takes a function `Double -> a` and returns a value `a`.\n",
"Or why don't we just immediately use the function itself and inject a number from a random number generator?\n",
"We will hopefully come to a conclusion about these important questions lateron.\n",
"\n",
"For now let's go to the next line in the `monad-bayes` source code and look at this simple definition of a functor instance for our new `SamF a` type:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"instance Functor SamF where\n",
" fmap f (Random k) = Random (f . k)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Remember a function simply maps a function over the values inside of a type.\n",
"In this case this value is the _sample_ that we take out of the _distribution_ that is described by `SamF`.\n",
"For example, we can now do something like this:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"10"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"10"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"sample $ fmap (\\x -> if x then 10 else 0) a\n",
"sample $ fmap (\\x -> if x then 10 else 0) a"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"the function instance helps us to further chain operations to the random samples that we have constructed.\n",
"Our first primitive `SamF` thus represents a simple sampleable distribution.\n",
"Let's move on..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Essential Concepts\n",
" \n",
"The next type that we encounter, `FreeSampler` is described with this small comment:\n",
"\n",
"> Free monad transformer over random sampling.\n",
"> Uses the Church-encoded version of the free monad\n",
"> for efficiency.\n",
"\n",
"Don't worry if you don't understand anything.\n",
"We'll try to explain it step by step."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Monads and Free Monads\n",
"\n",
"A Monad describes structure, computations that be generated with _return_ `a -> m a` and _joined_ together as `m (m a) -> m a`.\n",
"You can thus generate and bind effect-generating operations together.\n",
"Another explanation of a Monad is that it is a Monoid of endofunctors, in other words something that allows us to concatenate endofunctors, that is structure or computations, together.\n",
"\n",
"Monad Bayes extensively uses a so-called free monad transformer `FT` from the [free-5.1.2: Monads for free](https://hackage.haskell.org/package/free-5.1.2/docs/Control-Monad-Trans-Free-Church.html) package.\n",
"A Free monad as a linked list of instructions.\n",
"We get such a Monad _for free_, that is with a simple type constructor from the free monad's library, from a data structure that describes this set of instructions.\n",
"A _Free_ Monad is thus to a computation what the _Free_ Monoid list is to a normal type:\n",
"The essential thing that we should remember are:\n",
"\n",
"> Free monads build syntax trees. See the example sections for details.\n",
"> A free monad over a functor resembles a list of that functor: \n",
">\n",
"> * return behaves like [] by not using the functor at all\n",
"> * wrap behaves like (:) by prepending another layer of the functor\n",
"> * liftF behaves like singleton by creating a list from a single layer of the functor.\n",
"\n",
"(from [here](https://hackage.haskell.org/package/transformers-free-1.0.1/docs/Control-Monad-Trans-Free.html#v:wrap))\n",
"\n",
"What structure, computations or effects are we talking about here in terms of random variables? Well, one action is defined by our `SamF` type that draws a random variable of some type when we feed it with some basic randomness. Free monads are tightly related to the idea of describing code first as an abstract syntax tree data structure and then interpret it in a separate step. They describe how different operations and actions in the syntax tree can be combined.\n",
"\n",
"#### Monad & Free Monad Transformer\n",
"\n",
"A Monad Transformer allows me to _compose multiple_ Monads.\n",
"For example, if we have monad `m1` and `m2`, the monad transformer describes the composed monad `m1 (m2 a)`.\n",
"Unfortunately there isn't a general way how this can be done, so we have to specify one of the monads that will be composed.\n",
"The other monad is usually taken as an argument.\n",
"\n",
"We can also stack several Monad transformers together. In Monad Bayes we will frequently see things like: `WriterT .. (ReaderT .. (StateT .. (FreeSampler (SamplerIO Double) ..)))` .\n",
"This means that we get read-write state via StateT, Read-only state via ReaderT, write-only state via WriterT, our free monad instructions via FreeSampler, and some way to draw random numbers via SamplerIO.\n",
"This whole big stack will even be expanded more with custom datatypes such as `Traced` that give access to a past sample and result of a probabilistic computation, such that we can implement metropolis-hastings.\n",
"\n",
"This stack of Monad transformers is really the heart of Monad-Bayes because it allows to represent a probabilistic computation in different ways.\n",
"\n",
"#### Church-encoding\n",
"\n",
"not important for now. Let's say it just speeds up every operation that we do with FreeSampler"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### `FreeSampler` - step by step"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"newtype FreeSampler m a = FreeSampler (FT SamF m a)\n",
" deriving(Functor,Applicative,Monad,MonadTrans)\n",
" \n",
"runFreeSampler :: FreeSampler m a -> FT SamF m a\n",
"runFreeSampler (FreeSampler m) = m"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First of all, `FreeSampler m a` is simply a `newtype` wrapper around the type `FT SamF m a`, and it comes with a `runFreeSampler` function that unwraps its contents.\n",
"\n",
"This type can be decomposed into a return type `a` with associated structure `FT SamF m`.\n",
"`FT` stands for the _Free Monad Transformer_ here.\n",
"In other words, it builds a Monad from a functor `SamF` and another arbitrary Monad `m`.\n",
"This Monad is basically a linked list of instructions, where each instruction either comes from the functor or from the Monad.\n",
"Think about this list as `f1(m1(m2(m3(f2(m4)))))`, where `f1,f2` are instructions from the function `f` and `m1,m2,m3,m4` are instructions from the Monad `m`.\n",
"When we run this list of instructions by recursively unwrapping one layer after the next and then doing something depending on the instructions."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can prepend another \"layer\" of instructions:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"instance Monad m => MonadFree SamF (FreeSampler m) where\n",
" wrap = FreeSampler . wrap . fmap runFreeSampler"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"the random function of `MonadSample` is the same as the `Random` data constructor of `SamF` lifted into a `FreeSampler` operation:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The next function is interesting:\n",
"\n",
"> Lift a monad morphism from m to n into a monad morphism from (t m) to (t n)\n",
"> The first argument to hoist must be a monad morphism, even though the type system does not enforce this\n",
"\n",
"(from [here](https://hackage.haskell.org/package/pipes-4.3.12/docs/Pipes.html#v:hoist))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> FreeSampler m a -> FreeSampler n a\n",
"hoist f (FreeSampler m) = FreeSampler (hoistFT f m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"all of this was about building up a set of instructions with the `FreeSampler` datatype.\n",
"We haven't yet interpreted this list."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Executing random sampling\n",
"\n",
"now we are getting closer to the actual execution, that is to the interpreter:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"What we do with this function is that whenever we encounter a `Random` datatype in the list of instructions, we replace it with the `random` function."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"this is how we actually supply the `Double` value via the state monad:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"-- | Execute computation with supplied values for random choices.\n",
"withRandomness :: Monad m => [Double] -> FreeSampler m a -> m a\n",
"withRandomness randomness (FreeSampler m) = evalStateT (iterTM f m) randomness where\n",
" f (Random k) = do\n",
" xs <- get\n",
" case xs of\n",
" [] -> error \"FreeSampler: the list of randomness was too short\"\n",
" y:ys -> put ys >> k y"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We iterate through the instructions.\n",
"When we encounter a `Random a` datatype, where `a` is a function that ingests a random float and returns a value, we take the full list of randomness out of the state monad with `get` and feed the first value into this function `a`.\n",
"We then put the remaining list of random values back into the state monad and keep iterating."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### What do we have until now?\n",
"\n",
"We have the functor `SamF` with which we can describe Random functions:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>/* Styles used for the Hoogle display in the pager */\n",
".hoogle-doc {\n",
"display: block;\n",
"padding-bottom: 1.3em;\n",
"padding-left: 0.4em;\n",
"}\n",
".hoogle-code {\n",
"display: block;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"}\n",
".hoogle-text {\n",
"display: block;\n",
"}\n",
".hoogle-name {\n",
"color: green;\n",
"font-weight: bold;\n",
"}\n",
".hoogle-head {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-sub {\n",
"display: block;\n",
"margin-left: 0.4em;\n",
"}\n",
".hoogle-package {\n",
"font-weight: bold;\n",
"font-style: italic;\n",
"}\n",
".hoogle-module {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-class {\n",
"font-weight: bold;\n",
"}\n",
".get-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"display: block;\n",
"white-space: pre-wrap;\n",
"}\n",
".show-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"margin-left: 1em;\n",
"}\n",
".mono {\n",
"font-family: monospace;\n",
"display: block;\n",
"}\n",
".err-msg {\n",
"color: red;\n",
"font-style: italic;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"display: block;\n",
"}\n",
"#unshowable {\n",
"color: red;\n",
"font-weight: bold;\n",
"}\n",
".err-msg.in.collapse {\n",
"padding-top: 0.7em;\n",
"}\n",
".highlight-code {\n",
"white-space: pre;\n",
"font-family: monospace;\n",
"}\n",
".suggestion-warning { \n",
"font-weight: bold;\n",
"color: rgb(200, 130, 0);\n",
"}\n",
".suggestion-error { \n",
"font-weight: bold;\n",
"color: red;\n",
"}\n",
".suggestion-name {\n",
"font-weight: bold;\n",
"}\n",
"</style><span class='get-type'>a :: SamF Double</span>"
],
"text/plain": [
"a :: SamF Double"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"a = Random (2*)\n",
":t a"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We have the Monad transformer `FreeSampler m a` with which we can wrap custom commands with basic Monads:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Identity 0.3"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Identity 0.1"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"probabilisticProgram :: FreeSampler Identity Double\n",
"probabilisticProgram = do\n",
" a <- liftF $ Random id\n",
" b <- liftF $ Random (\\x -> if x < 0.5 then 0.0 else 1.0)\n",
" return (a + b)\n",
"\n",
"a = runFreeSampler probabilisticProgram\n",
"\n",
"interpret2 :: SamF a -> a\n",
"interpret2 (Random k) = k 0.3\n",
"\n",
"val = runFreeSampler probabilisticProgram\n",
"eval = iterT interpret2 val\n",
"eval\n",
"\n",
"eval = withRandomness [0.1, 0.2] probabilisticProgram\n",
"eval"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Introducing MonadSample\n",
"\n",
"We are now examining a variant that automatically draws a random variable via a type class `MonadSample` that provides functions to draw various random variables:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"-- | Class of monads that can draw random variables.\n",
"class Monad m => MonadSample m where\n",
" -- | A random variable distributed uniformly on [0,1].\n",
" random :: m Double\n",
" -- ...\n",
" bernoulli :: Double -> m Bool\n",
" bernoulli p = fmap (< p) random\n",
" \n",
" categorical :: VG.Vector v Double => v Double -> m Int\n",
" categorical ps = fromPMF (ps VG.!) where\n",
" -- | Draw a value from a discrete distribution using a sequence of draws from Bernoulli.\n",
" fromPMF :: MonadSample m => (Int -> Double) -> m Int\n",
" fromPMF p = f 0 1 where\n",
" f i r = do\n",
" when (r < 0) $ error \"fromPMF: total PMF above 1\"\n",
" let q = p i\n",
" when (q < 0 || q > 1) $ error \"fromPMF: invalid probability value\"\n",
" b <- bernoulli (q / r)\n",
" if b then pure i else f (i+1) (r-q)\n",
" \n",
" logCategorical :: (VG.Vector v (Log Double), VG.Vector v Double) => v (Log Double) -> m Int\n",
" logCategorical = categorical . VG.map (exp . ln)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"we need to lift the `random` and `score` functions over all transformers:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"instance MonadSample m => MonadSample (IdentityT m) where\n",
" random = lift random\n",
"\n",
"instance (Monoid w, MonadSample m) => MonadSample (WriterT w m) where\n",
" random = lift random\n",
"\n",
"instance MonadSample m => MonadSample (StateT s m) where\n",
" random = lift random\n",
"\n",
"instance MonadSample m => MonadSample (ListT m) where\n",
" random = lift random"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"we can make `FreeSampler` an instance of MonadSample:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"instance Monad m => MonadSample (FreeSampler m) where\n",
" random = FreeSampler $ liftF (Random id)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that this just passes on the supplied random variable.\n",
"We can also introduce a monad that we call `Sampler` that generates a new random variable:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"newtype Sampler a = Sampler { runSampler :: IO a}\n",
" deriving(Functor, Applicative, Monad, MonadIO)\n",
"\n",
"gen <- MWC.create :: IO (MWC.Gen RealWorld)\n",
"instance MonadSample Sampler where\n",
" random = Sampler $ MWC.uniform gen"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"this function draws a random variable from the transformed Monad and passes it on to `FreeSampler`:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"-- | Execute random sampling in the transformed monad.\n",
"interpret :: MonadSample m => FreeSampler m a -> m a\n",
"interpret (FreeSampler m) = iterT f m where\n",
" f (Random k) = random >>= k"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"let's see:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.024810362882962"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"probabilisticProgram :: FreeSampler Sampler Double\n",
"probabilisticProgram = do\n",
" a <- liftF $ Random id\n",
" b <- liftF $ Random (\\x -> if x < 0.5 then 0.0 else 1.0)\n",
" return (a + b)\n",
"\n",
"eval = runSampler . interpret $ probabilisticProgram\n",
"eval"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and now a function that injects parts of the randomness and that draws the remaining variables:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"-- | Execute computation with supplied values for a subset of random choices.\n",
"-- Return the output value and a record of all random choices used, whether\n",
"-- taken as input or drawn using the transformed monad.\n",
"withPartialRandomness :: MonadSample m => [Double] -> FreeSampler m a -> m (a, [Double])\n",
"withPartialRandomness randomness (FreeSampler m) =\n",
" runWriterT $ evalStateT (iterTM f $ hoistFT lift m) randomness where\n",
" f (Random k) = do\n",
" -- This block runs in StateT [Double] (WriterT [Double]) m.\n",
" -- StateT propagates consumed randomness while WriterT records\n",
" -- randomness used, whether old or new.\n",
" xs <- get\n",
" x <- case xs of\n",
" [] -> random\n",
" y:ys -> put ys >> return y\n",
" tell [x]\n",
" k x"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"finally the `runWith` function runs the whole construct and returns a record of value and all random choices used:"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"-- | Like 'withPartialRandomness', but use an arbitrary sampling monad.\n",
"runWith :: MonadSample m => [Double] -> FreeSampler Identity a -> m (a, [Double])\n",
"runWith randomness m = withPartialRandomness randomness $ hoist (return . runIdentity) m"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"to summarize, with `FreeSampler`, we can add random elements to any of the standard Monads that we use in Haskell, such as List, State, Writer or others.\n",
"These random elements are functions that need to be fed with some basic randomness to produce a random variable.\n",
"We will now see how we can implement Monads `m` on top of this functionality that are particularly useful for probabilistic computations..."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Weighted - Weighted.hs\n",
"\n",
"https://github.com/adscib/monad-bayes/blob/master/src/Control/Monad/Bayes/Weighted.hs"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"-- | Monads that can score different execution paths.\n",
"class Monad m => MonadCond m where\n",
" score :: Log Double -> m ()\n",
" \n",
"-- | Monads that support both sampling and scoring.\n",
"class (MonadSample m, MonadCond m) => MonadInfer m"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"again, we need to lift the basic functions over some standard Monad Transformers that we are going to use:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"instance MonadCond m => MonadCond (IdentityT m) where\n",
" score = lift . score\n",
"\n",
"instance (Monoid w, MonadCond m) => MonadCond (WriterT w m) where\n",
" score = lift . score\n",
"\n",
"instance MonadCond m => MonadCond (StateT s m) where\n",
" score = lift . score\n",
"\n",
"instance MonadCond m => MonadCond (ListT m) where\n",
" score = lift . score\n",
"\n",
"instance MonadInfer m => MonadInfer (StateT s m)\n",
"\n",
"instance (Monoid w, MonadInfer m) => MonadInfer (WriterT w m)\n",
"\n",
"instance MonadInfer m => MonadInfer (IdentityT m)\n",
"\n",
"instance MonadInfer m => MonadInfer (ListT m)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"-- | Executes the program using the prior distribution, while accumulating likelihood.\n",
"newtype Weighted m a = Weighted (StateT (Log Double) m a)\n",
" --StateT is more efficient than WriterT\n",
" deriving(Functor, Applicative, Monad, MonadIO, MonadTrans, MonadSample)\n",
"\n",
"instance Monad m => MonadCond (Weighted m) where\n",
" score w = Weighted (modify (* w))\n",
"\n",
"instance MonadSample m => MonadInfer (Weighted m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here we are, with our first concrete Monad that uses random functions, the `Weighted` Monad.\n",
"It is a wrapper around `StateT` with a `Log Double` state variable that is going to be used to store the accumulated _likelihood_ of a computation.\n",
"It can additionally be wrapped around another monad `m` and return type `a`.\n",
"Importantly, the `Weighted` Monad is an instance of `MonadCond` and provides a `score` function that multiplies the `Log Double` that is stored in the `StateT` Monad with a number (multiplication of a `Log Double` is just an addition under the hood).\n",
"Given a weighted monad, we can extract a value and the accumulated likelihood with this function:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"-- | Obtain an explicit value of the likelihood for a given value.\n",
"runWeighted :: (Functor m) => Weighted m a -> m (a, Log Double)\n",
"runWeighted (Weighted m) = runStateT m 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The likelihood is then multiplied at every occurence of a `score` function on top of the initial value `1` during a probabilistic computation."
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"-- | Compute the weight and discard the sample.\n",
"extractWeight :: Functor m => Weighted m a -> m (Log Double)\n",
"extractWeight m = snd <$> runWeighted m\n",
"\n",
"-- | Embed a random variable with explicitly given likelihood.\n",
"--\n",
"-- > runWeighted . withWeight = id\n",
"withWeight :: (Monad m) => m (a, Log Double) -> Weighted m a\n",
"withWeight m = Weighted $ do\n",
" (x,w) <- lift m\n",
" modify (* w)\n",
" return x\n",
"\n",
"-- | Discard the weight.\n",
"-- This operation introduces bias.\n",
"prior :: (Functor m) => Weighted m a -> m a\n",
"prior = fmap fst . runWeighted\n",
"\n",
"-- | Combine weights from two different levels.\n",
"flatten :: Monad m => Weighted (Weighted m) a -> Weighted m a\n",
"flatten m = withWeight $ (\\((x,p),q) -> (x, p*q)) <$> runWeighted (runWeighted m)\n",
"\n",
"-- | Use the weight as a factor in the transformed monad.\n",
"applyWeight :: MonadCond m => Weighted m a -> m a\n",
"applyWeight m = do\n",
" (x, w) <- runWeighted m\n",
" score w\n",
" return x\n",
"\n",
"-- | Apply a transformation to the transformed monad.\n",
"hoistWeighted :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a\n",
"hoistWeighted t (Weighted m) = Weighted $ mapStateT t m"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Playing with Weighted"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Identity ((),0.10000000000000002)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"Identity (True,5.000000000000001e-2)"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"a = withWeight $ Identity (True, 0.2)\n",
"b = a >> score 0.5\n",
"c = runWeighted b\n",
"c\n",
"\n",
"probComp :: Weighted Identity Bool\n",
"probComp = do\n",
" score 0.5\n",
" score 0.1\n",
" return True\n",
" \n",
"runWeighted probComp"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"probComp :: Weighted (FreeSampler Sampler) Double\n",
"probComp = do\n",
" a <- random\n",
" b <- lift . liftF $ Random (\\x -> if x < 0.5 then 0.0 else 1.0)\n",
" c <- random\n",
" score 0.5\n",
" score 0.1\n",
" return b"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((0.0,5.000000000000001e-2),[0.1,0.15936354678388287,0.6952687664728953])"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"a = runSampler . withPartialRandomness [0.1] $ runWeighted probComp\n",
"a"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"this starts to look like a probabilistic computation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Traced - Traced/Common.hs & Traced/Static.hs\n",
"\n",
"https://github.com/adscib/monad-bayes/blob/master/src/Control/Monad/Bayes/Traced/Common.hs\n",
"\n",
"we are now ready to go through the metropolis hastings implementation, that is, understanding the `mh` function in monad-bayes.\n",
"We start with the `Common.hs` file that exports a few important functions:\n",
"\n",
"```\n",
"module Control.Monad.Bayes.Traced.Common (\n",
" Trace,\n",
" singleton,\n",
" output,\n",
" scored,\n",
" bind,\n",
" mhTrans,\n",
" mhTrans'\n",
") where\n",
"```\n",
"\n",
"It defines a `Trace` datatype, that is essentially a list of `Doubles`, the _randomness_ that we have talked about before, the output value of an associated probabilistic computation, and a density that can be seen as the accumulated log likelihood of this sample.\n",
"Here is its definition, together with `Functor`, `Applicative` and `Monad` instances:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"data Trace a =\n",
" Trace {\n",
" variables :: [Double],\n",
" output :: a,\n",
" density :: Log Double\n",
" }\n",
"\n",
"instance Functor Trace where\n",
" fmap f t = t {output = f (output t)}\n",
"\n",
"instance Applicative Trace where\n",
" pure x = Trace {variables = [], output = x, density = 1}\n",
" tf <*> tx = Trace {variables = variables tf ++ variables tx, output = output tf (output tx), density = density tf * density tx}\n",
"\n",
"instance Monad Trace where\n",
" t >>= f =\n",
" let t' = f (output t) in\n",
" t' {variables = variables t ++ variables t', density = density t * density t'}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"some functionality that allow us to initialize and combine `Trace` objects:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"singleton :: Double -> Trace Double\n",
"singleton u = Trace {variables = [u], output = u, density = 1}\n",
"\n",
"scored :: Log Double -> Trace ()\n",
"scored w = Trace {variables = [], output = (), density = w}\n",
"\n",
"bind :: Monad m => m (Trace a) -> (a -> m (Trace b)) -> m (Trace b)\n",
"bind dx f = do\n",
" t1 <- dx\n",
" t2 <- f (output t1)\n",
" return $ t2 {variables = variables t1 ++ variables t2, density = density t1 * density t2}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And finally we arrive at a metropolis hastings transition:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"-- | A single Metropolis-corrected transition of single-site Trace MCMC.\n",
"mhTrans :: MonadSample m => Weighted (FreeSampler m) a -> Trace a -> m (Trace a)\n",
"mhTrans m t = do\n",
" let us = variables t\n",
" a = output t\n",
" p = density t\n",
" us' <- do\n",
" let n = length us\n",
" i <- categorical $ V.replicate n (1 / fromIntegral n)\n",
" u' <- random\n",
" let (xs, _:ys) = splitAt i us\n",
" return $ xs ++ (u':ys)\n",
" ((b, q), vs) <- runWriterT $ runWeighted $ hoistWeighted (WriterT . withPartialRandomness us') m\n",
" let ratio = (exp . ln) $ min 1 (q * fromIntegral (length us) / (p * fromIntegral (length vs)))\n",
" accept <- bernoulli ratio\n",
" return $ if accept then Trace vs b q else t\n",
"\n",
"-- | A variant of 'mhTrans' with an external sampling monad.\n",
"mhTrans' :: MonadSample m => Weighted (FreeSampler Identity) a -> Trace a -> m (Trace a)\n",
"mhTrans' m = mhTrans (hoistWeighted (hoist (return . runIdentity)) m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"let's try this out:\n",
"\n",
"#### Playing with a single metropolis hastings step"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[0.44984153252922365,0.0,0.0]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"0.0"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"5.000000000000001e-2"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"tr = Trace {variables=[0.0, 0.0, 0.0], output=0, density=1e-10}\n",
"a = mhTrans probComp tr\n",
"fmap variables $ runSampler a\n",
"fmap output $ runSampler a\n",
"fmap density $ runSampler a"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The final probabilistic computation is of type: `(StateT .. (FT SamF ..) ..)`.\n",
"It provides access to functionality that accesses state to store the log likelihood, to random elements of type `Random` through `SamF`, and in this case to a sampler that can also provide random variables."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Iterating through steps with the Traced Monad"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"-- | A tracing monad where only a subset of random choices are traced.\n",
"-- The random choices that are not to be traced should be lifted\n",
"-- from the transformed monad.\n",
"data Traced m a = Traced (Weighted (FreeSampler m) a) (m (Trace a))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"A traced computation contains the probabilistic computation _and_ a trace of its last execution.\n",
"We can extract the last trace:"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"traceDist :: Traced m a -> m (Trace a)\n",
"traceDist (Traced _ d) = d"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and we can extract the probabilistic computation, the _model_ with:"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [],
"source": [
"model :: Traced m a -> Weighted (FreeSampler m) a\n",
"model (Traced m _) = m"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here are some instances for this computation:"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"instance Monad m => Functor (Traced m) where\n",
" fmap f (Traced m d) = Traced (fmap f m) (fmap (fmap f) d)\n",
"\n",
"instance Monad m => Applicative (Traced m) where\n",
" pure x = Traced (pure x) (pure (pure x))\n",
" (Traced mf df) <*> (Traced mx dx) = Traced (mf <*> mx) (liftA2 (<*>) df dx)\n",
"\n",
"instance Monad m => Monad (Traced m) where\n",
" (Traced mx dx) >>= f = Traced my dy where\n",
" my = mx >>= model . f\n",
" dy = dx `bind` (traceDist . f)\n",
"\n",
"instance MonadTrans Traced where\n",
" lift m = Traced (lift $ lift m) (fmap pure m)\n",
"\n",
"instance MonadSample m => MonadSample (Traced m) where\n",
" random = Traced random (fmap singleton random)\n",
"\n",
"instance MonadCond m => MonadCond (Traced m) where\n",
" score w = Traced (score w) (score w >> pure (scored w))\n",
"\n",
"instance MonadInfer m => MonadInfer (Traced m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"and some more functions:"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"hoistT :: (forall x. m x -> m x) -> Traced m a -> Traced m a\n",
"hoistT f (Traced m d) = Traced m (f d)\n",
"\n",
"marginal :: Monad m => Traced m a -> m a\n",
"marginal (Traced _ d) = fmap output d"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"with this datatype, we can finally described a generic metropolis hastings algorithm:"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [],
"source": [
"mhStep :: MonadSample m => Traced m a -> Traced m a\n",
"mhStep (Traced m d) = Traced m d' where\n",
" d' = d >>= mhTrans m\n",
"\n",
"mh :: MonadSample m => Int -> Traced m a -> m [a]\n",
"mh n (Traced m d) = fmap (map output) t where\n",
" t = f n\n",
" f 0 = fmap (:[]) d\n",
" f k = do\n",
" x:xs <- f (k-1)\n",
" y <- mhTrans m x\n",
" return (y:x:xs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Intermission, What did we build here?\n",
"\n",
"the final type on which we can run an MCMC computation boils down to this:\n",
"\n",
"```\n",
"Traced (Weighted (FreeSampler Sampler Bool) (Sampler (Trace Bool))\n",
" ^ ^ ^ ^ ^\n",
" likelihood StateT (Log Double) Random variables random generator return type randomness, state and output of previous computation\n",
" \n",
" runWeighted (=runStateT) runFreeT + iterT \n",
" extractWeight withRandomness\n",
" prior withPartialRandomness ->\n",
" \n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Questions:\n",
"\n",
"* is Traced isomorphic to another `StateT` with `m Trace Bool` as state variable?\n",
"* can we get a simple specification for this type with capability, saying something like: `(HasState \"likelihood\" (Log Double), HasState \"Trace\" (Trace Bool), HasRandomElements(...), HasRandomGenerator (...)) => m Bool`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sequential"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"newtype Sequential m a = Sequential {runSequential :: Coroutine (Await ()) m a}\n",
" deriving(Functor,Applicative,Monad,MonadTrans,MonadIO)\n",
" \n",
"extract :: Await () a -> a\n",
"extract (Await f) = f ()\n",
"\n",
"-- | A point where the computation is paused.\n",
"suspend :: Monad m => Sequential m ()\n",
"suspend = Sequential await\n",
"\n",
"-- | Remove the remaining suspension points.\n",
"finish :: Monad m => Sequential m a -> m a\n",
"finish = pogoStick extract . runSequential\n",
"\n",
"-- | Run to the next suspension point.\n",
"-- If the computation is finished do nothing.\n",
"--\n",
"-- > finish = finish . advance\n",
"advance :: Monad m => Sequential m a -> Sequential m a\n",
"advance = Sequential . bounce extract . runSequential\n",
"\n",
"-- | Checks if no more suspension points remaining.\n",
"finished :: Monad m => Sequential m a -> m Bool\n",
"finished = fmap isRight . resume . runSequential\n",
"\n",
"-- | Transform the inner monad.\n",
"-- This operation only applies to computation up to the first suspension.\n",
"hoistFirstSeq :: (forall x. m x -> m x) -> Sequential m a -> Sequential m a\n",
"hoistFirstSeq f = Sequential . Coroutine . f . resume . runSequential\n",
"\n",
"-- | Transform the inner monad.\n",
"-- The transformation is applied recursively through all the suspension points.\n",
"hoistSeq :: (Monad m, Monad n) =>\n",
" (forall x. m x -> n x) -> Sequential m a -> Sequential n a\n",
"hoistSeq f = Sequential . mapMonad f . runSequential\n",
"\n",
"-- | Apply a function a given number of times.\n",
"composeCopies :: Int -> (a -> a) -> (a -> a)\n",
"composeCopies k f = foldr (.) id (replicate k f)\n",
"\n",
"-- | Sequential importance sampling.\n",
"-- Applies a given transformation after each time step.\n",
"sis :: Monad m\n",
" => (forall x. m x -> m x) -- ^ transformation\n",
" -> Int -- ^ number of time steps\n",
" -> Sequential m a\n",
" -> m a\n",
"sis f k = finish . composeCopies k (advance . hoistFirstSeq f)"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"instance MonadSample m => MonadSample (Sequential m) where\n",
" random = lift random\n",
" bernoulli = lift . bernoulli\n",
" categorical = lift . categorical\n",
"\n",
"instance MonadCond m => MonadCond (Sequential m) where\n",
" score w = lift (score w) >> suspend\n",
"\n",
"instance MonadInfer m => MonadInfer (Sequential m)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### test Sequential"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
"probComp :: Sequential (Weighted (FreeSampler Sampler)) Double\n",
"probComp = do\n",
" a <- random\n",
" b <- lift . lift. liftF $ Random (\\x -> if x < 0.5 then 0.0 else 1.0)\n",
" suspend\n",
" c <- random\n",
" score 0.5\n",
" score 0.1\n",
" return b"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<style>/* Styles used for the Hoogle display in the pager */\n",
".hoogle-doc {\n",
"display: block;\n",
"padding-bottom: 1.3em;\n",
"padding-left: 0.4em;\n",
"}\n",
".hoogle-code {\n",
"display: block;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"}\n",
".hoogle-text {\n",
"display: block;\n",
"}\n",
".hoogle-name {\n",
"color: green;\n",
"font-weight: bold;\n",
"}\n",
".hoogle-head {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-sub {\n",
"display: block;\n",
"margin-left: 0.4em;\n",
"}\n",
".hoogle-package {\n",
"font-weight: bold;\n",
"font-style: italic;\n",
"}\n",
".hoogle-module {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-class {\n",
"font-weight: bold;\n",
"}\n",
".get-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"display: block;\n",
"white-space: pre-wrap;\n",
"}\n",
".show-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"margin-left: 1em;\n",
"}\n",
".mono {\n",
"font-family: monospace;\n",
"display: block;\n",
"}\n",
".err-msg {\n",
"color: red;\n",
"font-style: italic;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"display: block;\n",
"}\n",
"#unshowable {\n",
"color: red;\n",
"font-weight: bold;\n",
"}\n",
".err-msg.in.collapse {\n",
"padding-top: 0.7em;\n",
"}\n",
".highlight-code {\n",
"white-space: pre;\n",
"font-family: monospace;\n",
"}\n",
".suggestion-warning { \n",
"font-weight: bold;\n",
"color: rgb(200, 130, 0);\n",
"}\n",
".suggestion-error { \n",
"font-weight: bold;\n",
"color: red;\n",
"}\n",
".suggestion-name {\n",
"font-weight: bold;\n",
"}\n",
"</style><span class='get-type'>resume :: forall (s :: * -> *) (m :: * -> *) r. Coroutine s m r -> m (Either (s (Coroutine s m r)) r)</span>"
],
"text/plain": [
"resume :: forall (s :: * -> *) (m :: * -> *) r. Coroutine s m r -> m (Either (s (Coroutine s m r)) r)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<style>/* Styles used for the Hoogle display in the pager */\n",
".hoogle-doc {\n",
"display: block;\n",
"padding-bottom: 1.3em;\n",
"padding-left: 0.4em;\n",
"}\n",
".hoogle-code {\n",
"display: block;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"}\n",
".hoogle-text {\n",
"display: block;\n",
"}\n",
".hoogle-name {\n",
"color: green;\n",
"font-weight: bold;\n",
"}\n",
".hoogle-head {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-sub {\n",
"display: block;\n",
"margin-left: 0.4em;\n",
"}\n",
".hoogle-package {\n",
"font-weight: bold;\n",
"font-style: italic;\n",
"}\n",
".hoogle-module {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-class {\n",
"font-weight: bold;\n",
"}\n",
".get-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"display: block;\n",
"white-space: pre-wrap;\n",
"}\n",
".show-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"margin-left: 1em;\n",
"}\n",
".mono {\n",
"font-family: monospace;\n",
"display: block;\n",
"}\n",
".err-msg {\n",
"color: red;\n",
"font-style: italic;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"display: block;\n",
"}\n",
"#unshowable {\n",
"color: red;\n",
"font-weight: bold;\n",
"}\n",
".err-msg.in.collapse {\n",
"padding-top: 0.7em;\n",
"}\n",
".highlight-code {\n",
"white-space: pre;\n",
"font-family: monospace;\n",
"}\n",
".suggestion-warning { \n",
"font-weight: bold;\n",
"color: rgb(200, 130, 0);\n",
"}\n",
".suggestion-error { \n",
"font-weight: bold;\n",
"color: red;\n",
"}\n",
".suggestion-name {\n",
"font-weight: bold;\n",
"}\n",
"</style><span class='get-type'>bounce extract :: forall (m :: * -> *) x. Monad m => Coroutine (Await ()) m x -> Coroutine (Await ()) m x</span>"
],
"text/plain": [
"bounce extract :: forall (m :: * -> *) x. Monad m => Coroutine (Await ()) m x -> Coroutine (Await ()) m x"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<style>/* Styles used for the Hoogle display in the pager */\n",
".hoogle-doc {\n",
"display: block;\n",
"padding-bottom: 1.3em;\n",
"padding-left: 0.4em;\n",
"}\n",
".hoogle-code {\n",
"display: block;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"}\n",
".hoogle-text {\n",
"display: block;\n",
"}\n",
".hoogle-name {\n",
"color: green;\n",
"font-weight: bold;\n",
"}\n",
".hoogle-head {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-sub {\n",
"display: block;\n",
"margin-left: 0.4em;\n",
"}\n",
".hoogle-package {\n",
"font-weight: bold;\n",
"font-style: italic;\n",
"}\n",
".hoogle-module {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-class {\n",
"font-weight: bold;\n",
"}\n",
".get-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"display: block;\n",
"white-space: pre-wrap;\n",
"}\n",
".show-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"margin-left: 1em;\n",
"}\n",
".mono {\n",
"font-family: monospace;\n",
"display: block;\n",
"}\n",
".err-msg {\n",
"color: red;\n",
"font-style: italic;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"display: block;\n",
"}\n",
"#unshowable {\n",
"color: red;\n",
"font-weight: bold;\n",
"}\n",
".err-msg.in.collapse {\n",
"padding-top: 0.7em;\n",
"}\n",
".highlight-code {\n",
"white-space: pre;\n",
"font-family: monospace;\n",
"}\n",
".suggestion-warning { \n",
"font-weight: bold;\n",
"color: rgb(200, 130, 0);\n",
"}\n",
".suggestion-error { \n",
"font-weight: bold;\n",
"color: red;\n",
"}\n",
".suggestion-name {\n",
"font-weight: bold;\n",
"}\n",
"</style><span class='get-type'>pogoStick extract :: forall (m :: * -> *) x. Monad m => Coroutine (Await ()) m x -> m x</span>"
],
"text/plain": [
"pogoStick extract :: forall (m :: * -> *) x. Monad m => Coroutine (Await ()) m x -> m x"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<style>/* Styles used for the Hoogle display in the pager */\n",
".hoogle-doc {\n",
"display: block;\n",
"padding-bottom: 1.3em;\n",
"padding-left: 0.4em;\n",
"}\n",
".hoogle-code {\n",
"display: block;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"}\n",
".hoogle-text {\n",
"display: block;\n",
"}\n",
".hoogle-name {\n",
"color: green;\n",
"font-weight: bold;\n",
"}\n",
".hoogle-head {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-sub {\n",
"display: block;\n",
"margin-left: 0.4em;\n",
"}\n",
".hoogle-package {\n",
"font-weight: bold;\n",
"font-style: italic;\n",
"}\n",
".hoogle-module {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-class {\n",
"font-weight: bold;\n",
"}\n",
".get-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"display: block;\n",
"white-space: pre-wrap;\n",
"}\n",
".show-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"margin-left: 1em;\n",
"}\n",
".mono {\n",
"font-family: monospace;\n",
"display: block;\n",
"}\n",
".err-msg {\n",
"color: red;\n",
"font-style: italic;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"display: block;\n",
"}\n",
"#unshowable {\n",
"color: red;\n",
"font-weight: bold;\n",
"}\n",
".err-msg.in.collapse {\n",
"padding-top: 0.7em;\n",
"}\n",
".highlight-code {\n",
"white-space: pre;\n",
"font-family: monospace;\n",
"}\n",
".suggestion-warning { \n",
"font-weight: bold;\n",
"color: rgb(200, 130, 0);\n",
"}\n",
".suggestion-error { \n",
"font-weight: bold;\n",
"color: red;\n",
"}\n",
".suggestion-name {\n",
"font-weight: bold;\n",
"}\n",
"</style><span class='get-type'>resume . adv $ a :: Weighted (FreeSampler Sampler) (Either (Await () (Coroutine (Await ()) (Weighted (FreeSampler Sampler)) Double)) Double)</span>"
],
"text/plain": [
"resume . adv $ a :: Weighted (FreeSampler Sampler) (Either (Await () (Coroutine (Await ()) (Weighted (FreeSampler Sampler)) Double)) Double)"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<style>/* Styles used for the Hoogle display in the pager */\n",
".hoogle-doc {\n",
"display: block;\n",
"padding-bottom: 1.3em;\n",
"padding-left: 0.4em;\n",
"}\n",
".hoogle-code {\n",
"display: block;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"}\n",
".hoogle-text {\n",
"display: block;\n",
"}\n",
".hoogle-name {\n",
"color: green;\n",
"font-weight: bold;\n",
"}\n",
".hoogle-head {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-sub {\n",
"display: block;\n",
"margin-left: 0.4em;\n",
"}\n",
".hoogle-package {\n",
"font-weight: bold;\n",
"font-style: italic;\n",
"}\n",
".hoogle-module {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-class {\n",
"font-weight: bold;\n",
"}\n",
".get-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"display: block;\n",
"white-space: pre-wrap;\n",
"}\n",
".show-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"margin-left: 1em;\n",
"}\n",
".mono {\n",
"font-family: monospace;\n",
"display: block;\n",
"}\n",
".err-msg {\n",
"color: red;\n",
"font-style: italic;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"display: block;\n",
"}\n",
"#unshowable {\n",
"color: red;\n",
"font-weight: bold;\n",
"}\n",
".err-msg.in.collapse {\n",
"padding-top: 0.7em;\n",
"}\n",
".highlight-code {\n",
"white-space: pre;\n",
"font-family: monospace;\n",
"}\n",
".suggestion-warning { \n",
"font-weight: bold;\n",
"color: rgb(200, 130, 0);\n",
"}\n",
".suggestion-error { \n",
"font-weight: bold;\n",
"color: red;\n",
"}\n",
".suggestion-name {\n",
"font-weight: bold;\n",
"}\n",
"</style><span class='get-type'>a :: Coroutine (Await ()) (Weighted (FreeSampler Sampler)) Double</span>"
],
"text/plain": [
"a :: Coroutine (Await ()) (Weighted (FreeSampler Sampler)) Double"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"<style>/* Styles used for the Hoogle display in the pager */\n",
".hoogle-doc {\n",
"display: block;\n",
"padding-bottom: 1.3em;\n",
"padding-left: 0.4em;\n",
"}\n",
".hoogle-code {\n",
"display: block;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"}\n",
".hoogle-text {\n",
"display: block;\n",
"}\n",
".hoogle-name {\n",
"color: green;\n",
"font-weight: bold;\n",
"}\n",
".hoogle-head {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-sub {\n",
"display: block;\n",
"margin-left: 0.4em;\n",
"}\n",
".hoogle-package {\n",
"font-weight: bold;\n",
"font-style: italic;\n",
"}\n",
".hoogle-module {\n",
"font-weight: bold;\n",
"}\n",
".hoogle-class {\n",
"font-weight: bold;\n",
"}\n",
".get-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"display: block;\n",
"white-space: pre-wrap;\n",
"}\n",
".show-type {\n",
"color: green;\n",
"font-weight: bold;\n",
"font-family: monospace;\n",
"margin-left: 1em;\n",
"}\n",
".mono {\n",
"font-family: monospace;\n",
"display: block;\n",
"}\n",
".err-msg {\n",
"color: red;\n",
"font-style: italic;\n",
"font-family: monospace;\n",
"white-space: pre;\n",
"display: block;\n",
"}\n",
"#unshowable {\n",
"color: red;\n",
"font-weight: bold;\n",
"}\n",
".err-msg.in.collapse {\n",
"padding-top: 0.7em;\n",
"}\n",
".highlight-code {\n",
"white-space: pre;\n",
"font-family: monospace;\n",
"}\n",
".suggestion-warning { \n",
"font-weight: bold;\n",
"color: rgb(200, 130, 0);\n",
"}\n",
".suggestion-error { \n",
"font-weight: bold;\n",
"color: red;\n",
"}\n",
".suggestion-name {\n",
"font-weight: bold;\n",
"}\n",
"</style><span class='get-type'>w :: Weighted (FreeSampler Sampler) Bool</span>"
],
"text/plain": [
"w :: Weighted (FreeSampler Sampler) Bool"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": [
"((True,5.000000000000001e-2),[0.1,0.8612396510899002,0.6406785100433282])"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
":t resume\n",
":t bounce extract\n",
":t pogoStick extract\n",
"adv = bounce extract\n",
"a = runSequential probComp\n",
":t resume . adv $ a\n",
":t a\n",
"w = fmap isRight . resume . adv . adv . adv $ a\n",
":t w\n",
"a = runSampler $ withPartialRandomness [0.1] $ runWeighted w\n",
"a"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Population"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [],
"source": [
"newtype Population m a = Population (Weighted (ListT m) a)\n",
" deriving(Functor,Applicative,Monad,MonadIO,MonadSample,MonadCond,MonadInfer)\n",
" \n",
"instance MonadTrans Population where\n",
" lift = Population . lift . lift"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```\n",
"Population (Weighted (ListT (FreeSampler Sampler Bool)\n",
" ^ ^ ^ ^ ^\n",
" likelihood StateT (Log Double) Random variables random generator return type\n",
" \n",
" runWeighted (=runStateT) runFreeT + iterT \n",
" extractWeight withRandomness\n",
" prior withPartialRandomness ->\n",
" \n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"-- | Explicit representation of the weighted sample with weights in log domain.\n",
"runPopulation :: Functor m => Population m a -> m [(a, Log Double)]\n",
"runPopulation (Population m) = runListT $ runWeighted m\n",
"\n",
"-- | Explicit representation of the weighted sample.\n",
"explicitPopulation :: Functor m => Population m a -> m [(a, Double)]\n",
"explicitPopulation = fmap (map (second (exp . ln))) . runPopulation\n",
"\n",
"-- | Initialise 'Population' with a concrete weighted sample.\n",
"fromWeightedList :: Monad m => m [(a,Log Double)] -> Population m a\n",
"fromWeightedList = Population . withWeight . ListT"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"-- | Increase the sample size by a given factor.\n",
"-- The weights are adjusted such that their sum is preserved.\n",
"-- It is therefore safe to use `spawn` in arbitrary places in the program\n",
"-- without introducing bias.\n",
"spawn :: Monad m => Int -> Population m ()\n",
"spawn n = fromWeightedList $ pure $ replicate n ((), 1 / fromIntegral n)\n",
"\n",
"resampleGeneric :: MonadSample m\n",
" => (V.Vector Double -> m [Int]) -- ^ resampler\n",
" -> Population m a -> Population m a\n",
"resampleGeneric resampler m = fromWeightedList $ do\n",
" pop <- runPopulation m\n",
" let (xs, ps) = unzip pop\n",
" let n = length xs\n",
" let z = sum ps\n",
" if z > 0 then do\n",
" let weights = V.fromList (map (exp . ln . (/z)) ps)\n",
" ancestors <- resampler weights\n",
" let xvec = V.fromList xs\n",
" let offsprings = map (xvec V.!) ancestors\n",
" return $ map (, z / fromIntegral n) offsprings\n",
" else\n",
" -- if all weights are zero do not resample\n",
" return pop\n",
"\n",
"-- | Systematic resampling helper.\n",
"systematic :: Double -> V.Vector Double -> [Int]\n",
"systematic u ps = f 0 (u / fromIntegral n) 0 0 [] where\n",
" prob i = ps V.! i\n",
" n = length ps\n",
" inc = 1 / fromIntegral n\n",
" f i _ _ _ acc | i == n = acc\n",
" f i v j q acc =\n",
" if v < q then\n",
" f (i+1) (v+inc) j q (j-1:acc)\n",
" else\n",
" f i v (j + 1) (q + prob j) acc\n",
"\n",
"-- | Resample the population using the underlying monad and a systematic resampling scheme.\n",
"-- The total weight is preserved.\n",
"resampleSystematic :: (MonadSample m)\n",
" => Population m a -> Population m a\n",
"resampleSystematic = resampleGeneric (\\ps -> (`systematic` ps) <$> random)\n",
"\n",
"-- | Multinomial resampler.\n",
"multinomial :: MonadSample m => V.Vector Double -> m [Int]\n",
"multinomial ps = replicateM (V.length ps) (categorical ps)\n",
"\n",
"-- | Resample the population using the underlying monad and a multinomial resampling scheme.\n",
"-- The total weight is preserved.\n",
"resampleMultinomial :: (MonadSample m)\n",
" => Population m a -> Population m a\n",
"resampleMultinomial = resampleGeneric multinomial\n",
"\n",
"-- | Separate the sum of weights into the 'Weighted' transformer.\n",
"-- Weights are normalized after this operation.\n",
"extractEvidence :: Monad m\n",
" => Population m a -> Population (Weighted m) a\n",
"extractEvidence m = fromWeightedList $ do\n",
" pop <- lift $ runPopulation m\n",
" let (xs, ps) = unzip pop\n",
" let z = sum ps\n",
" let ws = map (if z > 0 then (/ z) else const (1 / fromIntegral (length ps))) ps\n",
" score z\n",
" return $ zip xs ws\n",
"\n",
"-- | Push the evidence estimator as a score to the transformed monad.\n",
"-- Weights are normalized after this operation.\n",
"pushEvidence :: MonadCond m\n",
" => Population m a -> Population m a\n",
"pushEvidence = hoistPop applyWeight . extractEvidence\n",
"\n",
"-- | A properly weighted single sample, that is one picked at random according\n",
"-- to the weights, with the sum of all weights.\n",
"proper :: (MonadSample m)\n",
" => Population m a -> Weighted m a\n",
"proper m = do\n",
" pop <- runPopulation $ extractEvidence m\n",
" let (xs, ps) = unzip pop\n",
" index <- logCategorical $ V.fromList ps\n",
" let x = xs !! index\n",
" return x\n",
"\n",
"-- | Model evidence estimator, also known as pseudo-marginal likelihood.\n",
"evidence :: (Monad m) => Population m a -> m (Log Double)\n",
"evidence = extractWeight . runPopulation . extractEvidence\n",
"\n",
"-- | Picks one point from the population and uses model evidence as a 'score'\n",
"-- in the transformed monad.\n",
"-- This way a single sample can be selected from a population without\n",
"-- introducing bias.\n",
"collapse :: (MonadInfer m)\n",
" => Population m a -> m a\n",
"collapse = applyWeight . proper\n",
"\n",
"-- | Applies a random transformation to a population.\n",
"mapPopulation :: (Monad m) => ([(a, Log Double)] -> m [(a, Log Double)]) ->\n",
" Population m a -> Population m a\n",
"mapPopulation f m = fromWeightedList $ runPopulation m >>= f\n",
"\n",
"-- | Normalizes the weights in the population so that their sum is 1.\n",
"-- This transformation introduces bias.\n",
"normalize :: (Monad m) => Population m a -> Population m a\n",
"normalize = hoistPop prior . extractEvidence\n",
"\n",
"-- | Population average of a function, computed using unnormalized weights.\n",
"popAvg :: (Monad m) => (a -> Double) -> Population m a -> m Double\n",
"popAvg f p = do\n",
" xs <- explicitPopulation p\n",
" let ys = map (\\(x,w) -> f x * w) xs\n",
" let t = Data.List.sum ys\n",
" return t\n",
"\n",
"-- | Combine a population of populations into a single population.\n",
"flattenPop :: Monad m => Population (Population m) a -> Population m a\n",
"flattenPop m = Population $ withWeight $ ListT t where\n",
" t = f <$> (runPopulation . runPopulation) m\n",
" f d = do\n",
" (x,p) <- d\n",
" (y,q) <- x\n",
" return (y, p*q)\n",
"\n",
"-- | Applies a transformation to the inner monad.\n",
"hoistPop :: (Monad m, Monad n)\n",
" => (forall x. m x -> n x) -> Population m a -> Population n a\n",
"hoistPop f = fromWeightedList . f . runPopulation"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
"-- | Sequential importance resampling.\n",
"-- Basically an SMC template that takes a custom resampler.\n",
"sir :: Monad m\n",
" => (forall x. Population m x -> Population m x) -- ^ resampler\n",
" -> Int -- ^ number of timesteps\n",
" -> Int -- ^ population size\n",
" -> Sequential (Population m) a -- ^ model\n",
" -> Population m a\n",
"sir resampler k n = sis resampler k . hoistFirstSeq (spawn n >>)\n",
"\n",
"-- | Sequential Monte Carlo with multinomial resampling at each timestep.\n",
"-- Weights are not normalized.\n",
"smcMultinomial :: MonadSample m\n",
" => Int -- ^ number of timesteps\n",
" -> Int -- ^ number of particles\n",
" -> Sequential (Population m) a -- ^ model\n",
" -> Population m a\n",
"smcMultinomial = sir resampleMultinomial\n",
"\n",
"-- | Sequential Monte Carlo with systematic resampling at each timestep.\n",
"-- Weights are not normalized.\n",
"smcSystematic :: MonadSample m\n",
" => Int -- ^ number of timesteps\n",
" -> Int -- ^ number of particles\n",
" -> Sequential (Population m) a -- ^ model\n",
" -> Population m a\n",
"smcSystematic = sir resampleSystematic\n",
"\n",
"-- | Sequential Monte Carlo with multinomial resampling at each timestep.\n",
"-- Weights are normalized at each timestep and the total weight is pushed\n",
"-- as a score into the transformed monad.\n",
"smcMultinomialPush :: MonadInfer m\n",
" => Int -- ^ number of timesteps\n",
" -> Int -- ^ number of particles\n",
" -> Sequential (Population m) a -- ^ model\n",
" -> Population m a\n",
"smcMultinomialPush = sir (pushEvidence . resampleMultinomial)\n",
"\n",
"-- | Sequential Monte Carlo with systematic resampling at each timestep.\n",
"-- Weights are normalized at each timestep and the total weight is pushed\n",
"-- as a score into the transformed monad.\n",
"smcSystematicPush :: MonadInfer m\n",
" => Int -- ^ number of timesteps\n",
" -> Int -- ^ number of particles\n",
" -> Sequential (Population m) a -- ^ model\n",
" -> Population m a\n",
"smcSystematicPush = sir (pushEvidence . resampleSystematic)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Haskell - monad-bayes",
"language": "haskell",
"name": "ihaskell_monad-bayes"
},
"language_info": {
"codemirror_mode": "ihaskell",
"file_extension": ".hs",
"name": "haskell",
"pygments_lexer": "Haskell",
"version": "8.6.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
let
jupyterLib = builtins.fetchGit {
url = https://github.com/tweag/jupyterWith;
rev = "70f1dddd6446ab0155a5b0ff659153b397419a2d";
};
nixpkgsPath = jupyterLib + "/nix";
haskellOverlay = import ./haskell-overlay.nix;
pkgs = import nixpkgsPath {overlays = [ haskellOverlay ]; config={allowUnfree=true; allowBroken=true;};};
jupyter = import jupyterLib {pkgs=pkgs;};
ihaskellWithPackages = jupyter.kernels.iHaskellWith {
#extraIHaskellFlags = "--debug";
haskellPackages = pkgs.haskell.packages.ghc865;
name = "monad-bayes";
packages = p: with p; [
monad-bayes
];
};
jupyterlabWithKernels =
jupyter.jupyterlabWith {
kernels = [ ihaskellWithPackages ];
directory = jupyter.mkDirectoryWith {
extensions = [
"jupyterlab-ihaskell"
];
};
};
in
jupyterlabWithKernels.env
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment