Last active
July 29, 2020 16:55
-
-
Save torfjelde/4bed4361990d5040acc147da6afae530 to your computer and use it in GitHub Desktop.
Implementation of a `condition` method for Turing.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
julia> using Turing | |
julia> function condition( | |
model::Turing.Model, | |
latent::NamedTuple; | |
sampler = DynamicPPL.SampleFromPrior() | |
) | |
vi = Turing.VarInfo(model) | |
md = vi.metadata | |
for v in keys(md) | |
for vn in md[v].vns | |
vn_symbol = Symbol(vn) | |
if vn_symbol ∈ keys(latent) | |
tmp = latent[vn_symbol] | |
val = tmp isa Real ? [tmp] : tmp | |
DynamicPPL.setval!(vi, val, vn) | |
DynamicPPL.settrans!(vi, false, vn) | |
else | |
# delete so we can sample from prior | |
DynamicPPL.set_flag!(vi, vn, "del") | |
end | |
end | |
end | |
# Execute `model` on the parameters set in `vi` and sample those with `"del"` flag using `sampler` | |
model(vi, sampler) | |
# Convert `VarInfo` into `NamedTuple` and save | |
theta = DynamicPPL.tonamedtuple(vi) | |
lp = Turing.getlogp(vi) | |
return Turing.Inference.Transition(theta, lp) | |
end | |
condition (generic function with 1 method) | |
julia> function transition2nt(t::Turing.Inference.Transition; include = nothing, exclude = nothing) | |
include = include === nothing ? keys(t.θ) : include | |
exclude = exclude === nothing ? [] : exclude | |
return (; [(k, first(t.θ[k])) for k in keys(t.θ) if k ∉ exclude && k ∈ include]...) | |
end | |
transition2nt (generic function with 1 method) | |
julia> @model gsdemo(xs, ys) = begin | |
# Assumptions | |
σ ~ Uniform(0.0, 5.0) | |
μ ~ Uniform(0.0, 5.0) | |
# Observations | |
for i = 1:length(xs) | |
xs[i] ~ Normal(μ, σ) | |
ys[i] ~ Normal(μ, σ) | |
end | |
end | |
DynamicPPL.ModelGen{var"###generator#274",(:xs, :ys),(),Tuple{}}(##generator#274, NamedTuple()) | |
julia> m_obs = gsdemo([missing], [missing]); | |
julia> t = condition(m_obs, (σ = [3.0], μ = [1.0])) | |
Turing.Inference.Transition{NamedTuple{(:σ, :μ, :xs, :ys),NTuple{4,Tuple{Array{Float64,1},Array{String,1}}}},Float64}((σ = ([3.0], ["σ"]), μ = ([1.0], ["μ"]), xs = ([4.997428020248069], ["xs[1]"]), ys = ([-2.2354138174930265], ["ys[1]"])), -8.723273765696497) | |
julia> transition2nt(t; exclude = [:σ, :μ]) | |
(xs = [4.997428020248069], ys = [-2.2354138174930265]) | |
julia> transition2nt(t) | |
(σ = [3.0], μ = [1.0], xs = [4.997428020248069], ys = [-2.2354138174930265]) | |
julia> # Sample observations from conditioned model | |
num_obs = 100; | |
julia> julia> m_obs = gsdemo(fill(missing, num_obs), fill(missing, num_obs)); | |
julia> results = transition2nt(condition(m_obs, (σ = [3.0], μ = [1.0])); exclude = [:σ, :μ]); | |
julia> xs = results.xs; ys = results.ys; | |
julia> chain = sample(m, SMC(), 1000) | |
Sampling 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00 | |
Object of type Chains, with data of type 1000×5×1 Array{Float64,3} | |
Log evidence = -514.5954182059356 | |
Iterations = 1:1000 | |
Thinning interval = 1 | |
Chains = 1 | |
Samples per chain = 1000 | |
internals = le, lp, weight | |
parameters = μ, σ | |
2-element Array{ChainDataFrame,1} | |
Summary Statistics | |
parameters mean std naive_se mcse ess r_hat | |
────────── ────── ────── ──────── ────── ─────── ────── | |
μ 1.0431 0.2660 0.0084 0.0424 32.8009 1.0559 | |
σ 3.1289 0.1727 0.0055 0.0398 19.2294 1.0048 | |
Quantiles | |
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | |
────────── ────── ────── ────── ────── ────── | |
μ 0.5624 0.8371 1.0107 1.1559 1.6197 | |
σ 2.9387 2.9999 3.1187 3.2550 3.5593 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Important notes:
σ
andμ
are univariate rvs, the values need to specified by aVector
.