Skip to content

Instantly share code, notes, and snippets.

@npyoung
Last active July 30, 2022 05:41
Show Gist options
  • Save npyoung/adc097f95c6148a5e31c2f388efaa697 to your computer and use it in GitHub Desktop.
Save npyoung/adc097f95c6148a5e31c2f388efaa697 to your computer and use it in GitHub Desktop.
A state space model distribution for pymc3
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@Nilavro
Copy link

Nilavro commented Mar 1, 2018

Try this piece of code out instead of the mcmc sampling

Replace : trace = mc.sample(10000, start=mc.find_MAP())
With: inference = mc.ADVI()
%%time approx = mc.fit(n=30000, method=inference)

ADVI is faster

@michael-ziedalski
Copy link

Did this end up going anywhere? The code looked promising.

@tanmoy7989
Copy link

tanmoy7989 commented Mar 22, 2019

Same question as above. Did you develop this further?
I am using this with the ADVI suggestion by Nilavro and it isn't slow!

@Nilavro
Copy link

Nilavro commented Mar 25, 2019

I used this code for my research. Check the package BSSPy for the work https://arxiv.org/pdf/1901.07469.pdf.

Just remember the caveats that if you use a mean-field VI then it does not preserve the Markov dependency. i.e. $N(x_t | x_t-1 )$
More elaborate discussions about mean-field and structured mean field here. http://www.ee.columbia.edu/~sfchang/course/svia-F03/papers/factorial-HMM-97.pdf

I think the full rank ADVI may preserve this dependency but can be slower.

Also, I would ask for some thoughts regarding the Tau initialization -


    A = mc.Normal('A', mu=np.eye(2), tau=1e-5, shape=(2,2))
    **Tau = mc.Gamma('tau', mu=100, sd=100)**
    
    X = StateSpaceModel('x', A=A, B=T.zeros((1,1)), u=T.zeros((x.shape[0],1)), tau=Tau, shape=(y.shape[0], 2))

It seems like a single random variable for the Precision term is being broadcasted. I would prefer to declare it as a covariance matrix instead. I am trying to put together a code with reparameterization trick and will push a new module soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment