Skip to content

Instantly share code, notes, and snippets.

@JasonPekos
Created April 13, 2024 18:07
Show Gist options
  • Save JasonPekos/82be830e4bf390fd1cc2886a7518aede to your computer and use it in GitHub Desktop.
Save JasonPekos/82be830e4bf390fd1cc2886a7518aede to your computer and use it in GitHub Desktop.
using Turing, HiddenMarkovModels
using PosteriorDB
using Makie, CairoMakie
using LinearAlgebra, LogExpFunctions
using DataFrames
# Get the dataset and validated reference draws from PosteriorDB
pdb = PosteriorDB.database() # Data import
data = PosteriorDB.load(PosteriorDB.dataset(pdb, "hmm_example"))
ref_post = DataFrame(PosteriorDB.load(PosteriorDB.reference_posterior(pdb, "hmm_example-hmm_example")))
ref_draws = DataFrame([vcat(ref_post[!, c]...) for c in names(ref_post)], names(ref_post))
# Plotting Function(s)
function plot_states(gq, data)
f = Figure()
ax = Axis(f[1, 1])
scatter!(ax, 1:data["N"], data["y"])
for i in eachindex(gq)
lines!(ax, 1:data["N"], gq[i], color = :grey, alpha = 0.1)
end
return f
end
function plot_draws(chains, names_pair)
f = Figure()
ax = Axis(f[1, 1], title = "Turing vs PosteriorDB Reference Draws")
for (i, ch) in enumerate(chains)
scatter!(ax,
ch[!, names_pair[1]],
ch[!, names_pair[2]],
alpha = 0.4,
label = "draws $i")
end
return f
end
# Define The Models:
@model function example_hmm_marginalized(N, K, y)
mu ~ MvNormal([3, 10], I)
theta1 ~ Dirichlet(softmax(ones(K)))
theta2 ~ Dirichlet(softmax(ones(K)))
θ = vcat(theta1', theta2')
hmm = HMM(softmax(ones(K)), θ, [Normal(mu[1], 1), Normal(mu[2], 1)])
_, filtered_likelihood = forward(hmm, y)
Turing.@addlogprob! only(filtered_likelihood)
seq, _ = viterbi(hmm, y) # Probably do not want this in the model?
return [mu[s] for s in seq]
end
# Sample
chn_marg = sample(example_hmm_marginalized(values(data)...), NUTS(), 1000, discard_initial = 1000)
df_chn_marg = DataFrame(chn_marg)
plot_draws([df_chn_marg, ref_draws], ["theta1[1]", "theta2[1]"])
gq = generated_quantities(example_hmm_marginalized(values(data)...), chn_marg);
plot_states(gq, data)
# Compare Reference Draws:
@yebai
Copy link

yebai commented Apr 15, 2024

We could standardise what appears on the right side of tilde statements and provide a standard interface for user-provided distribution objects. At the moment, these can only be from Distributions.jl. We would like to support more, including HMMs, SSMs, BUGS models, Turing submodels, etc.

Related: TuringLang/DynamicPPL.jl#523

@torfjelde
Copy link

Probably the easiest way to make something that's "nicely" operable within a Turing.jl model is to just define a

struct MarginalizedHMM{M,V,F} <: Distribution{V,F}
    hmm::M
end

in combination with a nice marginlize(hmm) constructor that extracts the variate form, etc. from the emission distributions in hmm (though this would have to assume the emission distributions are all fo the "same" type; is this assumed in your work @gdalle ?).

Then you could do

y ~ marginalize(hmm)

This would of course still not expose the latent variables to Turing.jl as we'd ideally like, but it does make the experience a bit more seamless 🤷

@gdalle
Copy link

gdalle commented Apr 16, 2024

We could standardise what appears on the right side of tilde statements and provide a standard interface for user-provided distribution objects.

That was also the idea behind HiddenMarkovModels.jl: who cares if it's a Distributions.Distribution, as long as it has DensityInterface.logdensityof and Random.rand.

this would have to assume the emission distributions are all fo the "same" type; is this assumed in your work @gdalle ?

Yeah, you need to be able to apply every emission distribution to every observation object without error (the logdensity may return -Inf though). While it is theoretically possible to have a vector of emissions that looks like emissions=[Normal(), Gamma(), Exponential()], maybe there are some dispatches like fit(eltype(emissions), ...) that would cause trouble. They should be easy enough to replace with fit(typeof(emissions[i]), ...) if needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment