Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Last active July 29, 2020 16:55
Show Gist options
  • Save torfjelde/4bed4361990d5040acc147da6afae530 to your computer and use it in GitHub Desktop.
Save torfjelde/4bed4361990d5040acc147da6afae530 to your computer and use it in GitHub Desktop.
Implementation of a `condition` method for Turing.jl
julia> using Turing
julia> function condition(
model::Turing.Model,
latent::NamedTuple;
sampler = DynamicPPL.SampleFromPrior()
)
vi = Turing.VarInfo(model)
md = vi.metadata
for v in keys(md)
for vn in md[v].vns
vn_symbol = Symbol(vn)
if vn_symbol ∈ keys(latent)
tmp = latent[vn_symbol]
val = tmp isa Real ? [tmp] : tmp
DynamicPPL.setval!(vi, val, vn)
DynamicPPL.settrans!(vi, false, vn)
else
# delete so we can sample from prior
DynamicPPL.set_flag!(vi, vn, "del")
end
end
end
# Execute `model` on the parameters set in `vi` and sample those with `"del"` flag using `sampler`
model(vi, sampler)
# Convert `VarInfo` into `NamedTuple` and save
theta = DynamicPPL.tonamedtuple(vi)
lp = Turing.getlogp(vi)
return Turing.Inference.Transition(theta, lp)
end
condition (generic function with 1 method)
julia> function transition2nt(t::Turing.Inference.Transition; include = nothing, exclude = nothing)
include = include === nothing ? keys(t.θ) : include
exclude = exclude === nothing ? [] : exclude
return (; [(k, first(t.θ[k])) for k in keys(t.θ) if k ∉ exclude && k ∈ include]...)
end
transition2nt (generic function with 1 method)
julia> @model gsdemo(xs, ys) = begin
# Assumptions
σ ~ Uniform(0.0, 5.0)
μ ~ Uniform(0.0, 5.0)
# Observations
for i = 1:length(xs)
xs[i] ~ Normal(μ, σ)
ys[i] ~ Normal(μ, σ)
end
end
DynamicPPL.ModelGen{var"###generator#274",(:xs, :ys),(),Tuple{}}(##generator#274, NamedTuple())
julia> m_obs = gsdemo([missing], [missing]);
julia> t = condition(m_obs, (σ = [3.0], μ = [1.0]))
Turing.Inference.Transition{NamedTuple{(:σ, :μ, :xs, :ys),NTuple{4,Tuple{Array{Float64,1},Array{String,1}}}},Float64}((σ = ([3.0], ["σ"]), μ = ([1.0], ["μ"]), xs = ([4.997428020248069], ["xs[1]"]), ys = ([-2.2354138174930265], ["ys[1]"])), -8.723273765696497)
julia> transition2nt(t; exclude = [:σ, :μ])
(xs = [4.997428020248069], ys = [-2.2354138174930265])
julia> transition2nt(t)
(σ = [3.0], μ = [1.0], xs = [4.997428020248069], ys = [-2.2354138174930265])
julia> # Sample observations from conditioned model
num_obs = 100;
julia> julia> m_obs = gsdemo(fill(missing, num_obs), fill(missing, num_obs));
julia> results = transition2nt(condition(m_obs, (σ = [3.0], μ = [1.0])); exclude = [:σ, :μ]);
julia> xs = results.xs; ys = results.ys;
julia> chain = sample(m, SMC(), 1000)
Sampling 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:00
Object of type Chains, with data of type 1000×5×1 Array{Float64,3}
Log evidence = -514.5954182059356
Iterations = 1:1000
Thinning interval = 1
Chains = 1
Samples per chain = 1000
internals = le, lp, weight
parameters = μ, σ
2-element Array{ChainDataFrame,1}
Summary Statistics
parameters mean std naive_se mcse ess r_hat
────────── ────── ────── ──────── ────── ─────── ──────
μ 1.0431 0.2660 0.0084 0.0424 32.8009 1.0559
σ 3.1289 0.1727 0.0055 0.0398 19.2294 1.0048
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
────────── ────── ────── ────── ────── ──────
μ 0.5624 0.8371 1.0107 1.1559 1.6197
σ 2.9387 2.9999 3.1187 3.2550 3.5593
@torfjelde
Copy link
Author

Important notes:

  • Due to TuringLang/Turing.jl#1352 which is being fixed in TuringLang/DynamicPPL.jl#147 the above code will NOT work correctly for multivariate variables. As soon as this PR is merged, this snippet will work correctly for non-univariate cases also with some minor changes.
  • Even though σ and μ are univariate rvs, the values need to specified by a Vector.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment