Last active
February 27, 2016 02:12
-
-
Save EnsekiTT/f0ad3043a31e88e4f9d4 to your computer and use it in GitHub Desktop.
Reinforcement Learningの4章にあるプログラミングをやってみた。ギャンブラーの最適行動について。
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # -*- coding:utf-8 -*- | |
| import matplotlib.pyplot as plt | |
| ## 定数達 | |
| # ギャンブラーの所持金 | |
| S = list(range(1, 100)) | |
| # 状態価値関数 | |
| V = 101*[0.0] | |
| # 勝利条件 | |
| V[100] = 1.0 | |
| # 過去のV | |
| Vhist = [] | |
| # 方策 | |
| pi = list(range(101)) | |
| # 表が出る確率 | |
| p = 0.4 | |
| # 帰ってくるお金 | |
| def gamble(s, bet): | |
| global V,p | |
| return p*V[s+bet] + (1.0-p)*V[s-bet] | |
| # 価値反復 | |
| def value_iteration(delta, theta): | |
| global V, Vhist, S | |
| while delta >= theta: | |
| delta = 0.0 | |
| for s in S: | |
| v = V[s] | |
| prov = 0.0 | |
| for bet in range(1,min(s,100-s)+1): | |
| temp = gamble(s, bet) | |
| cand = max(temp, prov) | |
| V[s] = cand | |
| delta = max(delta, abs(v-V[s])) | |
| Vhist.append(list(V)) | |
| def view_Vhist(): | |
| global Vhist | |
| for Vt in Vhist: | |
| print(Vt) | |
| plt.plot(Vt) | |
| plt.show() | |
| # 決定論的方策π(pi)を求める | |
| def get_pi(threshold): | |
| global V, S | |
| for s in S: | |
| prov = 0 | |
| for bet in range(1,min(s,100-s)+1): | |
| temp = gamble(s, bet) | |
| # 制限を入れるとうまくいく→ほんとにいった | |
| # http://kivantium.hateblo.jp/entry/2015/09/29/181954 | |
| if temp > (prov + threshold): | |
| prov = temp | |
| pi[s] = bet | |
| return pi | |
| # pi表示 | |
| def view_pi(pi): | |
| for i in range(1,100): | |
| print(str(pi[i]) + ',', end="") | |
| print('') | |
| plt.bar(list(range(len(pi[:99]))), pi[:99]) | |
| plt.show() | |
| if __name__ == '__main__': | |
| # ループ終了のしきい値 | |
| theta = 1e-12 | |
| #最大変更量 | |
| delta = 1.0 | |
| value_iteration(delta, theta) | |
| # 表示 | |
| view_Vhist() | |
| # 更新制限 | |
| threshold = 1e-6 | |
| pi = get_pi(threshold) | |
| # 表示 | |
| view_pi(pi) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment