Skip to content

Instantly share code, notes, and snippets.

@cscherrer
Last active April 22, 2022 19:50
Show Gist options
  • Save cscherrer/7b00f9ade4e32bb28273e2fc1f360eda to your computer and use it in GitHub Desktop.
Save cscherrer/7b00f9ade4e32bb28273e2fc1f360eda to your computer and use it in GitHub Desktop.
ZigZag sampler with a Soss model
using Soss
using ZigZagBoomerang
using Soss: logdensity, xform, ConditionalModel
using ForwardDiff
using ForwardDiff: gradient!
using LinearAlgebra
using SparseArrays
using StructArrays
using TransformVariables
using MeasureTheory
using TupleVectors
Soss.xform(m::SpikeMixture) = Soss.xform(m.m)
function partiali(f,d)
id = collect(I(d))
ith(i) = @inbounds @view id[:,i]
function (x,i)
sa = StructArray{ForwardDiff.Dual{}}((x, ith(i)))
δ = f(sa).partials[]
return δ
end
end
function zigzag(ℓ, t0, x0, v0, T=1000.0; c=10.0, adapt=false)
d = length(x0)
∇ϕi = partiali(ℓ, d)
pdmp(∇ϕi, t0, x0, v0, T, c*ones(d), ZigZag(sparse(I(d)), 0*x0); adapt=adapt)
end
m = @model x begin
α ~ Normal()
β ~ Normal()
yhat = α .+ β .* x
y ~ For(eachindex(x)) do j
Normal(yhat[j], 2.0)
end
end
x = randn(20);
obs = -0.1 .+ 2x + 1randn(20);
post = m(x=x) | (y=obs,)
t = xform(post)
function ℓ(x)
(θ, logjac) = TransformVariables.transform_and_logjac(t, x)
-logdensity(post, θ) - logjac
end
tkeys = keys(t.transformations)
vars = Soss.select(simulate(post).trace, tkeys)
t0 = 0.0
x0 = inverse(t, vars)
v0 = randn(dimension(t))
T = 1000.0
mytrace, final, (num, acc) = zigzag(ℓ, t0, x0, v0, 100; c=20.0);
# mytrace is a continous object, discretize to obtain samples
ts, xs = ZigZagBoomerang.sep(discretize(mytrace, 0.1));
samples = TupleVector(xform(post).(xs))
using Plots
p = plot(ts, samples.α, label="α")
plot!(p, ts, samples.β, color=:red, label="β")
using TupleVectors: unwrap
# Unwrap a TupleVector to treat is as a NamedTuple of Vectors
p2 = @with unwrap(samples) begin
plot(α, β, legend=false)
end
@vincent-picaud
Copy link

Hi Chad,

Thanks a lot for your quick answer.

Following your remark, I wrote:

samples = TupleVector(TransformVariables.transform.(xform(post),xs))

and it worked :)

All this seems very interesting, but I need some work for a better understanding :)

Best,
VP

@cscherrer
Copy link
Author

👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment