Created
April 13, 2024 18:07
-
-
Save JasonPekos/82be830e4bf390fd1cc2886a7518aede to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Turing, HiddenMarkovModels | |
using PosteriorDB | |
using Makie, CairoMakie | |
using LinearAlgebra, LogExpFunctions | |
using DataFrames | |
# Get the dataset and validated reference draws from PosteriorDB | |
pdb = PosteriorDB.database() # Data import | |
data = PosteriorDB.load(PosteriorDB.dataset(pdb, "hmm_example")) | |
ref_post = DataFrame(PosteriorDB.load(PosteriorDB.reference_posterior(pdb, "hmm_example-hmm_example"))) | |
ref_draws = DataFrame([vcat(ref_post[!, c]...) for c in names(ref_post)], names(ref_post)) | |
# Plotting Function(s) | |
function plot_states(gq, data) | |
f = Figure() | |
ax = Axis(f[1, 1]) | |
scatter!(ax, 1:data["N"], data["y"]) | |
for i in eachindex(gq) | |
lines!(ax, 1:data["N"], gq[i], color = :grey, alpha = 0.1) | |
end | |
return f | |
end | |
function plot_draws(chains, names_pair) | |
f = Figure() | |
ax = Axis(f[1, 1], title = "Turing vs PosteriorDB Reference Draws") | |
for (i, ch) in enumerate(chains) | |
scatter!(ax, | |
ch[!, names_pair[1]], | |
ch[!, names_pair[2]], | |
alpha = 0.4, | |
label = "draws $i") | |
end | |
return f | |
end | |
# Define The Models: | |
@model function example_hmm_marginalized(N, K, y) | |
mu ~ MvNormal([3, 10], I) | |
theta1 ~ Dirichlet(softmax(ones(K))) | |
theta2 ~ Dirichlet(softmax(ones(K))) | |
θ = vcat(theta1', theta2') | |
hmm = HMM(softmax(ones(K)), θ, [Normal(mu[1], 1), Normal(mu[2], 1)]) | |
_, filtered_likelihood = forward(hmm, y) | |
Turing.@addlogprob! only(filtered_likelihood) | |
seq, _ = viterbi(hmm, y) # Probably do not want this in the model? | |
return [mu[s] for s in seq] | |
end | |
# Sample | |
chn_marg = sample(example_hmm_marginalized(values(data)...), NUTS(), 1000, discard_initial = 1000) | |
df_chn_marg = DataFrame(chn_marg) | |
plot_draws([df_chn_marg, ref_draws], ["theta1[1]", "theta2[1]"]) | |
gq = generated_quantities(example_hmm_marginalized(values(data)...), chn_marg); | |
plot_states(gq, data) | |
# Compare Reference Draws: | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
That was also the idea behind HiddenMarkovModels.jl: who cares if it's a
Distributions.Distribution
, as long as it hasDensityInterface.logdensityof
andRandom.rand
.Yeah, you need to be able to apply every emission distribution to every observation object without error (the logdensity may return
-Inf
though). While it is theoretically possible to have a vector of emissions that looks likeemissions=[Normal(), Gamma(), Exponential()]
, maybe there are some dispatches likefit(eltype(emissions), ...)
that would cause trouble. They should be easy enough to replace withfit(typeof(emissions[i]), ...)
if needed