Skip to content

Instantly share code, notes, and snippets.

@devmotion
Last active November 24, 2019 17:05
Show Gist options
  • Select an option

  • Save devmotion/6bab2561a9c340bac03c50b3fec73441 to your computer and use it in GitHub Desktop.

Select an option

Save devmotion/6bab2561a9c340bac03c50b3fec73441 to your computer and use it in GitHub Desktop.
Elliptical slice sampling
using Turing
using StatsPlots
using Random
using Statistics
function demo(N::Int; n::Int = 10)
# observation noise
σ² = 0.3
# define model
@model gdemo(x) = begin
m ~ Normal(0, 1)
x ~ MvNormal(fill(m, length(x)), sqrt(σ²) * I)
end
# define observations
Random.seed!(1234)
x = vec(rand(Normal(1.4, sqrt(σ²)), n))
# generate MCMC chain
chain = sample(gdemo(x), ESS(), N)
# compute posterior solution
τ² = inv(1 + length(x) / σ²)
μ = τ² / σ² * sum(x)
posterior = Normal(μ, sqrt(τ²))
# compare estimates
chain_array = vec(convert(Array, chain[:m]))
@show μ, mean(chain_array)
@show τ², var(chain_array)
# plot chain and posterior pdf
plot(chain)
plot!(posterior; subplot = 2, linestyle = :dash)
end
using Random: randexp
struct ESS{space} <: InferenceAlgorithm end
ESS() = ESS{()}()
ESS(space::Symbol...) = ESS{space}()
getspace(::ESS{space}) where {space} = space
getspace(::Type{ESS{space}}) where {space} = space
transition_type(spl::Sampler{<:ESS}) = typeof(Transition(spl))
alg_str(::Sampler{<:ESS}) = "ESS"
mutable struct ESSState{V<:VarInfo, F<:AbstractFloat} <: AbstractSamplerState
vi::V
lp::F
end
ESSState(model::Model) = ESSState(VarInfo(model), 0.0)
function Sampler(alg::ESS, model::Model, s::Selector)
info = Dict{Symbol, Any}()
state = ESSState(model)
return Sampler(alg, info, s, state)
end
function step!(
::AbstractRNG,
model::Model,
spl::Sampler{<:ESS},
::Integer;
kwargs...
)
return Transition(spl)
end
function step!(
::AbstractRNG,
model::Model,
spl::Sampler{<:ESS},
::Integer,
::Transition;
kwargs...
)
# recompute joint in logp
if spl.selector.tag !== :default
runmodel!(model, spl.state.vi)
end
# obtain previous sample and its log-likelihood
f = copy(spl.state.vi[spl])
logp_f = getlogp(spl.state.vi)
# sample log-likelihood threshold for the next sample
threshold = logp_f - randexp()
# sample from the prior
runmodel!(model, spl.state.vi, spl)
ν = spl.state.vi[spl]
# sample initial angle
θ = 2 * π * rand()
θₘᵢₙ = θ - 2 * π
θₘₐₓ = θ
# compute initial proposal
sinθ, cosθ = sincos(θ)
spl.state.vi[spl] = proposal = @. f * cosθ + ν * sinθ
runmodel!(model, spl.state.vi, spl)
# while the log-likelihood threshold is not reached
while getlogp(spl.state.vi) < threshold
# shrink the bracket
if θ < 0
θₘᵢₙ = θ
else
θₘₐₓ = θ
end
# sample new angle
θ = θₘᵢₙ + rand() * (θₘₐₓ - θₘᵢₙ)
# update the proposal
sinθ, cosθ = sincos(θ)
@. proposal = f * cosθ + ν * sinθ
runmodel!(model, spl.state.vi, spl)
end
return Transition(spl)
end
function assume(spl::Sampler{<:ESS}, dist::Distribution, vn::VarName, vi::VarInfo)
if isempty(getspace(spl.alg)) || vn.sym in getspace(spl.alg)
r = rand(dist)
vi[vn] = vectorize(dist, r)
setgid!(vi, spl.selector, vn)
else
r = vi[vn]
end
r, logpdf(dist, r)
end
function observe(spl::Sampler{<:ESS}, dist::Distribution, value::Any, vi::VarInfo)
return observe(nothing, dist, value, vi)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment