Skip to content

Instantly share code, notes, and snippets.

@PonDad
Last active October 25, 2017 15:06
Show Gist options
  • Save PonDad/ee2fbe3dcb871045bd69b59ab39c27f8 to your computer and use it in GitHub Desktop.
Save PonDad/ee2fbe3dcb871045bd69b59ab39c27f8 to your computer and use it in GitHub Desktop.
reinforcement-learning
import numpy as np
import pandas as pd
class QLearning:
def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
# actions = [0, 1, 2, 3]
self.actions = actions
self.alpha = learning_rate
self.discount_factor = reward_decay
self.epsilon = e_greedy
self.q_table = pd.DataFrame(columns=self.actions)
# 以前に行ったstateでないかを判別して、行っていないstateであれば初期化
def check_state_exist(self, state):
# 行っていないstateの場合にのみ初期化
if state not in self.q_table.index:
# 新しいstateをq_tableに追加
# 初期化は、[0、0、0、0]で
self.q_table = self.q_table.append(
pd.Series(
[0] * len(self.actions),
index=self.q_table.columns,
name=state,
)
)
# Q関数をQ学習アルゴリズムに基づいて更新する
def learn(self, s, a, r, s_):
# まず行ったことがあることを確認して、いない場合初期化
self.check_state_exist(s_)
q_1 = self.q_table.ix[s, a]
# 次の状態のQ関数の最大を求める
q_2 = r + self.discount_factor * self.q_table.ix[s_, :].max()
self.q_table.ix[s, a] += self.alpha * (q_2 - q_1)
# 現在の状態についての行動を受けてくる関数
def get_action(self, state):
self.check_state_exist(state)
# epsilonよりrand関数で選ばれた数が少ない場合、Q関数による行動リターンを得る
if np.random.rand() < self.epsilon:
# 最適の行動の選択
state_action = self.q_table.ix[state, :]
state_action = state_action.reindex(np.random.permutation(state_action.index))
action = state_action.argmax()
# epsilonよりrand関数で選ばれた数が大きい場合、ランダムに行動を返す
else:
# 任意の行動を選択
action = np.random.choice(self.actions)
return action
import numpy as np
np.random.seed(1)
import tkinter as tk
import time
from PIL import ImageTk, Image
UNIT = 100 # pixels
HEIGHT = 5 # grid height
WIDTH = 5 # grid width
class Env(tk.Tk):
def __init__(self):
super(Env, self).__init__()
self.action_space = ['u', 'd', 'l', 'r']
self.n_actions = len(self.action_space)
self.title('q learning')
self.geometry('{0}x{1}'.format(HEIGHT * UNIT, HEIGHT * UNIT))
self.buildGraphic()
self.texts = []
def buildGraphic(self):
self.canvas = tk.Canvas(self, bg='white',
height=HEIGHT * UNIT,
width=WIDTH * UNIT)
# create grids
for c in range(0, WIDTH * UNIT, UNIT): # 0~400 by 80
x0, y0, x1, y1 = c, 0, c, HEIGHT * UNIT
self.canvas.create_line(x0, y0, x1, y1)
for r in range(0, HEIGHT * UNIT, UNIT): # 0~400 by 80
x0, y0, x1, y1 = 0, r, HEIGHT * UNIT, r
self.canvas.create_line(x0, y0, x1, y1)
# image_load
self.rectangle_image = ImageTk.PhotoImage(Image.open("../resources/rectangle.png").resize((65, 65), Image.ANTIALIAS))
self.triangle_image = ImageTk.PhotoImage(Image.open("../resources/triangle.png").resize((65, 65)))
self.circle_image = ImageTk.PhotoImage(Image.open("../resources/circle.png").resize((65, 65)))
# add image to canvas
self.cat = self.canvas.create_image(50, 50, image=self.rectangle_image)
self.triangle1 = self.canvas.create_image(250, 150, image=self.triangle_image)
self.triangle2 = self.canvas.create_image(150, 250, image=self.triangle_image)
self.circle = self.canvas.create_image(250, 250, image=self.circle_image)
# pack all
self.canvas.pack()
def reset(self):
self.update()
time.sleep(0.5)
self.canvas.delete(self.cat)
origin = np.array([UNIT / 2, UNIT / 2])
self.cat = self.canvas.create_image(50, 50, image=self.rectangle_image)
# return observation
return self.coords_to_state(self.canvas.coords(self.cat))
def text_value(self, row, col, contents, action, font='Helvetica', size=10, style='normal', anchor="nw"):
if action == 0:
origin_x, origin_y = 7, 42
elif action == 1:
origin_x, origin_y = 85, 42
elif action == 2:
origin_x, origin_y = 42, 5
else:
origin_x, origin_y = 42, 77
x, y = origin_y + (UNIT * col), origin_x + (UNIT * row)
font = (font, str(size), style)
return self.texts.append(self.canvas.create_text(x, y, fill="black", text=contents, font=font, anchor=anchor))
def print_value_all(self, q_table):
for i in self.texts:
self.canvas.delete(i)
self.texts.clear()
for i in range(HEIGHT):
for j in range(WIDTH):
for action in range(0, 4):
state = [i, j]
if str(state) in q_table.index:
temp = q_table.ix[str(state), action]
self.text_value(j, i, round(temp, 2), action)
def coords_to_state(self, coords):
x = int((coords[0] - 50) / 100)
y = int((coords[1] - 50) / 100)
return [x, y]
def state_to_coords(self, state):
x = int(state[0] * 100 + 50)
y = int(state[1] * 100 + 50)
return [x, y]
def step(self, action):
s = self.canvas.coords(self.cat)
base_action = np.array([0, 0])
self.render()
if action == 0: # up
if s[1] > UNIT:
base_action[1] -= UNIT
elif action == 1: # down
if s[1] < (HEIGHT - 1) * UNIT:
base_action[1] += UNIT
elif action == 2: # left
if s[0] > UNIT:
base_action[0] -= UNIT
elif action == 3: # right
if s[0] < (WIDTH - 1) * UNIT:
base_action[0] += UNIT
self.canvas.move(self.cat, base_action[0], base_action[1]) # move agent
s_ = self.canvas.coords(self.cat) # next state
# reward function
if s_ == self.canvas.coords(self.circle):
reward = 100
done = True
elif s_ in [self.canvas.coords(self.triangle1), self.canvas.coords(self.triangle2)]:
reward = -100
done = True
else:
reward = 0
done = False
s_ = self.coords_to_state(s_)
return s_, reward, done
def render(self):
time.sleep(0.05)
self.update()
from environment import Env
from agent import QLearning
def update():
for episode in range(1000):
# 環境の初期化と環境から現在の状態受け来る
state = env.reset()
while True:
# Guiレンダリング
env.render()
# エージェントからその状態の行動を受けとる
action = agent.get_action(str(state))
# エージェントの行動をとって、次の状態と報酬とのエピソードが終了したかどうかを受けとる
state_, reward, done = env.step(action)
# エージェントのlearn関数S A R S_に代入する
agent.learn(str(state), action, reward, str(state_))
# 現在の状態で、次の状態を代入
state = state_
env.print_value_all(agent.q_table)
# エピソードが終了し、break
if done:
break
# すべてのエピソードが終わったらゲームオーバー
print('game over')
# env.destroy()
if __name__ == "__main__":
env = Env()
agent = QLearning(actions=list(range(env.n_actions)))
update()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment