Created
July 2, 2020 13:21
-
-
Save idontgetoutmuch/5854e88a6a91d42aab05ef1a45e7df53 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import qualified Data.Vector.Unboxed as V | |
import Data.Random.Source.PureMT | |
import Data.Random | |
import Control.Monad.State | |
import Data.Histogram ( asList ) | |
import Data.Histogram.Fill | |
import Data.Histogram.Generic ( Histogram ) | |
import Debug.Trace | |
nSamples, seed :: Int | |
nSamples = 10 | |
seed = 2 | |
mu0, sigma0, sigma :: Double | |
mu0 = 11.0 | |
sigma0 = 2.0 | |
sigma = 1.0 | |
d :: Double | |
d = 2.0 | |
simpleXs :: [Double] | |
simpleXs = | |
evalState (replicateM nSamples hierarchicalSample) | |
(pureMT $ fromIntegral seed) | |
where | |
hierarchicalSample = do | |
mu <- sample (Normal mu0 sigma0) | |
sample (Normal mu sigma) | |
normalisedProposals :: Int -> Double -> Int -> [Double] | |
normalisedProposals seed sigma nIters = | |
evalState (replicateM nIters (sample (Normal 0.0 sigma))) | |
(pureMT $ fromIntegral seed) | |
acceptOrRejects :: Int -> Int -> [Double] | |
acceptOrRejects seed nIters = | |
evalState (replicateM nIters (sample stdUniform)) | |
(pureMT $ fromIntegral seed) | |
prior :: Double -> Double | |
prior mu = exp (-(mu - mu0)**2 / (2 * sigma0**2)) | |
likelihood :: Double -> [Double] -> Double | |
likelihood mu xs = exp (-sum (map (\x -> (x - mu)**2 / (2 * sigma**2)) xs)) | |
posterior :: Double -> [Double] -> Double | |
posterior mu xs = likelihood mu xs * prior mu | |
acceptanceProb :: Double -> Double -> [Double] -> Double | |
acceptanceProb mu mu' xs = min 1.0 (posterior mu' xs / posterior mu xs) | |
oneStep :: (Double, Int) -> (Double, Double) -> (Double, Int) | |
oneStep (mu, nAccs) (proposedJump, acceptOrReject) = | |
if acceptOrReject < acceptanceProb mu (mu + proposedJump) simpleXs | |
then (mu + proposedJump, nAccs + 1) | |
else (mu, nAccs) | |
test :: [(Double, Int)] | |
test = drop 200000 $ | |
scanl oneStep (10.0, 0) $ | |
zip (normalisedProposals 3 0.4 300000) (acceptOrRejects 4 300000) | |
hb :: HBuilder Double (Data.Histogram.Generic.Histogram V.Vector BinD Double) | |
hb = forceDouble -<< mkSimple (binD (10.0 - 1.5*sigma0) 400 (10.0 + 1.5*sigma0)) | |
hist :: Histogram V.Vector BinD Double | |
hist = fillBuilder hb (map fst test) | |
logit p = log (p / (1 - p)) | |
invLogit x = exp x / (1 + exp x) | |
a = 9 | |
b = 12 | |
simpleYs = filter (< b) $ filter (> a) simpleXs | |
-- simpleYs | |
oneStep' :: (Double, Int) -> (Double, Double) -> (Double, Int) | |
oneStep' (mu, nAccs) (proposedJump, acceptOrReject) = | |
trace (show mu ++ " " ++ | |
show liftedMu ++ " " ++ | |
show proposedJump ++ " " ++ | |
show proposedLiftedMu ++ " " ++ | |
show proposedMu ++ " " ++ | |
show (acceptanceProb mu proposedMu simpleYs)) $ | |
if acceptOrReject < acceptanceProb mu proposedMu simpleYs | |
then (proposedMu, nAccs + 1) | |
else (mu, nAccs) | |
where | |
liftedMu = logit ((mu - a) / (b - a)) | |
proposedLiftedMu = liftedMu + proposedJump | |
proposedMu = a + (b - a) * invLogit proposedLiftedMu | |
-- logit ((9.000000000018865 - a) / (b - a)) | |
-- a + (b - a) * invLogit (logit ((9.000000000018865 - a) / (b - a))) | |
-- logit ((9.00000000002115 - a) / (b - a)) | |
-- a + (b - a) * invLogit (logit ((9.00000000002115 - a) / (b - a))) | |
test' :: [(Double, Int)] | |
test' = -- drop 100000 $ | |
scanl oneStep' (10.0, 0) $ | |
zip (normalisedProposals 3 10.0 40) (acceptOrRejects 4 40) | |
trial = take 40 test' | |
-- take 10 $ acceptOrRejects 4 300000 | |
-- take 10 $ drop 100000 $ acceptOrRejects 4 300000 | |
-- acceptanceProb 9.00000000002115 9.000000000018865 simpleYs | |
-- posterior 9.00000000002115 simpleYs | |
-- posterior 9.000000000018865 simpleYs | |
-- hist' :: Histogram V.Vector BinD Double | |
-- hist' = fillBuilder hb (map fst test') | |
-- xsA = map fst $ asList hist' | |
-- ysA = map snd $ asList hist' | |
-- length simpleYs | |
-- (/5) $ sum simpleYs | |
-- (/10) $ sum simpleXs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment