Created
January 23, 2020 14:16
-
-
Save ground0state/f2750abefe36e9379ff60a3bc8761df7 to your computer and use it in GitHub Desktop.
This code is from http://lofas.hatenablog.com/entry/2015/03/05/160230.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function, division | |
import numpy as np | |
from scipy import stats | |
import matplotlib.pyplot as plt | |
import pystan | |
model = """ | |
data { | |
int<lower=1> N; | |
int<lower=1> M; | |
real x[N]; | |
} | |
parameters { | |
vector[M] mu; | |
vector<lower=0.0001>[M] sig; | |
simplex[M] pi; | |
} | |
model { | |
real ps[M]; | |
for(n in 1:N){ | |
for(m in 1:M){ | |
ps[m] <- log(pi[m]) + normal_log(x[n], mu[m], sig[m]); | |
} | |
increment_log_prob(log_sum_exp(ps)); | |
} | |
} | |
""" | |
# サンプル数 | |
N = 5000 | |
# 混合数 | |
K = 2 | |
# 混合係数 | |
pi = 0.7 | |
# 乱数の種 | |
np.random.seed(0) | |
# 混合係数から各分布のサンプリング数を決める | |
N_k1 = N*pi | |
N_k2 = N-N_k1 | |
# 真のパラメータ | |
mu1 = -5 | |
sig1 = np.sqrt(25) | |
mu2 = 5 | |
sig2 = np.sqrt(1) | |
x1 = np.random.normal(mu1,sig1,int(N_k1)) | |
x2 = np.random.normal(mu2,sig2,int(N_k2)) | |
# 観測変数 | |
x = np.hstack((x1,x2)) | |
base = np.linspace(np.min(x),np.max(x),1000) | |
plt.hist(x,bins=100,normed=True) | |
plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1)) | |
plt.plot(base,(1-pi)*stats.norm.pdf(base,mu2,sig2)) | |
plt.plot(base,pi*stats.norm.pdf(base,mu1,sig1)+(1-pi)*stats.norm.pdf(base,mu2,sig2)) | |
# Stan | |
stan_data = {'N': N, 'M': K, 'x': x} | |
model = pystan.StanModel(model_code=model) | |
fitchan = model.sampling(data=stan_data, iter=2000, warmup=100,chains=1) | |
fitchan.plot() | |
plt.tight_layout() | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment