Last active
May 17, 2020 01:07
-
-
Save AnchorBlues/31e687173db81f3f46c250b41a2706a9 to your computer and use it in GitHub Desktop.
stanファイルが更新されていたときのみ、コンパイルをし直すような機構を実装する
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import hashlib | |
import dill | |
from pystan import StanModel | |
def textfile2hash(filename): | |
""" | |
テキストファイルの内容からハッシュ値を計算する | |
""" | |
with open(filename, mode="r") as f: | |
lines = f.readlines() | |
line = "".join(lines) | |
return hashlib.sha256(line.encode()).hexdigest() | |
def load_model(stan_file_path: str, model_dir: str, ext: str = "pickle") -> StanModel: | |
""" | |
1. stanファイルからハッシュ値を計算 | |
* ハッシュ値が一致するモデルオブジェクトファイルが存在なかった時、モデルのコンパイル・保存を行う | |
2. 当該モデルオブジェクトファイルからモデルをロードする | |
""" | |
stan_file_hash = textfile2hash(stan_file_path) | |
model_file_path = os.path.join( | |
model_dir, "{}.{}".format(stan_file_hash, ext)) | |
if not os.path.exists(model_file_path): | |
print("モデルファイル存在しなかったのでコンパイルします") | |
stm = StanModel(file=stan_file_path) | |
print("コンパイルが完了したので、保存を行います") | |
os.makedirs(os.path.dirname(model_file_path), exist_ok=True) | |
with open(model_file_path, mode="wb") as f: | |
dill.dump(stm, f) | |
print("モデルのロードを行います") | |
with open(model_file_path, mode="rb") as f: | |
stm = dill.load(f) | |
return stm | |
if __name__ == "__main__": | |
# stanファイルが"school_model.stan"、 | |
# pickle化したモデルオブジェクトファイルの保存先が"./stan_models"である場合の使用例 | |
model = load_model("school_model.stan", model_dir="./stan_models") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment