Skip to content

Instantly share code, notes, and snippets.

@JasonPekos
Created April 13, 2024 18:07
Show Gist options
  • Save JasonPekos/82be830e4bf390fd1cc2886a7518aede to your computer and use it in GitHub Desktop.
Save JasonPekos/82be830e4bf390fd1cc2886a7518aede to your computer and use it in GitHub Desktop.
using Turing, HiddenMarkovModels
using PosteriorDB
using Makie, CairoMakie
using LinearAlgebra, LogExpFunctions
using DataFrames
# Get the dataset and validated reference draws from PosteriorDB
pdb = PosteriorDB.database() # Data import
data = PosteriorDB.load(PosteriorDB.dataset(pdb, "hmm_example"))
ref_post = DataFrame(PosteriorDB.load(PosteriorDB.reference_posterior(pdb, "hmm_example-hmm_example")))
ref_draws = DataFrame([vcat(ref_post[!, c]...) for c in names(ref_post)], names(ref_post))
# Plotting Function(s)
function plot_states(gq, data)
f = Figure()
ax = Axis(f[1, 1])
scatter!(ax, 1:data["N"], data["y"])
for i in eachindex(gq)
lines!(ax, 1:data["N"], gq[i], color = :grey, alpha = 0.1)
end
return f
end
function plot_draws(chains, names_pair)
f = Figure()
ax = Axis(f[1, 1], title = "Turing vs PosteriorDB Reference Draws")
for (i, ch) in enumerate(chains)
scatter!(ax,
ch[!, names_pair[1]],
ch[!, names_pair[2]],
alpha = 0.4,
label = "draws $i")
end
return f
end
# Define The Models:
@model function example_hmm_marginalized(N, K, y)
mu ~ MvNormal([3, 10], I)
theta1 ~ Dirichlet(softmax(ones(K)))
theta2 ~ Dirichlet(softmax(ones(K)))
θ = vcat(theta1', theta2')
hmm = HMM(softmax(ones(K)), θ, [Normal(mu[1], 1), Normal(mu[2], 1)])
_, filtered_likelihood = forward(hmm, y)
Turing.@addlogprob! only(filtered_likelihood)
seq, _ = viterbi(hmm, y) # Probably do not want this in the model?
return [mu[s] for s in seq]
end
# Sample
chn_marg = sample(example_hmm_marginalized(values(data)...), NUTS(), 1000, discard_initial = 1000)
df_chn_marg = DataFrame(chn_marg)
plot_draws([df_chn_marg, ref_draws], ["theta1[1]", "theta2[1]"])
gq = generated_quantities(example_hmm_marginalized(values(data)...), chn_marg);
plot_states(gq, data)
# Compare Reference Draws:
@gdalle
Copy link

gdalle commented Apr 15, 2024

The slight annoyance is that allowing the user to do y ~ hmm might give them an indication that you're allowed to also sample from this, but this is unfortunately not possible (we assume Distribution in Turing.jl for sampling)

Will it error nicely if the user tries it with a non-Distribution, or return something wrong?

Buuut it is really a nice example 🤷

And also I didn't expect things to compose so well out of the box 😍

To clarify, it is pretty simple to sample forwardly from HMM right?

Definitely. The only hurdle is that I don't use Distributions.Categorical to sample state transitions, but roll out my own categorical instead to avoid the dependence. So even if the observation distributions are Distributions, Turing might have a hard time with my sampler on the state part.

https://github.com/gdalle/HiddenMarkovModels.jl/blob/6bfb23a7684f3fcddfa51989716fdd88ed67c46f/src/types/abstract_hmm.jl#L145-L171

In HiddenMarkovModels.jl (IIUC), the current rand returns both the latent states and the observed states, hence it's not compatible with the current assumptions that we make

It's also a bit manual in that it doesn't follow the offial Random API of first generating a Sampler. Maybe I should?

@yebai
Copy link

yebai commented Apr 15, 2024

We could standardise what appears on the right side of tilde statements and provide a standard interface for user-provided distribution objects. At the moment, these can only be from Distributions.jl. We would like to support more, including HMMs, SSMs, BUGS models, Turing submodels, etc.

Related: TuringLang/DynamicPPL.jl#523

@torfjelde
Copy link

Probably the easiest way to make something that's "nicely" operable within a Turing.jl model is to just define a

struct MarginalizedHMM{M,V,F} <: Distribution{V,F}
    hmm::M
end

in combination with a nice marginlize(hmm) constructor that extracts the variate form, etc. from the emission distributions in hmm (though this would have to assume the emission distributions are all fo the "same" type; is this assumed in your work @gdalle ?).

Then you could do

y ~ marginalize(hmm)

This would of course still not expose the latent variables to Turing.jl as we'd ideally like, but it does make the experience a bit more seamless 🤷

@gdalle
Copy link

gdalle commented Apr 16, 2024

We could standardise what appears on the right side of tilde statements and provide a standard interface for user-provided distribution objects.

That was also the idea behind HiddenMarkovModels.jl: who cares if it's a Distributions.Distribution, as long as it has DensityInterface.logdensityof and Random.rand.

this would have to assume the emission distributions are all fo the "same" type; is this assumed in your work @gdalle ?

Yeah, you need to be able to apply every emission distribution to every observation object without error (the logdensity may return -Inf though). While it is theoretically possible to have a vector of emissions that looks like emissions=[Normal(), Gamma(), Exponential()], maybe there are some dispatches like fit(eltype(emissions), ...) that would cause trouble. They should be easy enough to replace with fit(typeof(emissions[i]), ...) if needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment