Skip to content

Instantly share code, notes, and snippets.

@torfjelde
Last active February 10, 2022 14:11
Show Gist options
  • Save torfjelde/59fc2a8060955365125ff99496ecc860 to your computer and use it in GitHub Desktop.
Save torfjelde/59fc2a8060955365125ff99496ecc860 to your computer and use it in GitHub Desktop.
Example of (un)linearization of `VarInfo`.
using DynamicPPL
varnames_to_ranges(model::DynamicPPL.Model) = varnames_to_ranges(DynamicPPL.VarInfo(model))
varnames_to_ranges(varinfo::DynamicPPL.UntypedVarInfo) = varnames_to_ranges(varinfo.metadata)
function varnames_to_ranges(varinfo::DynamicPPL.TypedVarInfo)
offset = 0
dicts = map(varinfo.metadata) do md
vns2ranges = varnames_to_ranges(md)
vals = collect(values(vns2ranges))
vals_offset = map(r -> offset .+ r, vals)
offset += reduce((curr, r) -> max(curr, r[end]), vals; init=0)
return Dict(zip(keys(vns2ranges), vals_offset))
end
return reduce(merge, dicts)
end
function varnames_to_ranges(metadata::DynamicPPL.Metadata)
idcs = map(Base.Fix1(getindex, metadata.idcs), metadata.vns)
ranges = metadata.ranges[idcs]
return Dict(zip(metadata.vns, ranges))
end
linearize(varinfo::VarInfo) = varinfo[SampleFromPrior()]
unlinearize(θ::AbstractVector, varinfo::VarInfo; kwargs...) = unlinearize(θ, varinfo, Dict; kwargs...)
function unlinearize(θ::AbstractVector, varinfo::VarInfo, ::Type{Dict}; varnames2ranges=varnames_to_ranges(varinfo))
vals = map(zip(keys(varnames2ranges), values(varnames2ranges))) do (varname, r)
dist = DynamicPPL.getdist(varinfo, varname)
return dist isa Distributions.UnivariateDistribution ? θ[first(r)] : DynamicPPL.reconstruct(dist, θ[r])
end
return Dict(zip(keys(varnames2ranges), vals))
end
julia> @model function demo()
s ~ Dirac(1)
x = Matrix{Float64}(undef, 2, 3)
x[1, 1] ~ Dirac(2)
x[2, 1] ~ Dirac(3)
x[3] ~ Dirac(4)
y ~ Dirac(5)
x[4] ~ Dirac(5)
x[:, 3] ~ MvNormal([0, 0], [1 0; 0 1])
return s, x, y
end
demo (generic function with 2 methods)
julia> vi_typed = VarInfo(demo()); θ_typed = linearize(vi_typed);
julia> unlinearize(θ_typed, vi_typed)
Dict{VarName, Any} with 7 entries:
s => 1.0
x[:,3] => [-0.11168, -0.370387]
x[4] => 5.0
x[1,1] => 2.0
x[2,1] => 3.0
y => 5.0
x[3] => 4.0
julia> vi_untyped = VarInfo(); demo()(vi_untyped); θ_untyped = linearize(vi_untyped);
julia> unlinearize(θ_untyped, vi_untyped)
Dict{VarName, Any} with 7 entries:
s => 1
x[:,3] => Real[0.838597, 0.965344]
x[4] => 5
x[1,1] => 2
x[2,1] => 3
y => 5
x[3] => 4
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment