Last active
April 22, 2022 19:50
-
-
Save cscherrer/7b00f9ade4e32bb28273e2fc1f360eda to your computer and use it in GitHub Desktop.
ZigZag sampler with a Soss model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Soss | |
using ZigZagBoomerang | |
using Soss: logdensity, xform, ConditionalModel | |
using ForwardDiff | |
using ForwardDiff: gradient! | |
using LinearAlgebra | |
using SparseArrays | |
using StructArrays | |
using TransformVariables | |
using MeasureTheory | |
using TupleVectors | |
Soss.xform(m::SpikeMixture) = Soss.xform(m.m) | |
function partiali(f,d) | |
id = collect(I(d)) | |
ith(i) = @inbounds @view id[:,i] | |
function (x,i) | |
sa = StructArray{ForwardDiff.Dual{}}((x, ith(i))) | |
δ = f(sa).partials[] | |
return δ | |
end | |
end | |
function zigzag(ℓ, t0, x0, v0, T=1000.0; c=10.0, adapt=false) | |
d = length(x0) | |
∇ϕi = partiali(ℓ, d) | |
pdmp(∇ϕi, t0, x0, v0, T, c*ones(d), ZigZag(sparse(I(d)), 0*x0); adapt=adapt) | |
end | |
m = @model x begin | |
α ~ Normal() | |
β ~ Normal() | |
yhat = α .+ β .* x | |
y ~ For(eachindex(x)) do j | |
Normal(yhat[j], 2.0) | |
end | |
end | |
x = randn(20); | |
obs = -0.1 .+ 2x + 1randn(20); | |
post = m(x=x) | (y=obs,) | |
t = xform(post) | |
function ℓ(x) | |
(θ, logjac) = TransformVariables.transform_and_logjac(t, x) | |
-logdensity(post, θ) - logjac | |
end | |
tkeys = keys(t.transformations) | |
vars = Soss.select(simulate(post).trace, tkeys) | |
t0 = 0.0 | |
x0 = inverse(t, vars) | |
v0 = randn(dimension(t)) | |
T = 1000.0 | |
mytrace, final, (num, acc) = zigzag(ℓ, t0, x0, v0, 100; c=20.0); | |
# mytrace is a continous object, discretize to obtain samples | |
ts, xs = ZigZagBoomerang.sep(discretize(mytrace, 0.1)); | |
samples = TupleVector(xform(post).(xs)) | |
using Plots | |
p = plot(ts, samples.α, label="α") | |
plot!(p, ts, samples.β, color=:red, label="β") | |
using TupleVectors: unwrap | |
# Unwrap a TupleVector to treat is as a NamedTuple of Vectors | |
p2 = @with unwrap(samples) begin | |
plot(α, β, legend=false) | |
end |
Hi Vincent, this is because of a change in TransformVariables. For a transform t
and vector x
, you used to be able to call t(x)
. I prefer that, but the author deprecated this use, so now you need to instead call TransformVariables.transform(t, x)
. Does that fix the problem?
Hi Chad,
Thanks a lot for your quick answer.
Following your remark, I wrote:
samples = TupleVector(TransformVariables.transform.(xform(post),xs))
and it worked :)
All this seems very interesting, but I need some work for a better understanding :)
Best,
VP
👍
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi,
I wanted to reproduce this computation, however running
I get
Maybe a typo ?
Thanks,
VP