Last active
September 6, 2022 09:06
-
-
Save dermesser/0030ad422e1aa9cb90743fed1e8a890e to your computer and use it in GitHub Desktop.
Primitive Hamiltonian Monte Carlo (HMC) sampler
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Plots | |
using Random | |
using Distributions, DistributionsAD | |
using LinearAlgebra | |
import Zygote: gradient | |
struct HMC{T,F} | |
sup::AbstractArray{Tuple{T,T}} | |
invM::Matrix{T} | |
pdist::AbstractMvNormal | |
L::Int64 | |
logpdf::F | |
ΔT::T | |
end | |
function HMCnew(logpdf::F; L=10, ΔT::T=0.1, sup=[(-10., 10.)], M=(diagm(ones(length(sup)))))::HMC{T,F} where {T<:Real,F<:Function} | |
# TODO: Adapt mass. | |
HMC(sup, inv(M), MvNormal(zeros(length(sup)), M), L, logpdf, ΔT) | |
end | |
mutable struct HMCState{T} | |
x::AbstractArray{T} | |
p::AbstractArray{T} | |
t::T | |
end | |
function HMCState(hmc::HMC{T,F}) where {T <: Real, F <: Function} | |
dim = length(hmc.sup) | |
x0 = [rand(Uniform(s[1], s[2])) for s in hmc.sup] | |
p0 = zeros(dim) | |
t = 0. | |
HMCState(x0, p0, t) | |
end | |
function copy(s::HMCState{T})::HMCState{T} where {T} | |
HMCState(Base.copy(s.x), Base.copy(s.p), s.t) | |
end | |
function H(logpdf::F, invM::Matrix{T}, p::A, x::A)::T where {F <: Function, T <: Real, A <: AbstractArray{T}} | |
-logpdf(x) + 1/2 * p' * invM * p | |
end | |
function in_support(x::A, sup::AbstractArray{Tuple{T,T}})::Bool where {T<:Real, A<:AbstractArray{T}} | |
all(s[1] <= x[i] && x[i] <= s[2] for (i,s) in enumerate(sup)) | |
end | |
function transition_probability(hmc::HMC{T,F}, s0::HMCState{T}, s1::HMCState{T})::T where {T<:Real, R, F, A<:AbstractArray{T}} | |
if !in_support(s1.x, hmc.sup) | |
return 0. | |
end | |
new, old = H(hmc.logpdf, hmc.invM, s1.p, s1.x), H(hmc.logpdf, hmc.invM, s0.p, s0.x) | |
min(1, exp(-(new-old))) | |
end | |
function leapfrog_step(hmc::HMC{T,F}, s::HMCState{T})::HMCState{T} where {T <: Real, R <: AbstractRNG, F <: Function, A <: AbstractArray{T}} | |
s.p = s.p + hmc.ΔT/2 * gradient(hmc.logpdf, s.x)[1] | |
s.x = s.x + hmc.ΔT * hmc.invM * s.p | |
s.p = s.p + hmc.ΔT/2 * gradient(hmc.logpdf, s.x)[1] | |
s.t += hmc.ΔT | |
s | |
end | |
function sample(hmc::HMC{T,F}, s::HMCState{T})::HMCState{T} where {T, R, F, A <: AbstractArray{T}} | |
u = Uniform(0, 1) | |
s0 = copy(s) | |
s1 = s | |
rand!(hmc.pdist, s1.p) | |
for i in 1:hmc.L | |
s1 = leapfrog_step(hmc, s1) | |
end | |
α = transition_probability(hmc, s0, s1) | |
if rand(u) <= α | |
# Accept! | |
return s1 | |
else | |
return s0 | |
end | |
end | |
function test_sample_snd() | |
nd = Normal(0, 1) | |
hmc = HMCnew(x -> logpdf(nd, x[1]), sup=[(0., 1.)]) | |
hmcs = HMCState(hmc) | |
N = 10000 | |
samples = zeros(N) | |
for i in 1:N | |
hmcs = sample(hmc, hmcs) | |
samples[i] = hmcs.x[1] | |
end | |
samples | |
plot() | |
histogram!(samples, bins=LinRange(-4, 4, 50), normalize=:pdf) | |
plot!(x -> pdf(nd, x)) | |
current() | |
end | |
function normal_ppd(µs, σs, n=1000)::Vector{Float64} | |
s = zeros(n) | |
µs = Random.Sampler(Random.GLOBAL_RNG, µs, Val(1)) | |
σs = Random.Sampler(Random.GLOBAL_RNG, σs, Val(1)) | |
for i in 1:n | |
µ, σ = rand(µs), rand(σs) | |
s[i] = rand(Normal(µ, σ)) | |
end | |
s | |
end | |
function test_sample_mcmc() | |
true_dist = Normal(5, 2) | |
observations = rand(true_dist, 30) | |
prior_µ = Normal(3, 1) | |
prior_σ = Normal(1, 2) | |
loglik(θ) = begin | |
logpdf(prior_µ, θ[1]) + logpdf(prior_σ, θ[2]) + sum(logpdf(Normal(θ[1], abs(θ[2])), o) for o in observations; init=0) | |
end | |
hmc = HMCnew(loglik; sup=[(-10., 10.), (0., 10.)]) | |
hmcs = HMCState(hmc) | |
N = 1000 | |
samples = zeros(2, N) | |
for i in 1:N | |
hmcs = sample(hmc, hmcs) | |
samples[:, i] = hmcs.x | |
end | |
samples | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment