Skip to content

Instantly share code, notes, and snippets.

@el-hult
Last active May 13, 2022 15:57
Show Gist options
  • Save el-hult/e4ea52f35259cd871b197d72de54e503 to your computer and use it in GitHub Desktop.
Save el-hult/e4ea52f35259cd871b197d72de54e503 to your computer and use it in GitHub Desktop.
import numpy as np
import scipy.stats as sps
import matplotlib.pyplot as plt
from scipy.special import digamma
#
# Config
#
np.random.seed(999)
plt.rcParams['font.family'] = "monospace"
plt.rcParams['figure.autolayout'] = True
#
# Generate data
#
a = 0.7
c = 1
d = 20
T = 30
w = 0.2
U = np.random.standard_normal(size=T)
V = np.random.standard_normal(size=T)
Z = (np.random.random(size=T) < w).astype(int)
Xt = 0
X = np.zeros(T)
for t in range(T):
Xt = a * Xt + c*U[t] + d*Z[t] *V[t]
X[t] = Xt
#
# Set up a prior
#
mu0 = 0 # prior mean on A
sigma20 = 1 # prior variance on A
alpha0 = 0.5 # jeffrey prior
beta0 = 0.5 # jeffrey prior
#
# Compute posterior iteratively
# using VI and factorization approrixmation
#
maxiter=10 # number of iterates to perform
gamma = 0.3*np.ones(T) # initialization
y2 = np.zeros(T) # work vector
sigma2s = np.zeros(maxiter) # output vector
mus = np.zeros(maxiter) # output vector
alphas = np.zeros(maxiter) # output vector
betas = np.zeros(maxiter) # output vector
for iter in range(maxiter):
eta = (1-gamma)/c**2 + gamma / (c**2+d**2)
muT = sigma20*np.sum(eta[1:]*X[1:]*X[:-1]) /(1+np.sum(eta[1:]*X[:-1]**2))
sigma2T = sigma20 /(1+np.sum(eta[1:]*X[:-1]**2))
alphaT = alpha0 + np.sum(gamma)
betaT = beta0 + np.sum(1-gamma)
y2[0] = X[0]**2
y2[1:] = X[1:]**2 -2*muT*X[1:]*X[:-1] + (muT**2+sigma2T)*X[:-1]**2
r1 = digamma(alphaT) - 0.5*np.log(c**2+d**2) - y2/2/(c**2+d**2)
r0 = digamma(betaT ) - 0.5*np.log(c**2 ) - y2/2/(c**2 )
gamma = np.exp(r1) / (np.exp(r0)+np.exp(r1))
sigma2s[iter]=sigma2T
mus[iter]=muT
alphas[iter]=alphaT
betas[iter]=betaT
#
# Plot results and data
#
fig,axs= plt.subplots(1,2,sharex=True,figsize=plt.figaspect(1/2))
iters = np.arange(maxiter)
axs[0].plot(iters,mus)
axs[0].fill_between(iters,
mus+sigma2s*1.96,
mus-sigma2s*1.96,
alpha=0.3)
axs[0].axhline(a,color='black',linestyle='dashed')
axs[0].set_ylabel("$a$")
axs[1].plot(alphas/(alphas+betas))
axs[1].fill_between(iters,
[sps.beta(a,b).ppf( 0.05/2) for a,b in zip(alphas,betas)],
[sps.beta(a,b).ppf(1-0.05/2) for a,b in zip(alphas,betas)],
alpha=0.3)
axs[1].axhline(w,color='black',linestyle='dashed')
axs[1].set_ylabel("$w$")
axs[1].set_xlabel("iter")
axs[0].set_xlabel("iter")
fig,ax=plt.subplots()
ax.plot(X,alpha=0.3,color='C0')
ax.scatter(np.arange(T)[Z==1],X[Z==1],alpha=0.3,color='C1')
ax.set_xlabel("$t$")
ax.set_ylabel("$X_t$")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment