Skip to content

Instantly share code, notes, and snippets.

@jsks
Last active November 2, 2024 20:39
Show Gist options
  • Save jsks/1af193afb244fe6159792759627ac1a3 to your computer and use it in GitHub Desktop.
Save jsks/1af193afb244fe6159792759627ac1a3 to your computer and use it in GitHub Desktop.
Metropolis within gibbs sampling for poisson regression
#!/usr/bin/env -S julia -t auto
#
# Metropolis-within-Gibbs sampler for a poisson regression
#
# y_i ~ poisson(λ_i)
# λ_i = exp(X_i' β)
# β_j ~ normal(0, 5)
#
###
using Base.Threads, Distributions, LinearAlgebra, Plots,
SpecialFunctions, StatsBase, StatsPlots
struct ChainFit
draws::Array{Float64, 2}
acceptance::Float64
adapt::Vector{Float64}
end
function log_poisson_pmf(y, X, β, λ)
mul!(λ, X, β)
λ .= exp.(λ)
sum(y .* log.(λ) .- λ .- loggamma.(y .+ 1.0))
end
function log_priors(priors::Vector{<:Distribution}, β)
s = 0.0
@inbounds for i in eachindex(β)
s += logpdf(priors[i], β[i])
end
return s
end
function gibbs(y::Vector{Int}, X::Matrix{Float64} , priors::Vector{<:Distribution};
adapt_rate = 0.05,
adapt_target = 0.234,
nadapt = 5_000,
niter=10_000)
N, D = size(X)
β = rand(Uniform(-2, 2), D)
σ = fill(1.0, D)
# Adaptively tune the proposal standard deviations in batches
batch_accept = zeros(D)
batch_size = 100
accept = 0
samples = zeros(niter, D)
proposal = zeros(D)
# Pre-allocate vector for Poisson rates
λ = zeros(N)
lp_current = log_poisson_pmf(y, X, β, λ) + log_priors(priors, β)
@inbounds for t in 1:niter
for j in 1:D
copy!(proposal, β)
proposal[j] = rand(Normal(β[j], σ[j]))
lp_proposal = log_poisson_pmf(y, X, proposal, λ) + log_priors(priors, proposal)
if rand() <= min(1, exp(lp_proposal - lp_current))
if t <= nadapt
batch_accept[j] += 1
else
accept += 1
end
β[j] = proposal[j]
lp_current = lp_proposal
end
end
if t <= nadapt && t % batch_size == 0
for j in 1:D
σ[j] *= exp(adapt_rate * (batch_accept[j] / batch_size - adapt_target))
σ[j] = clamp(σ[j], 1e-4, 10.0)
batch_accept[j] = 0
end
end
samples[t, :] = β
t % 1000 == 0 && @info "Iteration {$(threadid())} $t/$niter"
end
return ChainFit(samples, accept / (D * niter), σ)
end
###
# Plotting functions
function traceplot(samples, coefficient)
M = size(samples[1], 1)
p = plot(1:M, samples[1][:, coefficient], label="Chain 1", color=1, alpha=0.5)
for i in 2:length(samples)
plot!(1:M, samples[i][:, coefficient], label="Chain $i", color=i, alpha=0.5)
end
return p
end
function scatterplot(samples, i, j)
p = scatter(samples[1][:, i], samples[1][:, j])
for n in 2:length(samples)
scatter!(samples[n][:, i], samples[n][:, j])
end
return p
end
function parplot(samples, coefficient, true_value)
draws = vcat([chain[:, coefficient] for chain in samples]...)
p = density(draws, xrotation = 45)
vline!([true_value], linestyle=:dash)
return p
end
function rootogram(y, yhat)
sqrt_count = x -> Dict((k, sqrt(v)) for (k, v) in countmap(x))
low = x -> quantile(x, 0.05)
high = x -> quantile(x, 0.95)
ycounts = sqrt_count(y)
x, y_freq = map(collect, zip(sort(collect(ycounts), by = x -> x[1])...))
p = plot(0:(length(y_freq) - 1), y_freq, label="Observed", seriestype=:bar, alpha=0.5)
post_counts = sqrt_count.(eachcol(yhat))
px = map(keys, post_counts) |> Iterators.flatten |> unique |> sort
# There's definitely a better way to do this...
post_freq = [median(get(draw, d, 0) for draw in post_counts) for d in px]
post_freq_low = post_freq .- [low(get(draw, d, 0) for draw in post_counts) for d in px]
post_freq_high = [high(get(draw, d, 0) for draw in post_counts) for d in px] .- post_freq
scatter!(px, post_freq, yerror=(post_freq_low, post_freq_high), markersize=2.5, label ="Predicted")
return p
end
function predict(X, β)
N = size(X, 1)
λ = exp.(X * β)
[rand(Poisson(λ[i])) for i in 1:N]
end
###
# Start by simulating dataset
N = 10_000
D = 3
scale = x -> (x .- mean(x)) ./ std(x)
# Two covariates and an intercept
covariates = rand(Normal(0, 1), N, D - 1) |> eachcol .|> scale
X = hcat(ones(N), covariates...)
β = rand(Normal(0, 1 / D), D)
λ = exp.(X * β)
y = map(rand, Poisson.(λ))
###
# Fit model to simulated data
priors = [Normal(0, 5) for _ in 1:D]
chains = fetch.([@spawn gibbs(y, X, priors, niter=100_000, nadapt=50_000) for _ in 1:4])
burnin = 50_000
samples = [model.draws[(burnin+1):50:end, :] for model in chains]
# Traceplots for each parameter
plots = [traceplot(samples, i) for i in 1:D]
plot(plots...)
# Scatter plot of each pair of parameters
plots = [scatterplot(samples, i, j) for i in 1:D for j in (i+1):D]
plot(plots...)
# Coefficient estimates
plots = [parplot(samples, i, β[i]) for i in 1:D]
plot(plots...)
# Finally, posterior predictions using square root frequencies
β_hat = vcat(samples...)
yhat = hcat([predict(X, β_hat[m, :]) for m in 1:500]...)
rootogram(y, yhat)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment