Last active
December 19, 2019 22:42
-
-
Save devmotion/601d95112df0920cdb9098cd2f2942e2 to your computer and use it in GitHub Desktop.
ESS examples
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Turing | |
using StatsPlots | |
using Random | |
using Statistics | |
function demo(N::Int; n::Int = 10) | |
# observation noise | |
σ² = 0.3 | |
# define model | |
@model gdemo(x) = begin | |
m ~ Normal(0, 1) | |
x ~ MvNormal(fill(m, length(x)), sqrt(σ²)) | |
end | |
# define observations | |
Random.seed!(1234) | |
x = vec(rand(Normal(1.4, sqrt(σ²)), n)) | |
# generate MCMC chain | |
chain = sample(gdemo(x), ESS(), N) | |
#chain = sample(gdemo(x), NUTS(), N) | |
# compute posterior solution | |
τ² = inv(1 + length(x) / σ²) | |
μ = τ² / σ² * sum(x) | |
posterior = Normal(μ, sqrt(τ²)) | |
# compare estimates | |
chain_array = vec(convert(Array, chain[:m])) | |
@show μ, mean(chain_array) | |
@show τ², var(chain_array) | |
# plot chain and posterior pdf | |
plot(chain) | |
plot!(posterior; subplot = 2, linestyle = :dash) | |
end | |
function gdemo(N::Int; nparticles::Int = 15) | |
# define model | |
@model gdemo(x, y) = begin | |
s ~ InverseGamma(2, 3) | |
m ~ Normal(0, sqrt(s)) | |
x ~ Normal(m, sqrt(s)) | |
y ~ Normal(m, sqrt(s)) | |
return s, m | |
end | |
# generate MCMC chain | |
Random.seed!(100) | |
chain = sample(gdemo(1.5, 2.0), Gibbs(CSMC(nparticles, :s), ESS(:m)), N) | |
#chain = sample(gdemo(1.5, 2.0), Gibbs(CSMC(nparticles, :s), HMC(0.2, 4, :m)), N) | |
# define posterior solutions | |
μ = 7 / 6 | |
λ = 3 | |
α = 3 | |
β = 49 / 12 | |
posterior_m = LocationScale(μ, sqrt(β / (λ * α)), TDist(2 * α)) | |
posterior_s = InverseGamma(α, β) | |
# compare estimates | |
chain_array_m = vec(convert(Array, chain[:m])) | |
@show mean(posterior_m), mean(chain_array_m) | |
@show var(posterior_m), var(chain_array_m) | |
chain_array_s = vec(convert(Array, chain[:s])) | |
@show mean(posterior_s), mean(chain_array_s) | |
@show var(posterior_s), var(chain_array_s) | |
# plot chain and posterior pdf | |
p1 = plot(chain[:m]) | |
plot!(p1, posterior_m; subplot = 2, linestyle = :dash) | |
p2 = plot(chain[:s]) | |
plot!(p2, posterior_s; subplot = 2, linestyle = :dash) | |
plot(p1, p2; layout = @layout [a; b]) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment