Last active
October 25, 2017 15:06
-
-
Save PonDad/ee2fbe3dcb871045bd69b59ab39c27f8 to your computer and use it in GitHub Desktop.
reinforcement-learning
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import pandas as pd | |
class QLearning: | |
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9): | |
# actions = [0, 1, 2, 3] | |
self.actions = actions | |
self.alpha = learning_rate | |
self.discount_factor = reward_decay | |
self.epsilon = e_greedy | |
self.q_table = pd.DataFrame(columns=self.actions) | |
# 以前に行ったstateでないかを判別して、行っていないstateであれば初期化 | |
def check_state_exist(self, state): | |
# 行っていないstateの場合にのみ初期化 | |
if state not in self.q_table.index: | |
# 新しいstateをq_tableに追加 | |
# 初期化は、[0、0、0、0]で | |
self.q_table = self.q_table.append( | |
pd.Series( | |
[0] * len(self.actions), | |
index=self.q_table.columns, | |
name=state, | |
) | |
) | |
# Q関数をQ学習アルゴリズムに基づいて更新する | |
def learn(self, s, a, r, s_): | |
# まず行ったことがあることを確認して、いない場合初期化 | |
self.check_state_exist(s_) | |
q_1 = self.q_table.ix[s, a] | |
# 次の状態のQ関数の最大を求める | |
q_2 = r + self.discount_factor * self.q_table.ix[s_, :].max() | |
self.q_table.ix[s, a] += self.alpha * (q_2 - q_1) | |
# 現在の状態についての行動を受けてくる関数 | |
def get_action(self, state): | |
self.check_state_exist(state) | |
# epsilonよりrand関数で選ばれた数が少ない場合、Q関数による行動リターンを得る | |
if np.random.rand() < self.epsilon: | |
# 最適の行動の選択 | |
state_action = self.q_table.ix[state, :] | |
state_action = state_action.reindex(np.random.permutation(state_action.index)) | |
action = state_action.argmax() | |
# epsilonよりrand関数で選ばれた数が大きい場合、ランダムに行動を返す | |
else: | |
# 任意の行動を選択 | |
action = np.random.choice(self.actions) | |
return action |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
np.random.seed(1) | |
import tkinter as tk | |
import time | |
from PIL import ImageTk, Image | |
UNIT = 100 # pixels | |
HEIGHT = 5 # grid height | |
WIDTH = 5 # grid width | |
class Env(tk.Tk): | |
def __init__(self): | |
super(Env, self).__init__() | |
self.action_space = ['u', 'd', 'l', 'r'] | |
self.n_actions = len(self.action_space) | |
self.title('q learning') | |
self.geometry('{0}x{1}'.format(HEIGHT * UNIT, HEIGHT * UNIT)) | |
self.buildGraphic() | |
self.texts = [] | |
def buildGraphic(self): | |
self.canvas = tk.Canvas(self, bg='white', | |
height=HEIGHT * UNIT, | |
width=WIDTH * UNIT) | |
# create grids | |
for c in range(0, WIDTH * UNIT, UNIT): # 0~400 by 80 | |
x0, y0, x1, y1 = c, 0, c, HEIGHT * UNIT | |
self.canvas.create_line(x0, y0, x1, y1) | |
for r in range(0, HEIGHT * UNIT, UNIT): # 0~400 by 80 | |
x0, y0, x1, y1 = 0, r, HEIGHT * UNIT, r | |
self.canvas.create_line(x0, y0, x1, y1) | |
# image_load | |
self.rectangle_image = ImageTk.PhotoImage(Image.open("../resources/rectangle.png").resize((65, 65), Image.ANTIALIAS)) | |
self.triangle_image = ImageTk.PhotoImage(Image.open("../resources/triangle.png").resize((65, 65))) | |
self.circle_image = ImageTk.PhotoImage(Image.open("../resources/circle.png").resize((65, 65))) | |
# add image to canvas | |
self.cat = self.canvas.create_image(50, 50, image=self.rectangle_image) | |
self.triangle1 = self.canvas.create_image(250, 150, image=self.triangle_image) | |
self.triangle2 = self.canvas.create_image(150, 250, image=self.triangle_image) | |
self.circle = self.canvas.create_image(250, 250, image=self.circle_image) | |
# pack all | |
self.canvas.pack() | |
def reset(self): | |
self.update() | |
time.sleep(0.5) | |
self.canvas.delete(self.cat) | |
origin = np.array([UNIT / 2, UNIT / 2]) | |
self.cat = self.canvas.create_image(50, 50, image=self.rectangle_image) | |
# return observation | |
return self.coords_to_state(self.canvas.coords(self.cat)) | |
def text_value(self, row, col, contents, action, font='Helvetica', size=10, style='normal', anchor="nw"): | |
if action == 0: | |
origin_x, origin_y = 7, 42 | |
elif action == 1: | |
origin_x, origin_y = 85, 42 | |
elif action == 2: | |
origin_x, origin_y = 42, 5 | |
else: | |
origin_x, origin_y = 42, 77 | |
x, y = origin_y + (UNIT * col), origin_x + (UNIT * row) | |
font = (font, str(size), style) | |
return self.texts.append(self.canvas.create_text(x, y, fill="black", text=contents, font=font, anchor=anchor)) | |
def print_value_all(self, q_table): | |
for i in self.texts: | |
self.canvas.delete(i) | |
self.texts.clear() | |
for i in range(HEIGHT): | |
for j in range(WIDTH): | |
for action in range(0, 4): | |
state = [i, j] | |
if str(state) in q_table.index: | |
temp = q_table.ix[str(state), action] | |
self.text_value(j, i, round(temp, 2), action) | |
def coords_to_state(self, coords): | |
x = int((coords[0] - 50) / 100) | |
y = int((coords[1] - 50) / 100) | |
return [x, y] | |
def state_to_coords(self, state): | |
x = int(state[0] * 100 + 50) | |
y = int(state[1] * 100 + 50) | |
return [x, y] | |
def step(self, action): | |
s = self.canvas.coords(self.cat) | |
base_action = np.array([0, 0]) | |
self.render() | |
if action == 0: # up | |
if s[1] > UNIT: | |
base_action[1] -= UNIT | |
elif action == 1: # down | |
if s[1] < (HEIGHT - 1) * UNIT: | |
base_action[1] += UNIT | |
elif action == 2: # left | |
if s[0] > UNIT: | |
base_action[0] -= UNIT | |
elif action == 3: # right | |
if s[0] < (WIDTH - 1) * UNIT: | |
base_action[0] += UNIT | |
self.canvas.move(self.cat, base_action[0], base_action[1]) # move agent | |
s_ = self.canvas.coords(self.cat) # next state | |
# reward function | |
if s_ == self.canvas.coords(self.circle): | |
reward = 100 | |
done = True | |
elif s_ in [self.canvas.coords(self.triangle1), self.canvas.coords(self.triangle2)]: | |
reward = -100 | |
done = True | |
else: | |
reward = 0 | |
done = False | |
s_ = self.coords_to_state(s_) | |
return s_, reward, done | |
def render(self): | |
time.sleep(0.05) | |
self.update() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from environment import Env | |
from agent import QLearning | |
def update(): | |
for episode in range(1000): | |
# 環境の初期化と環境から現在の状態受け来る | |
state = env.reset() | |
while True: | |
# Guiレンダリング | |
env.render() | |
# エージェントからその状態の行動を受けとる | |
action = agent.get_action(str(state)) | |
# エージェントの行動をとって、次の状態と報酬とのエピソードが終了したかどうかを受けとる | |
state_, reward, done = env.step(action) | |
# エージェントのlearn関数S A R S_に代入する | |
agent.learn(str(state), action, reward, str(state_)) | |
# 現在の状態で、次の状態を代入 | |
state = state_ | |
env.print_value_all(agent.q_table) | |
# エピソードが終了し、break | |
if done: | |
break | |
# すべてのエピソードが終わったらゲームオーバー | |
print('game over') | |
# env.destroy() | |
if __name__ == "__main__": | |
env = Env() | |
agent = QLearning(actions=list(range(env.n_actions))) | |
update() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment