Skip to content

Instantly share code, notes, and snippets.

@ground0state
Created January 23, 2020 14:16
Show Gist options
  • Save ground0state/f2750abefe36e9379ff60a3bc8761df7 to your computer and use it in GitHub Desktop.
Save ground0state/f2750abefe36e9379ff60a3bc8761df7 to your computer and use it in GitHub Desktop.
from __future__ import print_function, division
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
import pystan
model = """
data {
int<lower=1> N;
int<lower=1> M;
real x[N];
}
parameters {
vector[M] mu;
vector<lower=0.0001>[M] sig;
simplex[M] pi;
}
model {
real ps[M];
for(n in 1:N){
for(m in 1:M){
ps[m] <- log(pi[m]) + normal_log(x[n], mu[m], sig[m]);
}
increment_log_prob(log_sum_exp(ps));
}
}
"""
# サンプル数
N = 5000
# 混合数
K = 2
# 混合係数
pi = 0.7
# 乱数の種
np.random.seed(0)
# 混合係数から各分布のサンプリング数を決める
N_k1 = N*pi
N_k2 = N-N_k1
# 真のパラメータ
mu1 = -5
sig1 = np.sqrt(25)
mu2 = 5
sig2 = np.sqrt(1)
x1 = np.random.normal(mu1,sig1,int(N_k1))
x2 = np.random.normal(mu2,sig2,int(N_k2))
# 観測変数
x = np.hstack((x1,x2))
base = np.linspace(np.min(x),np.max(x),1000)
plt.hist(x,bins=100,normed=True)
plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1))
plt.plot(base,(1-pi)*stats.norm.pdf(base,mu2,sig2))
plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1)+(1-pi)*stats.norm.pdf(base,mu2,sig2))
# Stan
stan_data = {'N': N, 'M': K, 'x': x}
model = pystan.StanModel(model_code=model)
fitchan = model.sampling(data=stan_data, iter=2000, warmup=100,chains=1)
fitchan.plot()
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment