Skip to content

Instantly share code, notes, and snippets.

@peace098beat
Created October 5, 2015 09:13
Show Gist options
  • Save peace098beat/f8045fd0a7f2a12498f5 to your computer and use it in GitHub Desktop.
Save peace098beat/f8045fd0a7f2a12498f5 to your computer and use it in GitHub Desktop.
[map-qt.py ver2] textの追加
# -*- coding:utf-8 -+-
"""
Q-Learning
強化学習で迷路を探索する
Pythonで迷路探索の学習コード。しかしPygameが必要。
http://qiita.com/hogefugabar/items/74bed2851a84e978b61c
PySideで動かす。
"""
import sys
import os
# ユーザーモジュール
import const
from state import *
from agent import *
# PySide系モジュール
from PySide.QtGui import *
from PySide.QtCore import *
# フィールドの行数
NUM_ROW = const.NUM_ROW
NUM_COL = const.NUM_COL
# フィールドの種類 kind
START, GOAL, ROAD, WALL = [0, 1, 2, 3]
# 行動の種類
DIRECT = [LEFT, RIGHT, UP, DOWN] = [0, 1, 2, 3]
INTERVAL_TIME = 0.1
def printMapKind(field):
print '*** state kind ***'
for x in range(NUM_ROW):
cols = [field[x][y].k for y in range(NUM_COL)]
print cols
def printMapReward(field):
print '*** state reward ***'
for x in range(NUM_ROW):
cols = [field[x][y].r for y in range(NUM_COL)]
print cols
def printMapMaxQ(field):
print '*** state q ***'
for x in range(NUM_ROW):
cols = [field[x][y].get_max_q_action().q for y in range(NUM_COL)]
print cols
class Map(object):
def __init__(self):
# (トリアエズ)初期化
self.field = [[State(GOAL) for x in range(NUM_COL)] for y in range(NUM_ROW)]
# エージェント
self.agent = Agent()
# 今の環境オブジェクト
self.state = self.field[self.agent.y][self.agent.x]
self.generation = 0
self.run = True
self.cursor = [0][0]
# 初期化
self.clear()
# 行動
self.set_all_possibility()
def qdraw(self):
pass
def qupdate(self):
if self.run == True:
self.step()
# 学習ステップ
def step(self):
self.generation += 1
# print '** Step'
# print 'Agent x,y:', [self.agent.y, self.agent.x]
# a = phy(S) 現在の環境から、Q値が最大となる行動オブジェクトを取得
self.action = self.state.action_select()
# agent move 行動に従いエージェントの位置を更新
self.agent.move(self.action)
# Q(Sdot) 今いる場所での環境でQを更新
self.state = self.field[self.agent.y][self.agent.x]
self.action.update_q_value(self.state)
# ゴールしたら、初期位置に戻す
if self.state.k == 1:
self.agent.move_start()
self.state = self.field[self.agent.y][self.agent.x]
def clear(self):
# 世代の初期化
self.generation = 0
# エージェントの初期化
self.agent = Agent()
# 環境の初期化
self.state = self.field[self.agent.y][self.agent.x]
# 報酬の初期化
for row in range(NUM_ROW):
for col in range(NUM_COL):
self.field[row][col] = State(const.FIELD[row][col])
if const.FIELD[row][col] == GOAL:
self.field[row][col].r = 100
def set_possibility(self, row, col):
# No Wall
NO_WALL = [ROAD, START, GOAL]
# テーブル関数の各座標に
if self.field[row][col].k in NO_WALL:
# その周りに動けるかチェック(壁以外だとTrue)
if self.field[row][col - 1].k in NO_WALL:
self.field[row][col].set_action(LEFT, True)
if self.field[row][col + 1].k in NO_WALL:
self.field[row][col].set_action(RIGHT, True)
if self.field[row - 1][col].k in NO_WALL:
self.field[row][col].set_action(UP, True)
if self.field[row + 1][col].k in NO_WALL:
self.field[row][col].set_action(DOWN, True)
def set_all_possibility(self):
for row in range(1, NUM_ROW - 1):
for col in range(1, NUM_COL - 1):
self.set_possibility(row, col)
class gameWindow(QWidget, Map):
width = NUM_ROW
height = NUM_COL
Margin = 20
interval_time = INTERVAL_TIME
def __init__(self, parent=None):
QWidget.__init__(self)
Map.__init__(self)
self.cursor_pos = [0, 0]
self.iter_num = 0
self.debugText = 'debug:\n'
self.pixmap = QPixmap(self.size())
# Pixmapの初期化
self.refreshPixmap()
painter = QPainter(self.pixmap)
self.drawGrid(painter)
# メインループ
self.timer = QTimer()
self.timer.timeout.connect(self.mainloop)
self.timer.start(INTERVAL_TIME)
print self.cursor_pos[0], self.cursor_pos[1]
# **************************** #
# メインループ
# **************************** #
def mainloop(self):
self.iter_num += 1
# print '** Iteration', self.iter_num
self.qupdate()
self.cursor_pos = [self.agent.y, self.agent.x]
# print 'Cursor:', self.cursor
# print 'Cursor pos:', self.cursor_pos
# printMapMaxQ(self.field)
self.update()
pass
def paintEvent(self, *args, **kwargs):
# print 'paint'
painter = QStylePainter(self)
painter.drawPixmap(0, 0, self.pixmap)
#
self.drawGrid(painter)
# Cursorの描画
dx, dy = self.Margin, self.Margin
painter.setBrush(QBrush(Qt.blue, Qt.SolidPattern))
painter.drawRect(self.cursor_pos[0] * dx, self.cursor_pos[1] * dy, dx, dy)
# -- 描画位置
x = 0.1 * self.size().width
y = 0.7 * self.height()
dx = 0.3 * self.width()
dy = 0.3 * self.height()
# x,y,dx,dy = 10,10,1000,1000
# -- 描画文字
self.debugText = 'debut'
label = self.debugText
painter.setFont(QFont('Helvetica [Cronyx]', 13, QFont.Bold))
painter.drawText(x, y, dx, dy, Qt.AlignLeft | Qt.AlignTop, unicode('Debug text:\n%s' % label))
def refreshPixmap(self):
self.pixmap = QPixmap(self.size())
self.pixmap.fill(self, 0, 0)
painter = QPainter(self.pixmap)
painter.initFrom(self)
def drawGrid(self, painter):
dx = self.Margin
dy = self.Margin
for x in range(NUM_ROW):
for y in range(NUM_COL):
if self.field[x][y].k == WALL:
painter.setBrush(QBrush(Qt.black, Qt.SolidPattern))
elif self.field[x][y].k == ROAD:
painter.setBrush(QBrush(Qt.white, Qt.SolidPattern))
if self.field[x][y].get_max_q_action().q != 0:
qval = self.field[x][y].get_max_q_action().q / 100.0 * 255.0
painter.setBrush(QBrush(QColor(qval, 255 - qval, 255 - qval, 127), Qt.SolidPattern))
elif self.field[x][y].k == GOAL:
painter.setBrush(QBrush(Qt.red, Qt.SolidPattern))
elif self.field[x][y].k == START:
painter.setBrush(QBrush(Qt.green, Qt.SolidPattern))
painter.drawRect(x * dx, y * dy, dx, dy)
def sizeHint(self):
return QSize(self.width * self.Margin, self.height * self.Margin)
def keyPressEvent(self, event):
e = event.key()
if e == Qt.Key_Up:
self.cursor_pos[1] -= 1
if self.cursor_pos[1] < 0:
self.cursor_pos[1] = 0
elif e == Qt.Key_Down:
self.cursor_pos[1] += 1
if self.cursor_pos[1] > NUM_COL - 1:
self.cursor_pos[1] = NUM_COL - 1
elif e == Qt.Key_Left:
self.cursor_pos[0] -= 1
if self.cursor_pos[0] < 0:
self.cursor_pos[0] = 0
elif e == Qt.Key_Right:
self.cursor_pos[0] += 1
if self.cursor_pos[0] > NUM_ROW - 1:
self.cursor_pos[0] = NUM_ROW - 1
elif e == Qt.Key_Plus:
self.interval_time -= 20
if self.interval_time < 1:
self.interval_time = 1
self.timer.setInterval(self.interval_time)
# self.timer.setInterval(1000)
elif e == Qt.Key_Minus:
self.interval_time += 20
if self.interval_time > 200:
self.interval_time = 200
self.timer.setInterval(self.interval_time)
# self.timer.setInterval(1)
else:
pass
print 'INTERVAL_TIME', self.interval_time
self.update()
if __name__ == '__main__':
# Map()
app = QApplication(sys.argv)
win = gameWindow()
win.show()
sys.exit(app.exec_())
print 'Fin .. map-qt.py'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment