Last active
April 22, 2022 19:50
-
-
Save cscherrer/7b00f9ade4e32bb28273e2fc1f360eda to your computer and use it in GitHub Desktop.
ZigZag sampler with a Soss model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Soss | |
using ZigZagBoomerang | |
using Soss: logdensity, xform, ConditionalModel | |
using ForwardDiff | |
using ForwardDiff: gradient! | |
using LinearAlgebra | |
using SparseArrays | |
using StructArrays | |
using TransformVariables | |
using MeasureTheory | |
using TupleVectors | |
Soss.xform(m::SpikeMixture) = Soss.xform(m.m) | |
function partiali(f,d) | |
id = collect(I(d)) | |
ith(i) = @inbounds @view id[:,i] | |
function (x,i) | |
sa = StructArray{ForwardDiff.Dual{}}((x, ith(i))) | |
δ = f(sa).partials[] | |
return δ | |
end | |
end | |
function zigzag(ℓ, t0, x0, v0, T=1000.0; c=10.0, adapt=false) | |
d = length(x0) | |
∇ϕi = partiali(ℓ, d) | |
pdmp(∇ϕi, t0, x0, v0, T, c*ones(d), ZigZag(sparse(I(d)), 0*x0); adapt=adapt) | |
end | |
m = @model x begin | |
α ~ Normal() | |
β ~ Normal() | |
yhat = α .+ β .* x | |
y ~ For(eachindex(x)) do j | |
Normal(yhat[j], 2.0) | |
end | |
end | |
x = randn(20); | |
obs = -0.1 .+ 2x + 1randn(20); | |
post = m(x=x) | (y=obs,) | |
t = xform(post) | |
function ℓ(x) | |
(θ, logjac) = TransformVariables.transform_and_logjac(t, x) | |
-logdensity(post, θ) - logjac | |
end | |
tkeys = keys(t.transformations) | |
vars = Soss.select(simulate(post).trace, tkeys) | |
t0 = 0.0 | |
x0 = inverse(t, vars) | |
v0 = randn(dimension(t)) | |
T = 1000.0 | |
mytrace, final, (num, acc) = zigzag(ℓ, t0, x0, v0, 100; c=20.0); | |
# mytrace is a continous object, discretize to obtain samples | |
ts, xs = ZigZagBoomerang.sep(discretize(mytrace, 0.1)); | |
samples = TupleVector(xform(post).(xs)) | |
using Plots | |
p = plot(ts, samples.α, label="α") | |
plot!(p, ts, samples.β, color=:red, label="β") | |
using TupleVectors: unwrap | |
# Unwrap a TupleVector to treat is as a NamedTuple of Vectors | |
p2 = @with unwrap(samples) begin | |
plot(α, β, legend=false) | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
👍