Last active
June 20, 2020 05:58
-
-
Save brandonwillard/b1b1a410d2e8f83761a7353c871a3373 to your computer and use it in GitHub Desktop.
PyMC3 Poisson-Zero HMM Testing Example
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import theano.tensor as tt | |
import pymc3 as pm | |
import arviz as az | |
import matplotlib.pyplot as plt | |
from pymc3_hmm.distributions import HMMStateSeq, SwitchingProcess | |
from pymc3_hmm.step_methods import FFBSStep, TransMatConjugateStep | |
def simulate_poiszero_hmm(N, mus=[10.0, 30.0], | |
pi_0_a=np.r_[1, 1, 1], | |
Gamma=np.r_['0,2,1', | |
[5, 1, 1], | |
[1, 3, 1], | |
[1, 1, 5]] | |
): | |
assert pi_0_a.size == mus.size + 1 == Gamma.shape[0] == Gamma.shape[1] | |
with pm.Model() as test_model: | |
trans_rows = [pm.Dirichlet(f'p_{i}', r) for i, r in enumerate(Gamma)] | |
P_tt = tt.stack(trans_rows) | |
P_rv = pm.Deterministic('P_tt', P_tt) | |
pi_0_tt = pm.Dirichlet('pi_0', pi_0_a) | |
S_rv = HMMStateSeq('S_t', N, P_rv, pi_0_tt) | |
Y_rv = SwitchingProcess('Y_t', | |
[pm.Constant.dist(0)] + [pm.Poisson.dist(mu) | |
for mu in mus], | |
S_rv, observed=np.zeros(N)) | |
y_test_point = pm.sample_prior_predictive(samples=1) | |
return y_test_point, test_model | |
# Simulate some data from a Poisson-Zero HMM | |
poiszero_sim, _ = simulate_poiszero_hmm(3000, np.r_[5000, 7000]) | |
y_test = poiszero_sim['Y_t'] | |
# Plot the simulated observations and true underlying state sequence | |
fig, ax = plt.subplots(figsize=(15, 6.0), nrows=2) | |
ax[0].plot(y_test, | |
label=r'$y_t$', color='black', | |
drawstyle='steps-pre', linewidth=0.5) | |
ax[1].plot(poiszero_sim['S_t'], | |
label=r'$S_t$', color='blue', | |
drawstyle='steps-pre', linewidth=0.5) | |
for ax_ in ax: | |
ax_.legend() | |
plt.tight_layout() | |
# Define a model with (mostly) the same assumptions as our simulated data | |
with pm.Model() as test_model: | |
p_0_rv = pm.Dirichlet('p_0', np.r_[1, 1, 1]) | |
p_1_rv = pm.Dirichlet('p_1', np.r_[1, 1, 1]) | |
p_2_rv = pm.Dirichlet('p_2', np.r_[1, 1, 1]) | |
P_tt = tt.stack([p_0_rv, p_1_rv, p_2_rv]) | |
P_rv = pm.Deterministic('P_tt', P_tt) | |
pi_0_tt = poiszero_sim['pi_0'] | |
S_rv = HMMStateSeq('S_t', y_test.shape[0], P_rv, pi_0_tt) | |
S_rv.tag.test_value = (y_test > 0).astype(np.int) | |
E_1_mu, Var_1_mu = 1000.0, 1000.0 | |
mu_1_rv = pm.Gamma('mu_1', E_1_mu**2 / Var_1_mu, E_1_mu / Var_1_mu) | |
E_2_mu, Var_2_mu = 100.0, 100.0 | |
mu_2_rv = pm.Gamma('mu_2', E_2_mu**2 / Var_2_mu, E_2_mu / Var_2_mu) | |
Y_rv = SwitchingProcess('Y_t', | |
[pm.Constant.dist(0), | |
pm.Poisson.dist(mu_1_rv), | |
pm.Poisson.dist(mu_1_rv + mu_2_rv)], | |
S_rv, observed=y_test) | |
# Generate posterior samples using our step method(s) (and possibly other ones) | |
with test_model: | |
mu_step = pm.NUTS([mu_1_rv, mu_2_rv]) | |
ffbs = FFBSStep([S_rv]) | |
transitions = TransMatConjugateStep([p_0_rv, p_1_rv, p_2_rv], S_rv) | |
steps = [ | |
ffbs, | |
mu_step, | |
transitions | |
] | |
trace = pm.sample(2000, | |
step=steps, | |
chains=1, | |
return_inferencedata=True) | |
# Plot the posterior sample chains and their marginal distributions | |
pm.traceplot(trace, var_names=['mu_1', 'mu_2', 'p_0', 'p_1', 'p_2'], compact=True) | |
# Plot the posterior auto-correlations | |
az.plot_autocorr(trace, var_names=['mu_1', 'mu_2', 'p_0', 'p_1', 'p_2']) | |
# Sample posterior predictive values | |
with test_model: | |
adds_pois_ppc = pm.sample_posterior_predictive(trace.posterior) | |
az_post_trace = az.from_pymc3(posterior_predictive=adds_pois_ppc) | |
# Compute high-density intervals for the posterior predictive samples | |
post_pred_imps_hpd_df = az.hdi(az_post_trace, hdi_prob=0.97, group='posterior_predictive', | |
var_names=['Y_t']).to_dataframe() | |
post_pred_imps_hpd_df = post_pred_imps_hpd_df.unstack(level='hdi') | |
post_pred_imps_hpd_df.columns = post_pred_imps_hpd_df.columns.set_levels(['lower', 'upper'], level='hdi') | |
# Plot the observed values and the posterior predictive high-density intervals | |
fig, ax = plt.subplots(figsize=(18, 4.8)) | |
ax.plot(y_test, | |
label=r'$y_t$', | |
alpha=0.5, | |
color='red', | |
linewidth=0.8, | |
drawstyle='steps') | |
ax.fill_between(post_pred_imps_hpd_df.index, | |
post_pred_imps_hpd_df['Y_t']['lower'], | |
post_pred_imps_hpd_df['Y_t']['upper'], | |
label=r'97\% HDI interval', | |
color='b', step='pre', | |
alpha=0.3) | |
ax.legend() | |
plt.tight_layout() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment