Last active
November 2, 2024 20:39
-
-
Save jsks/1af193afb244fe6159792759627ac1a3 to your computer and use it in GitHub Desktop.
Metropolis within gibbs sampling for poisson regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env -S julia -t auto | |
# | |
# Metropolis-within-Gibbs sampler for a poisson regression | |
# | |
# y_i ~ poisson(λ_i) | |
# λ_i = exp(X_i' β) | |
# β_j ~ normal(0, 5) | |
# | |
### | |
using Base.Threads, Distributions, LinearAlgebra, Plots, | |
SpecialFunctions, StatsBase, StatsPlots | |
struct ChainFit | |
draws::Array{Float64, 2} | |
acceptance::Float64 | |
adapt::Vector{Float64} | |
end | |
function log_poisson_pmf(y, X, β, λ) | |
mul!(λ, X, β) | |
λ .= exp.(λ) | |
sum(y .* log.(λ) .- λ .- loggamma.(y .+ 1.0)) | |
end | |
function log_priors(priors::Vector{<:Distribution}, β) | |
s = 0.0 | |
@inbounds for i in eachindex(β) | |
s += logpdf(priors[i], β[i]) | |
end | |
return s | |
end | |
function gibbs(y::Vector{Int}, X::Matrix{Float64} , priors::Vector{<:Distribution}; | |
adapt_rate = 0.05, | |
adapt_target = 0.234, | |
nadapt = 5_000, | |
niter=10_000) | |
N, D = size(X) | |
β = rand(Uniform(-2, 2), D) | |
σ = fill(1.0, D) | |
# Adaptively tune the proposal standard deviations in batches | |
batch_accept = zeros(D) | |
batch_size = 100 | |
accept = 0 | |
samples = zeros(niter, D) | |
proposal = zeros(D) | |
# Pre-allocate vector for Poisson rates | |
λ = zeros(N) | |
lp_current = log_poisson_pmf(y, X, β, λ) + log_priors(priors, β) | |
@inbounds for t in 1:niter | |
for j in 1:D | |
copy!(proposal, β) | |
proposal[j] = rand(Normal(β[j], σ[j])) | |
lp_proposal = log_poisson_pmf(y, X, proposal, λ) + log_priors(priors, proposal) | |
if rand() <= min(1, exp(lp_proposal - lp_current)) | |
if t <= nadapt | |
batch_accept[j] += 1 | |
else | |
accept += 1 | |
end | |
β[j] = proposal[j] | |
lp_current = lp_proposal | |
end | |
end | |
if t <= nadapt && t % batch_size == 0 | |
for j in 1:D | |
σ[j] *= exp(adapt_rate * (batch_accept[j] / batch_size - adapt_target)) | |
σ[j] = clamp(σ[j], 1e-4, 10.0) | |
batch_accept[j] = 0 | |
end | |
end | |
samples[t, :] = β | |
t % 1000 == 0 && @info "Iteration {$(threadid())} $t/$niter" | |
end | |
return ChainFit(samples, accept / (D * niter), σ) | |
end | |
### | |
# Plotting functions | |
function traceplot(samples, coefficient) | |
M = size(samples[1], 1) | |
p = plot(1:M, samples[1][:, coefficient], label="Chain 1", color=1, alpha=0.5) | |
for i in 2:length(samples) | |
plot!(1:M, samples[i][:, coefficient], label="Chain $i", color=i, alpha=0.5) | |
end | |
return p | |
end | |
function scatterplot(samples, i, j) | |
p = scatter(samples[1][:, i], samples[1][:, j]) | |
for n in 2:length(samples) | |
scatter!(samples[n][:, i], samples[n][:, j]) | |
end | |
return p | |
end | |
function parplot(samples, coefficient, true_value) | |
draws = vcat([chain[:, coefficient] for chain in samples]...) | |
p = density(draws, xrotation = 45) | |
vline!([true_value], linestyle=:dash) | |
return p | |
end | |
function rootogram(y, yhat) | |
sqrt_count = x -> Dict((k, sqrt(v)) for (k, v) in countmap(x)) | |
low = x -> quantile(x, 0.05) | |
high = x -> quantile(x, 0.95) | |
ycounts = sqrt_count(y) | |
x, y_freq = map(collect, zip(sort(collect(ycounts), by = x -> x[1])...)) | |
p = plot(0:(length(y_freq) - 1), y_freq, label="Observed", seriestype=:bar, alpha=0.5) | |
post_counts = sqrt_count.(eachcol(yhat)) | |
px = map(keys, post_counts) |> Iterators.flatten |> unique |> sort | |
# There's definitely a better way to do this... | |
post_freq = [median(get(draw, d, 0) for draw in post_counts) for d in px] | |
post_freq_low = post_freq .- [low(get(draw, d, 0) for draw in post_counts) for d in px] | |
post_freq_high = [high(get(draw, d, 0) for draw in post_counts) for d in px] .- post_freq | |
scatter!(px, post_freq, yerror=(post_freq_low, post_freq_high), markersize=2.5, label ="Predicted") | |
return p | |
end | |
function predict(X, β) | |
N = size(X, 1) | |
λ = exp.(X * β) | |
[rand(Poisson(λ[i])) for i in 1:N] | |
end | |
### | |
# Start by simulating dataset | |
N = 10_000 | |
D = 3 | |
scale = x -> (x .- mean(x)) ./ std(x) | |
# Two covariates and an intercept | |
covariates = rand(Normal(0, 1), N, D - 1) |> eachcol .|> scale | |
X = hcat(ones(N), covariates...) | |
β = rand(Normal(0, 1 / D), D) | |
λ = exp.(X * β) | |
y = map(rand, Poisson.(λ)) | |
### | |
# Fit model to simulated data | |
priors = [Normal(0, 5) for _ in 1:D] | |
chains = fetch.([@spawn gibbs(y, X, priors, niter=100_000, nadapt=50_000) for _ in 1:4]) | |
burnin = 50_000 | |
samples = [model.draws[(burnin+1):50:end, :] for model in chains] | |
# Traceplots for each parameter | |
plots = [traceplot(samples, i) for i in 1:D] | |
plot(plots...) | |
# Scatter plot of each pair of parameters | |
plots = [scatterplot(samples, i, j) for i in 1:D for j in (i+1):D] | |
plot(plots...) | |
# Coefficient estimates | |
plots = [parplot(samples, i, β[i]) for i in 1:D] | |
plot(plots...) | |
# Finally, posterior predictions using square root frequencies | |
β_hat = vcat(samples...) | |
yhat = hcat([predict(X, β_hat[m, :]) for m in 1:500]...) | |
rootogram(y, yhat) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment