Last active
October 5, 2022 22:47
-
-
Save mschauer/30c93cdc1eafbe912a9d75ac22c65a3b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Pkg | |
cd(@__DIR__) | |
# Pkg.activate(@__DIR__) | |
using Tilde, Pathfinder, PDMats, StructArrays | |
using ForwardDiff | |
using ForwardDiff: Dual | |
using LinearAlgebra, Random, Statistics, StatsBase, SparseArrays | |
using ZigZagBoomerang | |
using ZigZagBoomerang: StickyBarriers, StructuredTarget, StickyUpperBounds, StickyFlow, EndTime | |
using MCMCChains | |
using ArraysOfArrays | |
# Configuration | |
Random.seed!(1) | |
κ = 0.01 # stickyness | |
T = 5000.0 # sampling time | |
c = 0.01 | |
progress = true # show progress bar | |
PLOT = true # plot posterior trace | |
nsamples = 200 | |
Δt = T/nsamples | |
# Generate mock data | |
println("Data...") | |
X = hcat(ones(20000),randn(20000,22)) | |
Xt = Matrix(X') | |
d = size(X, 2) | |
n = size(X, 1) | |
betas = vcat([-0.8], zeros(3), [1.0], zeros(5), [0.9], zeros(12)) | |
@assert length(betas) == d | |
y = (X*betas) .|> x->rand(Bernoulli(logitp = x)) | |
@info "Important coef positions: $(findall(betas.!=0)), Average rate: $(mean(y))" | |
# Simple logistic model | |
model_lr = @model (Xt, y) begin | |
d, n = size(Xt) | |
θ ~ Normal() ^ d | |
for j in 1:n | |
logitp = view(Xt, :, j)' * θ | |
y[j] ~ Bernoulli(logitp = logitp) | |
end | |
end | |
# Gradients | |
function make_grads(model_lr, At, y, d) | |
post = model_lr(At, y) | (;y) | |
as_post = as(post) | |
obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ)) | |
ℓ(θ) = -obj(θ) | |
gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}()) | |
function ∇neglogp!(y, t, x, args...) | |
ForwardDiff.gradient!(y, obj, x, gconfig) | |
return | |
end | |
ith = zeros(d) | |
function ∂neglogp(x,i) | |
# should use StructArrays, seems tilde broke that | |
# ForwardDiff.partials(obj([Dual{}(x[j], 1.0*(i==j)) for j in eachindex(x)]))[] | |
ith[i] = 1 | |
sa = mappedarray(ForwardDiff.Dual{}, x, ith) | |
δ = obj(sa).partials[] | |
ith[i] = 0 | |
δ | |
end | |
post, ℓ, ∇neglogp!, ∂neglogp | |
end | |
post, ℓ, ∇neglogp!, ∂neglogp = make_grads(model_lr, Xt, y, d) | |
# Pathfinding | |
println("Pathfinder...") | |
init_scale = 1 | |
if !@isdefined pf_result | |
@time pf_result = pathfinder(ℓ; dim=d, init_scale) | |
end | |
M = PDMats.PDiagMat(diag(pf_result.fit_distribution.Σ)) | |
Γ = sparse(inv(M)) | |
#Γ = sparse(inv(pf_result.fit_distribution.Σ)) | |
x0 = μ = pf_result.fit_distribution.μ | |
v0 = PDMats.unwhiten(M, randn(length(x0))) | |
# Sticky sampler | |
println("Sticky sampler...") | |
barriers = [StickyBarriers((0.0, 0.0), (:sticky, :sticky), (κ, κ)) for i in 1:d] | |
d = length(x0) | |
t0 = fill(0.0, d) | |
u0 = (t0, x0, v0) | |
target = StructuredTarget([i => 1:d for i in 1:d], ∂neglogp) | |
flow = StickyFlow(ZigZag(Γ, μ)) | |
strong_upperbounds = false | |
adapt = true | |
multiplier = 1.7 # increase bounds | |
G = target.G | |
G1 = [i => rowvals(Γ)[nzrange(Γ, i)] for i in axes(Γ, 1)] | |
upper_bounds = StickyUpperBounds(G, G1, Γ, fill(c, d); adapt=adapt, strong = strong_upperbounds, multiplier= multiplier) | |
end_time = EndTime(T) | |
∇ϕ(x, i) = ZigZagBoomerang.idot(Γ, i, x) # sparse computation | |
elapsed_time = @elapsed begin | |
trace, _, _, acc = @time stickyzz(u0, target, flow, upper_bounds, barriers, end_time; progress=progress) | |
end | |
@info "Upper bounds: $(upper_bounds.c)" | |
println("acc ", acc.acc/acc.num) | |
# Plot continuous trace and subsamples | |
if PLOT | |
ts, xs = ZigZagBoomerang.sep(trace) | |
tsd, xsd = ZigZagBoomerang.sep(ZigZagBoomerang.discretise(trace, Δt)) | |
println("Plot...") | |
colors = [:green, :red, :blue, :violet] | |
using GLMakie | |
fig1 = fig = Figure() | |
r = 1:length(ts) | |
rd = 1:length(tsd) | |
ax = Axis(fig[1,1], title = "trace") | |
is = [1, 2, 5, 11] | |
for i in 1:length(is) | |
scatter!(ax, tsd[rd], getindex.(xsd[rd], is[i]), color=colors[i], markersize=3.0) | |
lines!(ax, ts[r], getindex.(xs[r], is[i]), color=(colors[i], 0.4)) | |
lines!(ax, ts[r], fill(betas[is[i]], length(r)), linestyle=:dash, color = (colors[i], 0.5)) | |
end | |
display(fig) | |
end | |
# Your samples | |
samples = flatview(VectorOfSimilarVectors(ZigZagBoomerang.sep(ZigZagBoomerang.discretise(trace, Δt))[2])) | |
chain = MCMCChains.Chains(samples') | |
chain = setinfo(chain, (;start_time=0.0, stop_time = elapsed_time)); | |
chain |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[deps] | |
ArraysOfArrays = "65a8f2f4-9b39-5baf-92e2-a9cc46fdf018" | |
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" | |
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" | |
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" | |
Pathfinder = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" | |
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" | |
ZigZagBoomerang = "36347407-b186-4a6a-8c98-4f4567861712" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment