Skip to content

Instantly share code, notes, and snippets.

@ground0state
Last active January 23, 2020 14:13
Show Gist options
  • Save ground0state/3ceeb75b458c3a4ec5e4d079f5e71def to your computer and use it in GitHub Desktop.
Save ground0state/3ceeb75b458c3a4ec5e4d079f5e71def to your computer and use it in GitHub Desktop.
pystan sample
import numpy as np
import pystan
import matplotlib.pyplot as plt
import pickle
# Stanモデル
model = """
data {
int<lower=0> N; // 学習データの数
int<lower=0> M; // 事後予測の数
real X[N]; //
}
parameters { // parameterセクション
real mu; // 平均
real<lower=0> sigma; // 標準偏差
}
model { // モデルを宣言するmodelセクション
for (n in 1:N) {
X[n] ~ normal(mu, sigma); // 線形単回帰モデル
}
}
generated quantities { // 事後予測分布を出力するセクション
real y[M];
for (m in 1:M){
y[m] = normal_rng(mu, sigma);
}
}
"""
# コンパイル関数
def stan_model_cache(model_code, model_name=None, **kwargs):
"""Use just as you would `stan`"""
if model_name is None:
cache_fn = 'cached-model.pkl'
else:
cache_fn = 'cached-model-{}.pkl'.format(model_name)
try:
sm = pickle.load(open(cache_fn, 'rb'))
except:
sm = pystan.StanModel(model_code=model_code)
with open(cache_fn, 'wb') as f:
pickle.dump(sm, f)
else:
print("Using cached StanModel")
return sm
# MCMC
X = np.random.normal(0,5,100)
stan_data = {'N': len(X), 'X': X, 'M': 200}
sm = stan_model_cache(model_code=model, model_name="default")
fit = sm.sampling(
data=stan_data,
chains=4,
iter=2000,
warmup=1000,
thin=1,
seed=1,
n_jobs=-1
)
# トレースプロット
fit.plot()
plt.tight_layout()
plt.show()
# 事後予測分布
samples = fit.extract()
y_pred = samples["y"].reshape(-1)
plt.hist(y_pred)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment