Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Created June 27, 2023 11:05
Show Gist options
  • Save torfjelde/5a0d4d45241d9ef1ad07a94b12b26dcd to your computer and use it in GitHub Desktop.
Save torfjelde/5a0d4d45241d9ef1ad07a94b12b26dcd to your computer and use it in GitHub Desktop.
using DynamicPPL: OrderedDict, SamplingContext, AbstractContext, IsParent, VarName, Distribution, evaluate!!, VarInfo
import DynamicPPL: tilde_assume, dot_tilde_assume, childcontext, setchildcontext, NodeTrait
Base.@kwdef struct PriorExtractorContext{D,Ctx} <: AbstractContext
priors::D=OrderedDict{VarName,Any}()
context::Ctx=SamplingContext()
end
NodeTrait(::PriorExtractorContext) = IsParent()
childcontext(context::PriorExtractorContext) = context.context
setchildcontext(parent::PriorExtractorContext, child) = PriorExtractorContext(parent.priors, child)
function tilde_assume(context::PriorExtractorContext, right, vn, vi)
setprior!(context, vn, right)
return tilde_assume(childcontext(context), right, vn, vi)
end
function dot_tilde_assume(context::PriorExtractorContext, right, left, vn, vi)
setprior!(context, vn, right)
return dot_tilde_assume(childcontext(context), right, left, vn, vi)
end
function setprior!(context::PriorExtractorContext, vn::VarName, dist::Distribution)
context.priors[vn] = dist
end
function setprior!(context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dist::Distribution)
for vn in vns
context.priors[vn] = dist
end
end
function setprior!(context::PriorExtractorContext, vns::AbstractArray{<:VarName}, dists::AbstractArray{<:Distribution})
# TODO: Support broadcasted expressions properly.
for (vn, dist) in zip(vns, dists)
context.priors[vn] = dist
end
end
"""
extract_priors(model::Model)
Extract the priors from a model. This is done by sampling from the model and
recording the distributions that are used to generate the samples.
"""
function extract_priors(model::Model)
context = PriorExtractorContext()
evaluate!!(model, VarInfo(), context)
return context.priors
end
@torfjelde
Copy link
Author

Allows us to do stuff like:

julia> model = DynamicPPL.TestUtils.DEMO_MODELS[1]
Model{typeof(DynamicPPL.TestUtils.demo_dot_assume_dot_observe), (:x, Symbol("##arg#378")), (), (), Tuple{Vector{Float64}, DataType}, Tuple{}, DefaultContext}(DynamicPPL.TestUtils.demo_dot_assume_dot_observe, (x = [1.5, 2.0], var"##arg#378" = Vector{Float64}), NamedTuple(), DefaultContext())

julia> extract_priors(model)
OrderedDict{VarName, Any} with 4 entries:
  s[1] => InverseGamma{Float64}(
  s[2] => InverseGamma{Float64}(
  m[1] => Normal{Float64}=0.0, σ=3.78366)
  m[2] => Normal{Float64}=0.0, σ=1.02669)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment