Last active
November 24, 2019 17:05
-
-
Save devmotion/6bab2561a9c340bac03c50b3fec73441 to your computer and use it in GitHub Desktop.
Elliptical slice sampling
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using Turing | |
| using StatsPlots | |
| using Random | |
| using Statistics | |
| function demo(N::Int; n::Int = 10) | |
| # observation noise | |
| σ² = 0.3 | |
| # define model | |
| @model gdemo(x) = begin | |
| m ~ Normal(0, 1) | |
| x ~ MvNormal(fill(m, length(x)), sqrt(σ²) * I) | |
| end | |
| # define observations | |
| Random.seed!(1234) | |
| x = vec(rand(Normal(1.4, sqrt(σ²)), n)) | |
| # generate MCMC chain | |
| chain = sample(gdemo(x), ESS(), N) | |
| # compute posterior solution | |
| τ² = inv(1 + length(x) / σ²) | |
| μ = τ² / σ² * sum(x) | |
| posterior = Normal(μ, sqrt(τ²)) | |
| # compare estimates | |
| chain_array = vec(convert(Array, chain[:m])) | |
| @show μ, mean(chain_array) | |
| @show τ², var(chain_array) | |
| # plot chain and posterior pdf | |
| plot(chain) | |
| plot!(posterior; subplot = 2, linestyle = :dash) | |
| end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| using Random: randexp | |
| struct ESS{space} <: InferenceAlgorithm end | |
| ESS() = ESS{()}() | |
| ESS(space::Symbol...) = ESS{space}() | |
| getspace(::ESS{space}) where {space} = space | |
| getspace(::Type{ESS{space}}) where {space} = space | |
| transition_type(spl::Sampler{<:ESS}) = typeof(Transition(spl)) | |
| alg_str(::Sampler{<:ESS}) = "ESS" | |
| mutable struct ESSState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState | |
| vi::V | |
| lp::F | |
| end | |
| ESSState(model::Model) = ESSState(VarInfo(model), 0.0) | |
| function Sampler(alg::ESS, model::Model, s::Selector) | |
| info = Dict{Symbol, Any}() | |
| state = ESSState(model) | |
| return Sampler(alg, info, s, state) | |
| end | |
| function step!( | |
| ::AbstractRNG, | |
| model::Model, | |
| spl::Sampler{<:ESS}, | |
| ::Integer; | |
| kwargs... | |
| ) | |
| return Transition(spl) | |
| end | |
| function step!( | |
| ::AbstractRNG, | |
| model::Model, | |
| spl::Sampler{<:ESS}, | |
| ::Integer, | |
| ::Transition; | |
| kwargs... | |
| ) | |
| # recompute joint in logp | |
| if spl.selector.tag !== :default | |
| runmodel!(model, spl.state.vi) | |
| end | |
| # obtain previous sample and its log-likelihood | |
| f = copy(spl.state.vi[spl]) | |
| logp_f = getlogp(spl.state.vi) | |
| # sample log-likelihood threshold for the next sample | |
| threshold = logp_f - randexp() | |
| # sample from the prior | |
| runmodel!(model, spl.state.vi, spl) | |
| ν = spl.state.vi[spl] | |
| # sample initial angle | |
| θ = 2 * π * rand() | |
| θₘᵢₙ = θ - 2 * π | |
| θₘₐₓ = θ | |
| # compute initial proposal | |
| sinθ, cosθ = sincos(θ) | |
| spl.state.vi[spl] = proposal = @. f * cosθ + ν * sinθ | |
| runmodel!(model, spl.state.vi, spl) | |
| # while the log-likelihood threshold is not reached | |
| while getlogp(spl.state.vi) < threshold | |
| # shrink the bracket | |
| if θ < 0 | |
| θₘᵢₙ = θ | |
| else | |
| θₘₐₓ = θ | |
| end | |
| # sample new angle | |
| θ = θₘᵢₙ + rand() * (θₘₐₓ - θₘᵢₙ) | |
| # update the proposal | |
| sinθ, cosθ = sincos(θ) | |
| @. proposal = f * cosθ + ν * sinθ | |
| runmodel!(model, spl.state.vi, spl) | |
| end | |
| return Transition(spl) | |
| end | |
| function assume(spl::Sampler{<:ESS}, dist::Distribution, vn::VarName, vi::VarInfo) | |
| if isempty(getspace(spl.alg)) || vn.sym in getspace(spl.alg) | |
| r = rand(dist) | |
| vi[vn] = vectorize(dist, r) | |
| setgid!(vi, spl.selector, vn) | |
| else | |
| r = vi[vn] | |
| end | |
| r, logpdf(dist, r) | |
| end | |
| function observe(spl::Sampler{<:ESS}, dist::Distribution, value::Any, vi::VarInfo) | |
| return observe(nothing, dist, value, vi) | |
| end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment