Skip to content

Instantly share code, notes, and snippets.

@brandonwillard
Created October 15, 2021 23:46
Show Gist options
  • Save brandonwillard/034d2d3f578a867818d327f243443761 to your computer and use it in GitHub Desktop.
Save brandonwillard/034d2d3f578a867818d327f243443761 to your computer and use it in GitHub Desktop.
An HMM with time-varying transitions matrices constructed in AePPL
import aesara
import aesara.tensor as at
import numpy as np
from aeppl.joint_logprob import factorized_joint_logprob
# aesara.config.compute_test_value = "warn"
srng = at.random.RandomStream(seed=2320)
N = 100
M = 10
p_S_0 = np.array([0.9, 0.1])
S_0_rv = srng.categorical(p_S_0, name="S_0")
S_0_rv.tag.test_value = 0
s_0_vv = S_0_rv.clone()
s_0_vv.name = "s_0"
#
# Design matrix for regressions
#
X = np.random.normal(size=(N, M))
X_at = aesara.shared(X, name="X", borrow=True)
#
# Create the emissions distributions/mixture
#
beta_rv = srng.normal(0, 1, size=M, name="beta")
beta_vv = beta_rv.clone()
beta_vv.name = "beta_vv"
mu_at = at.sigmoid(X_at.dot(beta_rv))
Y_binom_rv = srng.binomial(50, mu_at, size=N, name="Y_binom")
Y_pois_rv = srng.poisson(0, size=N, name="Y_pois")
mixture_rv = at.stack([Y_binom_rv, Y_pois_rv], axis=1)
#
# Create a time-varying transition matrix driven by logistic regressions
#
xi_rv_vv = {}
z_parts = []
for s in range(2):
xi_rv = srng.normal(0, 1, size=M, name=f"xi_{s}")
xi_vv = xi_rv.clone()
xi_vv.name = f"xi_{s}_vv"
xi_rv_vv[xi_rv] = xi_vv
z_part = X_at.dot(at.shape_padright(xi_rv))
z_parts.append(z_part)
z_tt = at.stack(z_parts, axis=1)
sig_z = at.sigmoid(z_tt)
Gammas_at = at.concatenate([sig_z, 1.0 - sig_z], axis=2)
def step_fn(Gamma_t, mixture_t, S_tm1):
S_t = srng.categorical(Gamma_t[S_tm1], name="S_t")
# TODO: Define the mixture here using `at.switch` or `ifelse`
# See https://github.com/aesara-devs/aeppl/issues/76 and
# https://github.com/aesara-devs/aeppl/issues/77
obs_t = mixture_t[S_t]
return S_t, obs_t
(S_1T_rv, Y_1T_rv), _ = aesara.scan(
fn=step_fn,
outputs_info=[{"initial": S_0_rv, "taps": [-1]}, None],
sequences=[Gammas_at, mixture_rv],
strict=True,
n_steps=N,
name="S_0T",
)
S_1T_rv.name = "S_1T"
s_1T_vv = S_1T_rv.clone()
s_1T_vv.name = "s_1T"
# Assign a value variable to the `Scan` output
Y_1T_rv.name = "Y_1T"
y_1T_vv = Y_1T_rv.clone()
y_1T_vv.name = "y_1T"
rv_vv_map = {
beta_rv: beta_vv,
Y_1T_rv: y_1T_vv,
S_1T_rv: s_1T_vv,
S_0_rv: s_0_vv,
}
rv_vv_map.update(xi_rv_vv)
# XXX: Using `mixture_rv` as a `Scan` sequence unput isn't supported yet.
# See https://github.com/aesara-devs/aeppl/issues/75 and the TODO in `step_fn`
logp_parts = factorized_joint_logprob(rv_vv_map)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment