Last active
January 25, 2025 16:38
-
-
Save madsbuch/5a8a1fc9b70621dd93dd70058754b126 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE GADTs #-} | |
{- | |
The following code is based on experimental code by Aslan Askerov | |
based on Ramsey and Pfeffers "Stochastic Lambda Calculus and Monads of | |
Probability Distributions". Implementation of random n is from | |
Audebaud and Paulin-Mohring paper, so is the random walk example. | |
This gist is used here http://madsbuch.com/the-probability-monad/ | |
The class hierarchy is as follows: | |
+----------------------------------------+ | |
| | | |
| Monad | | |
| | | |
+--------------------+-------------------+ | |
| | |
V | |
+-----------------------------------------+ | |
| | | |
| Probability Monad | | |
| | | |
+------+--------------+--------------+----+ | |
| | | | |
V V V | |
+-------------+ +---------+ +-----------+ | |
| | | | | | | |
| Expectation | | Support | | Sample | | |
| Monad | | Monad | | Monad | | |
| | | | | | | |
+-------------+ +---------+ +-----------+ | |
Make sure to have have the cabal random package | |
installed: `cabal install random` | |
-} | |
module ProbMonad where | |
import Data.List | |
import qualified System.Random as R | |
import Control.Applicative -- Otherwise you can't do the Applicative instance. | |
import Control.Monad (liftM, ap) | |
type Probability = Double -- number from 0 to 1 | |
-- Probability Monad | |
class Monad m => ProbabilityMonad m where | |
choose :: Probability -> m a -> m a -> m a | |
-- Support Monad | |
class ProbabilityMonad m => SupportMonad m where | |
support :: m a -> [a] | |
-- Expectation Monad | |
class ProbabilityMonad m => ExpMonad m where | |
expectation :: (a -> Double) -> m a -> Double | |
class ProbabilityMonad m => SamplingMonad m where | |
sample :: R.RandomGen g => m a -> g -> (a, g) | |
-- Probability Monad Type | |
newtype PExp a = PExp (( a -> Double) -> Double) | |
-- PExp needs to be a functor to be a monad | |
instance Functor PExp where | |
fmap = liftM | |
-- PExp needs to be an applicative to be a monad | |
instance Applicative PExp where | |
pure = return | |
(<*>) = ap | |
-- PExp is a monad | |
instance Monad PExp where | |
return x = PExp (\h -> h x) | |
(PExp d) >>= k = | |
PExp (\h -> let | |
apply (PExp f) arg = f arg | |
g x = apply (k x) h | |
in | |
d g ) | |
-- PExp is a probability monad | |
instance ProbabilityMonad PExp where | |
choose p (PExp d1) (PExp d2) = | |
PExp (\h -> p * d1 h + (1 - p) * d2 h) | |
-- Not easily implemented | |
instance SupportMonad PExp where | |
support (PExp h) = undefined | |
-- Easily implemented! | |
instance ExpMonad PExp where | |
expectation h (PExp d) = d h | |
-- Not easily implemented | |
instance SamplingMonad PExp where | |
sample = undefined | |
{-- The general probability monad --} | |
data P a where | |
R :: a -> P a | |
B :: P a -> (a -> P b) -> P b -- The reason for GADT | |
C :: Probability -> P a -> P a -> P a | |
-- P needs to be a functor to be a monad | |
instance Functor P where | |
fmap = liftM | |
-- P needs to be an applicative to be a monad | |
instance Applicative P where | |
pure = return | |
(<*>) = ap | |
-- P is a monad | |
instance Monad P where | |
return x = R x | |
d >>= k = B d k | |
-- P is a probability monad | |
instance ProbabilityMonad P where | |
choose p d1 d2 = C p d1 d2 | |
instance SupportMonad P where | |
support (R x) = [x] | |
support (B d k) = concat [support (k x) | x <- support d] | |
support (C p d1 d2) = support d1 ++ support d2 | |
instance ExpMonad P where | |
expectation h (R x) = h x | |
expectation h (B d k) = expectation g d | |
where | |
g x = expectation h (k x) | |
expectation h (C p d1 d2) = | |
(p * expectation h d1) | |
+ ((1-p) * expectation h d2) | |
instance SamplingMonad P where | |
sample (R x) g = (x, g) | |
sample (B d k) g = let | |
(x, g') = sample d g | |
in | |
sample (k x) g' | |
sample (C p d1 d2) g = let | |
(x, g') = R.random g | |
in | |
sample (if x < p then d1 else d2) g' | |
{-- Helper functions --} | |
prob :: Bool -> Probability | |
prob b = if b then 1 else 0 | |
uniform :: [a] -> P a | |
uniform [x] = return x | |
uniform ls@(x:xs) = | |
let p = 1.0 / ( fromIntegral (length ls) ) | |
in choose p (return x) (uniform xs) | |
-- taking samples | |
nSamples :: R.RandomGen g => Int -> P a -> g -> [(a, g)] | |
nSamples 0 dist gen = [] | |
nSamples n dist rGen = let | |
(g1, g2) = R.split rGen | |
in | |
(sample dist g1) : (nSamples (n-1) dist g2) | |
{-- Examples --} | |
-- We consider a dice | |
data Dice = One | Two | Three | Four | Five | Six | |
deriving (Enum, Eq, Show, Read, Ord) | |
-- Simple example of support | |
example01a = | |
let dist :: P Dice | |
dist = uniform [One .. Six] | |
in support dist | |
-- Simple example of expectation | |
example01b = | |
let dist :: P Dice | |
dist = uniform [One .. Six] | |
event s = prob (s == Six) | |
in expectation event dist | |
-- Simple example of sampling | |
example01c = | |
let dist :: P Dice | |
dist = uniform [One .. Six] | |
randGen = R.mkStdGen 42 | |
in map (\(a, p) -> a) (nSamples 10 dist randGen) | |
{- Using prior distributions -} | |
example02a = | |
let dist :: P Dice | |
dist = do | |
d <- uniform [One .. Six] | |
return (if d == Six then One else d) | |
in support dist | |
example02b = | |
let dist :: P Dice | |
dist = do | |
d <- uniform [One .. Six] | |
return (if d == Six then One else d) | |
event s = prob (s == Six) | |
in expectation event dist | |
example02c = | |
let dist :: P Dice | |
dist = do | |
d <- uniform [One .. Six] | |
return (if d == Six then One else d) | |
randGen = R.mkStdGen 42 | |
in map (\(a, p) -> a) (nSamples 10 dist randGen) | |
-- We need a completely enumerable world | |
walk x = | |
do bit <- uniform [True, False] | |
if bit then | |
return x | |
else walk (x + 1) | |
-- This doesn't terminate! Guess why? | |
example03a = support (walk 0) | |
example03b = expectation (\x -> prob (x < 5)) (walk 0) | |
example03c = map (\(a, p) -> a) (nSamples 10 (walk 0) (R.mkStdGen 42)) | |
mc = map (\a -> (head a, length a)) $ group $ sort xs | |
where | |
xs = map (\(a, p) -> a) (nSamples 10000 (walk 0) (R.mkStdGen 42)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment