Last active
December 23, 2015 18:57
-
-
Save slinderman/1555e77aff0b2c7dc44d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division | |
import numpy as np | |
np.seterr(divide='ignore') # these warnings are usually harmless for this code | |
from matplotlib import pyplot as plt | |
import matplotlib | |
import os | |
matplotlib.rcParams['font.size'] = 8 | |
import pyhsmm | |
from pyhsmm.util.text import progprint_xrange | |
######################### | |
# posterior inference # | |
######################### | |
# Set the weak limit truncation level | |
Nmax = 100 | |
# and some hyperparameters | |
obs_dim = 1 | |
gauss_hypparams = {'mu_0':np.zeros(obs_dim), | |
'sigma_0':np.eye(obs_dim), | |
'kappa_0':0.25, | |
'nu_0':obs_dim+2} | |
poiss_hypparams = {'alpha_0': 1.0, 'beta_0': 1.0} | |
### HDP-HMM to generate true data | |
obs_distns = [pyhsmm.distributions.Gaussian(**gauss_hypparams) for state in xrange(Nmax)] | |
#obs_distns = [pyhsmm.distributions.Poisson(**poiss_hypparams) for state in xrange(Nmax)] | |
model = pyhsmm.models.WeakLimitHDPHMM(alpha=6.,gamma=6.,init_state_concentration=1., | |
obs_distns=obs_distns) | |
data, Z_true = model.generate(T=1000) | |
K_true = len(model.used_states) | |
# Fit with a HDPHMM with concentration parameter resampling | |
obs_distns = [pyhsmm.distributions.Gaussian(**gauss_hypparams) for state in xrange(Nmax)] | |
#obs_distns = [pyhsmm.distributions.Poisson(**poiss_hypparams) for state in xrange(Nmax)] | |
posteriormodel = pyhsmm.models.WeakLimitHDPHMM(alpha_a_0=6.0, alpha_b_0=1.0, | |
gamma_a_0=6., gamma_b_0=1.0, | |
init_state_concentration=1., | |
obs_distns=obs_distns) | |
posteriormodel.add_data(data) | |
# Fit with a HDPHMM *without* concentration parameter resampling | |
obs_distns = [pyhsmm.distributions.Gaussian(**gauss_hypparams) for state in xrange(Nmax)] | |
#obs_distns = [pyhsmm.distributions.Poisson(**poiss_hypparams) for state in xrange(Nmax)] | |
posteriormodel_concfixed = pyhsmm.models.WeakLimitHDPHMM(alpha=6.0, gamma=6., | |
init_state_concentration=1., | |
obs_distns=obs_distns) | |
posteriormodel_concfixed.add_data(data) | |
Ks_concresample = [] | |
Ks_concfixed = [] | |
for idx in progprint_xrange(100): | |
posteriormodel.resample_model() | |
Ks_concresample.append(len(posteriormodel.used_states)) | |
posteriormodel_concfixed.resample_model() | |
Ks_concfixed.append(len(posteriormodel_concfixed.used_states)) | |
plt.figure() | |
plt.plot(np.array(Ks_concresample), "-b", label="conc resampling") | |
plt.plot(np.array(Ks_concfixed), "-r", label="conc fixed") | |
plt.plot([0,100], K_true * np.ones(2), ':k', label="True") | |
plt.legend(loc="upper left") | |
plt.ylim(0,100) | |
plt.title("Obs distn: %s" % obs_distns[0].__class__) | |
plt.xlabel("Iteration") | |
plt.ylabel("Number of States") | |
plt.savefig("num_states.png") | |
plt.show() | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment