Skip to content

Instantly share code, notes, and snippets.

@mschauer
Last active October 5, 2022 22:47
Show Gist options
  • Save mschauer/30c93cdc1eafbe912a9d75ac22c65a3b to your computer and use it in GitHub Desktop.
Save mschauer/30c93cdc1eafbe912a9d75ac22c65a3b to your computer and use it in GitHub Desktop.
import Pkg
cd(@__DIR__)
# Pkg.activate(@__DIR__)
using Tilde, Pathfinder, PDMats, StructArrays
using ForwardDiff
using ForwardDiff: Dual
using LinearAlgebra, Random, Statistics, StatsBase, SparseArrays
using ZigZagBoomerang
using ZigZagBoomerang: StickyBarriers, StructuredTarget, StickyUpperBounds, StickyFlow, EndTime
using MCMCChains
using ArraysOfArrays
# Configuration
Random.seed!(1)
κ = 0.01 # stickyness
T = 5000.0 # sampling time
c = 0.01
progress = true # show progress bar
PLOT = true # plot posterior trace
nsamples = 200
Δt = T/nsamples
# Generate mock data
println("Data...")
X = hcat(ones(20000),randn(20000,22))
Xt = Matrix(X')
d = size(X, 2)
n = size(X, 1)
betas = vcat([-0.8], zeros(3), [1.0], zeros(5), [0.9], zeros(12))
@assert length(betas) == d
y = (X*betas) .|> x->rand(Bernoulli(logitp = x))
@info "Important coef positions: $(findall(betas.!=0)), Average rate: $(mean(y))"
# Simple logistic model
model_lr = @model (Xt, y) begin
d, n = size(Xt)
θ ~ Normal() ^ d
for j in 1:n
logitp = view(Xt, :, j)' * θ
y[j] ~ Bernoulli(logitp = logitp)
end
end
# Gradients
function make_grads(model_lr, At, y, d)
post = model_lr(At, y) | (;y)
as_post = as(post)
obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ))
ℓ(θ) = -obj(θ)
gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}())
function ∇neglogp!(y, t, x, args...)
ForwardDiff.gradient!(y, obj, x, gconfig)
return
end
ith = zeros(d)
function ∂neglogp(x,i)
# should use StructArrays, seems tilde broke that
# ForwardDiff.partials(obj([Dual{}(x[j], 1.0*(i==j)) for j in eachindex(x)]))[]
ith[i] = 1
sa = mappedarray(ForwardDiff.Dual{}, x, ith)
δ = obj(sa).partials[]
ith[i] = 0
δ
end
post, ℓ, ∇neglogp!, ∂neglogp
end
post, ℓ, ∇neglogp!, ∂neglogp = make_grads(model_lr, Xt, y, d)
# Pathfinding
println("Pathfinder...")
init_scale = 1
if !@isdefined pf_result
@time pf_result = pathfinder(ℓ; dim=d, init_scale)
end
M = PDMats.PDiagMat(diag(pf_result.fit_distribution.Σ))
Γ = sparse(inv(M))
#Γ = sparse(inv(pf_result.fit_distribution.Σ))
x0 = μ = pf_result.fit_distribution.μ
v0 = PDMats.unwhiten(M, randn(length(x0)))
# Sticky sampler
println("Sticky sampler...")
barriers = [StickyBarriers((0.0, 0.0), (:sticky, :sticky), (κ, κ)) for i in 1:d]
d = length(x0)
t0 = fill(0.0, d)
u0 = (t0, x0, v0)
target = StructuredTarget([i => 1:d for i in 1:d], ∂neglogp)
flow = StickyFlow(ZigZag(Γ, μ))
strong_upperbounds = false
adapt = true
multiplier = 1.7 # increase bounds
G = target.G
G1 = [i => rowvals(Γ)[nzrange(Γ, i)] for i in axes(Γ, 1)]
upper_bounds = StickyUpperBounds(G, G1, Γ, fill(c, d); adapt=adapt, strong = strong_upperbounds, multiplier= multiplier)
end_time = EndTime(T)
∇ϕ(x, i) = ZigZagBoomerang.idot(Γ, i, x) # sparse computation
elapsed_time = @elapsed begin
trace, _, _, acc = @time stickyzz(u0, target, flow, upper_bounds, barriers, end_time; progress=progress)
end
@info "Upper bounds: $(upper_bounds.c)"
println("acc ", acc.acc/acc.num)
# Plot continuous trace and subsamples
if PLOT
ts, xs = ZigZagBoomerang.sep(trace)
tsd, xsd = ZigZagBoomerang.sep(ZigZagBoomerang.discretise(trace, Δt))
println("Plot...")
colors = [:green, :red, :blue, :violet]
using GLMakie
fig1 = fig = Figure()
r = 1:length(ts)
rd = 1:length(tsd)
ax = Axis(fig[1,1], title = "trace")
is = [1, 2, 5, 11]
for i in 1:length(is)
scatter!(ax, tsd[rd], getindex.(xsd[rd], is[i]), color=colors[i], markersize=3.0)
lines!(ax, ts[r], getindex.(xs[r], is[i]), color=(colors[i], 0.4))
lines!(ax, ts[r], fill(betas[is[i]], length(r)), linestyle=:dash, color = (colors[i], 0.5))
end
display(fig)
end
# Your samples
samples = flatview(VectorOfSimilarVectors(ZigZagBoomerang.sep(ZigZagBoomerang.discretise(trace, Δt))[2]))
chain = MCMCChains.Chains(samples')
chain = setinfo(chain, (;start_time=0.0, stop_time = elapsed_time));
chain
[deps]
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
ZigZagBoomerang = "36347407-b186-4a6a-8c98-4f4567861712"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment