Last active
March 16, 2020 19:51
-
-
Save brandonwillard/0f95fff4644ff8d0360c50fd17af7b7f to your computer and use it in GitHub Desktop.
Simple Theano DLM Simulation Example using Symbolic PyMC
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import theano | |
import theano.tensor as tt | |
import matplotlib.pyplot as plt | |
from theano.printing import debugprint as tt_dprint | |
from symbolic_pymc.theano.random_variables import NormalRV, MvNormalRV, InvGammaRV, observed | |
theano.config.cxx = "" | |
theano.config.mode = "FAST_COMPILE" | |
tt.config.compute_test_value = 'ignore' | |
# y_tt = tt.dvector(name='y') | |
# N_t = y_tt.size | |
N_t = tt.iscalar("N_t") | |
G_tt = tt.dmatrix('G_tt') | |
F_tt = tt.dmatrix('F_tt') | |
N_theta_tt = F_tt.shape[-1] | |
theta_0_rv = MvNormalRV(tt.zeros([N_theta_tt]), 10. * tt.eye(N_theta_tt), name='theta') | |
nu_scale = InvGammaRV(0.5, 0.5, name='nu_scale') | |
eps_scale = InvGammaRV(0.5, 0.5, name='eps_scale') | |
def state_step(theta_tm1, G, nu_scale): | |
nu_rv = MvNormalRV(tt.zeros_like(theta_tm1), | |
nu_scale * tt.eye(theta_tm1.shape[-1]), | |
name='nu') | |
return G.dot(theta_tm1) + nu_rv | |
theta_rv, _ = theano.scan(fn=state_step, | |
non_sequences=[G_tt, nu_scale], | |
outputs_info={"initial": theta_0_rv}, | |
n_steps=N_t, | |
name='theta') | |
def obs_step(theta_t, F, eps_scale): | |
eps_rv = NormalRV(0, eps_scale, name='eps') | |
return F.dot(theta_t) + eps_rv | |
Y_rv, _ = theano.scan(fn=obs_step, | |
non_sequences=[F_tt, eps_scale], | |
sequences=[theta_rv], | |
name='Y') | |
tt_dprint(Y_rv) | |
y_sim = Y_rv.eval({N_t: 100, | |
G_tt: np.c_[[1., 0.], [0.5, 0.1]].astype(tt.config.floatX), | |
F_tt: np.c_[[1.], [0.2]].astype(tt.config.floatX)}) | |
plt.plot(y_sim) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment