Skip to content

Instantly share code, notes, and snippets.

@JasonPekos
Created April 13, 2024 18:07
Show Gist options
  • Save JasonPekos/82be830e4bf390fd1cc2886a7518aede to your computer and use it in GitHub Desktop.
Save JasonPekos/82be830e4bf390fd1cc2886a7518aede to your computer and use it in GitHub Desktop.
using Turing, HiddenMarkovModels
using PosteriorDB
using Makie, CairoMakie
using LinearAlgebra, LogExpFunctions
using DataFrames
# Get the dataset and validated reference draws from PosteriorDB
pdb = PosteriorDB.database() # Data import
data = PosteriorDB.load(PosteriorDB.dataset(pdb, "hmm_example"))
ref_post = DataFrame(PosteriorDB.load(PosteriorDB.reference_posterior(pdb, "hmm_example-hmm_example")))
ref_draws = DataFrame([vcat(ref_post[!, c]...) for c in names(ref_post)], names(ref_post))
# Plotting Function(s)
function plot_states(gq, data)
f = Figure()
ax = Axis(f[1, 1])
scatter!(ax, 1:data["N"], data["y"])
for i in eachindex(gq)
lines!(ax, 1:data["N"], gq[i], color = :grey, alpha = 0.1)
end
return f
end
function plot_draws(chains, names_pair)
f = Figure()
ax = Axis(f[1, 1], title = "Turing vs PosteriorDB Reference Draws")
for (i, ch) in enumerate(chains)
scatter!(ax,
ch[!, names_pair[1]],
ch[!, names_pair[2]],
alpha = 0.4,
label = "draws $i")
end
return f
end
# Define The Models:
@model function example_hmm_marginalized(N, K, y)
mu ~ MvNormal([3, 10], I)
theta1 ~ Dirichlet(softmax(ones(K)))
theta2 ~ Dirichlet(softmax(ones(K)))
θ = vcat(theta1', theta2')
hmm = HMM(softmax(ones(K)), θ, [Normal(mu[1], 1), Normal(mu[2], 1)])
_, filtered_likelihood = forward(hmm, y)
Turing.@addlogprob! only(filtered_likelihood)
seq, _ = viterbi(hmm, y) # Probably do not want this in the model?
return [mu[s] for s in seq]
end
# Sample
chn_marg = sample(example_hmm_marginalized(values(data)...), NUTS(), 1000, discard_initial = 1000)
df_chn_marg = DataFrame(chn_marg)
plot_draws([df_chn_marg, ref_draws], ["theta1[1]", "theta2[1]"])
gq = generated_quantities(example_hmm_marginalized(values(data)...), chn_marg);
plot_states(gq, data)
# Compare Reference Draws:
@torfjelde
Copy link

torfjelde commented Apr 14, 2024

This is neato stuff!

Regarding finding the optimal states, you can also do:

@model function example_hmm_marginalized(N, K, y, ::Val{IncludeGenerated}=Val(false)) where {IncludeGenerated}
    mu ~ MvNormal([3, 10], I)
    theta1 ~ Dirichlet(softmax(ones(K)))
    theta2 ~ Dirichlet(softmax(ones(K)))
    θ = vcat(theta1', theta2')

    hmm = HMM(softmax(ones(K)), θ, [Normal(mu[1], 1), Normal(mu[2], 1)])

    _, filtered_likelihood = forward(hmm, y)

    Turing.@addlogprob! only(filtered_likelihood)

    # Conditional generation of the hidden states.
    if IncludeGenerated
        seq, _ = viterbi(hmm, y)
        return [mu[s] for s in seq]
    else
        return nothing
    end
end

and then instantiate two models:

model_inference = example_hmm_marginalized(N, K, y)
model_generation = example_hmm_marginalized(N, K, y, Val(true))

where the check will be compiled away in the first instantiation of the model, but included in the second:)

Also note that normalize(x, 1) is available from LinearAlgebra, which might be preferable to softmax(ones(K)), etc. Or even better, make use of FillArrays.jl:

using FillArrays

normalized_ones(n) = Fill(1/n, n)

which are "lazily" constructed arrays.

EDIT: Moreover, it seems HiddenMarkovModels.jl supports logdensityof, so we can simply do:

@model function example_hmm_marginalized(N, K, y, ::Val{IncludeGenerated}=Val(false)) where {IncludeGenerated}
    mu ~ MvNormal([3, 10], I)
    theta1 ~ Dirichlet(softmax(ones(K)))
    theta2 ~ Dirichlet(softmax(ones(K)))
    θ = vcat(theta1', theta2')

    # Replaced `forward` call with `logdensityof`
    hmm = HMM(softmax(ones(K)), θ, [Normal(mu[1], 1), Normal(mu[2], 1)])
    Turing.@addlogprob! logdensityof(hmm, y)

    # Conditional generation of the hidden states.
    if IncludeGenerated
        seq, _ = viterbi(hmm, y)
        return [mu[s] for s in seq]
    else
        return nothing
    end
end

EDIT 2:

We can also make this work as an observation in Turing.jl with a couple of method definitions:

# HACK: Make `AbstractHMM` work with DynamicPPL.jl.
DynamicPPL.check_tilde_rhs(right::HiddenMarkovModels.AbstractHMM) = right
# https://github.com/TuringLang/DynamicPPL.jl/blob/6a2454fac6a0f6da436976818196cf0749d5e30d/src/context_implementations.jl#L261-L264
function DynamicPPL.observe(right::HiddenMarkovModels.AbstractHMM, left, vi)
    DynamicPPL.increment_num_produce!(vi)
    return HiddenMarkovModels.logdensityof(right, left), vi
end

# Define The Models:
@model function example_hmm_marginalized(N, K, y, ::Val{IncludeGenerated}=Val(false)) where {IncludeGenerated}
    mu ~ MvNormal([3, 10], I)
    theta1 ~ Dirichlet(softmax(ones(K)))
    theta2 ~ Dirichlet(softmax(ones(K)))
    θ = vcat(theta1', theta2')

    hmm = HMM(softmax(ones(K)), θ, [Normal(mu[1], 1), Normal(mu[2], 1)])
    y ~ hmm

    # Conditional generation of the hidden states.
    if IncludeGenerated
        seq, _ = viterbi(hmm, y)
        return [mu[s] for s in seq]
    else
        return nothing
    end
end

Now everything works as before:)

@gdalle
Copy link

gdalle commented Apr 15, 2024

Would it make sense to put the hacks in a package extension for HiddenMarkovModels.jl?

@torfjelde
Copy link

Was wondering about doing the same in DynamicPPL.jl 😅

The slight annoyance is that allowing the user to do y ~ hmm might give them an indication that you're allowed to also sample from this, but this is unfortunately not possible (we assume Distribution in Turing.jl for sampling). Buuut it is really a nice example 🤷 So all in all, a bit uncertain

@yebai
Copy link

yebai commented Apr 15, 2024

do y ~ hmm might give them an indication that you're allowed to also sample from this,

To clarify, it is pretty simple to sample forwardly from HMM right?

@yebai
Copy link

yebai commented Apr 15, 2024

Although I agree we probably shouldn't create a package extension, these packages should interoperate via their standard API rather than require glue code. If the standard API is not sufficient, then we should probably fix that!

@torfjelde
Copy link

To clarify, it is pretty simple to sample forwardly from HMM right?

In HiddenMarkovModels.jl (IIUC), the current rand returns both the latent states and the observed states, hence it's not compatible with the current assumptions that we make, i.e. that whatever we want to call rand on produces a single output.

So there's two options:

  1. Allow non-distributions-like rand calls (and other stuff) in DynamicPPL / Turing.
  2. Add NamedTupleVariate to Distributions.jl, which AFIAK has been discussed extensively before: JuliaStats/Distributions.jl#1803

But even so, we'd have to do something fancy to add proper support for something like HMM as you effectively want to do the following:

  1. If y is given => marginalize (i.e. use logdensityof(hmm, y)).
  2. If y is not given => sample both y and the latents.

Could we make it much easier to specialize to scenarios like this? Probably, but not super clear to me how.

@gdalle
Copy link

gdalle commented Apr 15, 2024

The slight annoyance is that allowing the user to do y ~ hmm might give them an indication that you're allowed to also sample from this, but this is unfortunately not possible (we assume Distribution in Turing.jl for sampling)

Will it error nicely if the user tries it with a non-Distribution, or return something wrong?

Buuut it is really a nice example 🤷

And also I didn't expect things to compose so well out of the box 😍

To clarify, it is pretty simple to sample forwardly from HMM right?

Definitely. The only hurdle is that I don't use Distributions.Categorical to sample state transitions, but roll out my own categorical instead to avoid the dependence. So even if the observation distributions are Distributions, Turing might have a hard time with my sampler on the state part.

https://github.com/gdalle/HiddenMarkovModels.jl/blob/6bfb23a7684f3fcddfa51989716fdd88ed67c46f/src/types/abstract_hmm.jl#L145-L171

In HiddenMarkovModels.jl (IIUC), the current rand returns both the latent states and the observed states, hence it's not compatible with the current assumptions that we make

It's also a bit manual in that it doesn't follow the offial Random API of first generating a Sampler. Maybe I should?

@yebai
Copy link

yebai commented Apr 15, 2024

We could standardise what appears on the right side of tilde statements and provide a standard interface for user-provided distribution objects. At the moment, these can only be from Distributions.jl. We would like to support more, including HMMs, SSMs, BUGS models, Turing submodels, etc.

Related: TuringLang/DynamicPPL.jl#523

@torfjelde
Copy link

Probably the easiest way to make something that's "nicely" operable within a Turing.jl model is to just define a

struct MarginalizedHMM{M,V,F} <: Distribution{V,F}
    hmm::M
end

in combination with a nice marginlize(hmm) constructor that extracts the variate form, etc. from the emission distributions in hmm (though this would have to assume the emission distributions are all fo the "same" type; is this assumed in your work @gdalle ?).

Then you could do

y ~ marginalize(hmm)

This would of course still not expose the latent variables to Turing.jl as we'd ideally like, but it does make the experience a bit more seamless 🤷

@gdalle
Copy link

gdalle commented Apr 16, 2024

We could standardise what appears on the right side of tilde statements and provide a standard interface for user-provided distribution objects.

That was also the idea behind HiddenMarkovModels.jl: who cares if it's a Distributions.Distribution, as long as it has DensityInterface.logdensityof and Random.rand.

this would have to assume the emission distributions are all fo the "same" type; is this assumed in your work @gdalle ?

Yeah, you need to be able to apply every emission distribution to every observation object without error (the logdensity may return -Inf though). While it is theoretically possible to have a vector of emissions that looks like emissions=[Normal(), Gamma(), Exponential()], maybe there are some dispatches like fit(eltype(emissions), ...) that would cause trouble. They should be easy enough to replace with fit(typeof(emissions[i]), ...) if needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment