Created
          October 5, 2015 09:13 
        
      - 
      
 - 
        
Save peace098beat/f8045fd0a7f2a12498f5 to your computer and use it in GitHub Desktop.  
    [map-qt.py ver2] textの追加
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # -*- coding:utf-8 -+- | |
| """ | |
| Q-Learning | |
| 強化学習で迷路を探索する | |
| Pythonで迷路探索の学習コード。しかしPygameが必要。 | |
| http://qiita.com/hogefugabar/items/74bed2851a84e978b61c | |
| PySideで動かす。 | |
| """ | |
| import sys | |
| import os | |
| # ユーザーモジュール | |
| import const | |
| from state import * | |
| from agent import * | |
| # PySide系モジュール | |
| from PySide.QtGui import * | |
| from PySide.QtCore import * | |
| # フィールドの行数 | |
| NUM_ROW = const.NUM_ROW | |
| NUM_COL = const.NUM_COL | |
| # フィールドの種類 kind | |
| START, GOAL, ROAD, WALL = [0, 1, 2, 3] | |
| # 行動の種類 | |
| DIRECT = [LEFT, RIGHT, UP, DOWN] = [0, 1, 2, 3] | |
| INTERVAL_TIME = 0.1 | |
| def printMapKind(field): | |
| print '*** state kind ***' | |
| for x in range(NUM_ROW): | |
| cols = [field[x][y].k for y in range(NUM_COL)] | |
| print cols | |
| def printMapReward(field): | |
| print '*** state reward ***' | |
| for x in range(NUM_ROW): | |
| cols = [field[x][y].r for y in range(NUM_COL)] | |
| print cols | |
| def printMapMaxQ(field): | |
| print '*** state q ***' | |
| for x in range(NUM_ROW): | |
| cols = [field[x][y].get_max_q_action().q for y in range(NUM_COL)] | |
| print cols | |
| class Map(object): | |
| def __init__(self): | |
| # (トリアエズ)初期化 | |
| self.field = [[State(GOAL) for x in range(NUM_COL)] for y in range(NUM_ROW)] | |
| # エージェント | |
| self.agent = Agent() | |
| # 今の環境オブジェクト | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| self.generation = 0 | |
| self.run = True | |
| self.cursor = [0][0] | |
| # 初期化 | |
| self.clear() | |
| # 行動 | |
| self.set_all_possibility() | |
| def qdraw(self): | |
| pass | |
| def qupdate(self): | |
| if self.run == True: | |
| self.step() | |
| # 学習ステップ | |
| def step(self): | |
| self.generation += 1 | |
| # print '** Step' | |
| # print 'Agent x,y:', [self.agent.y, self.agent.x] | |
| # a = phy(S) 現在の環境から、Q値が最大となる行動オブジェクトを取得 | |
| self.action = self.state.action_select() | |
| # agent move 行動に従いエージェントの位置を更新 | |
| self.agent.move(self.action) | |
| # Q(Sdot) 今いる場所での環境でQを更新 | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| self.action.update_q_value(self.state) | |
| # ゴールしたら、初期位置に戻す | |
| if self.state.k == 1: | |
| self.agent.move_start() | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| def clear(self): | |
| # 世代の初期化 | |
| self.generation = 0 | |
| # エージェントの初期化 | |
| self.agent = Agent() | |
| # 環境の初期化 | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| # 報酬の初期化 | |
| for row in range(NUM_ROW): | |
| for col in range(NUM_COL): | |
| self.field[row][col] = State(const.FIELD[row][col]) | |
| if const.FIELD[row][col] == GOAL: | |
| self.field[row][col].r = 100 | |
| def set_possibility(self, row, col): | |
| # No Wall | |
| NO_WALL = [ROAD, START, GOAL] | |
| # テーブル関数の各座標に | |
| if self.field[row][col].k in NO_WALL: | |
| # その周りに動けるかチェック(壁以外だとTrue) | |
| if self.field[row][col - 1].k in NO_WALL: | |
| self.field[row][col].set_action(LEFT, True) | |
| if self.field[row][col + 1].k in NO_WALL: | |
| self.field[row][col].set_action(RIGHT, True) | |
| if self.field[row - 1][col].k in NO_WALL: | |
| self.field[row][col].set_action(UP, True) | |
| if self.field[row + 1][col].k in NO_WALL: | |
| self.field[row][col].set_action(DOWN, True) | |
| def set_all_possibility(self): | |
| for row in range(1, NUM_ROW - 1): | |
| for col in range(1, NUM_COL - 1): | |
| self.set_possibility(row, col) | |
| class gameWindow(QWidget, Map): | |
| width = NUM_ROW | |
| height = NUM_COL | |
| Margin = 20 | |
| interval_time = INTERVAL_TIME | |
| def __init__(self, parent=None): | |
| QWidget.__init__(self) | |
| Map.__init__(self) | |
| self.cursor_pos = [0, 0] | |
| self.iter_num = 0 | |
| self.debugText = 'debug:\n' | |
| self.pixmap = QPixmap(self.size()) | |
| # Pixmapの初期化 | |
| self.refreshPixmap() | |
| painter = QPainter(self.pixmap) | |
| self.drawGrid(painter) | |
| # メインループ | |
| self.timer = QTimer() | |
| self.timer.timeout.connect(self.mainloop) | |
| self.timer.start(INTERVAL_TIME) | |
| print self.cursor_pos[0], self.cursor_pos[1] | |
| # **************************** # | |
| # メインループ | |
| # **************************** # | |
| def mainloop(self): | |
| self.iter_num += 1 | |
| # print '** Iteration', self.iter_num | |
| self.qupdate() | |
| self.cursor_pos = [self.agent.y, self.agent.x] | |
| # print 'Cursor:', self.cursor | |
| # print 'Cursor pos:', self.cursor_pos | |
| # printMapMaxQ(self.field) | |
| self.update() | |
| pass | |
| def paintEvent(self, *args, **kwargs): | |
| # print 'paint' | |
| painter = QStylePainter(self) | |
| painter.drawPixmap(0, 0, self.pixmap) | |
| # | |
| self.drawGrid(painter) | |
| # Cursorの描画 | |
| dx, dy = self.Margin, self.Margin | |
| painter.setBrush(QBrush(Qt.blue, Qt.SolidPattern)) | |
| painter.drawRect(self.cursor_pos[0] * dx, self.cursor_pos[1] * dy, dx, dy) | |
| # -- 描画位置 | |
| x = 0.1 * self.size().width | |
| y = 0.7 * self.height() | |
| dx = 0.3 * self.width() | |
| dy = 0.3 * self.height() | |
| # x,y,dx,dy = 10,10,1000,1000 | |
| # -- 描画文字 | |
| self.debugText = 'debut' | |
| label = self.debugText | |
| painter.setFont(QFont('Helvetica [Cronyx]', 13, QFont.Bold)) | |
| painter.drawText(x, y, dx, dy, Qt.AlignLeft | Qt.AlignTop, unicode('Debug text:\n%s' % label)) | |
| def refreshPixmap(self): | |
| self.pixmap = QPixmap(self.size()) | |
| self.pixmap.fill(self, 0, 0) | |
| painter = QPainter(self.pixmap) | |
| painter.initFrom(self) | |
| def drawGrid(self, painter): | |
| dx = self.Margin | |
| dy = self.Margin | |
| for x in range(NUM_ROW): | |
| for y in range(NUM_COL): | |
| if self.field[x][y].k == WALL: | |
| painter.setBrush(QBrush(Qt.black, Qt.SolidPattern)) | |
| elif self.field[x][y].k == ROAD: | |
| painter.setBrush(QBrush(Qt.white, Qt.SolidPattern)) | |
| if self.field[x][y].get_max_q_action().q != 0: | |
| qval = self.field[x][y].get_max_q_action().q / 100.0 * 255.0 | |
| painter.setBrush(QBrush(QColor(qval, 255 - qval, 255 - qval, 127), Qt.SolidPattern)) | |
| elif self.field[x][y].k == GOAL: | |
| painter.setBrush(QBrush(Qt.red, Qt.SolidPattern)) | |
| elif self.field[x][y].k == START: | |
| painter.setBrush(QBrush(Qt.green, Qt.SolidPattern)) | |
| painter.drawRect(x * dx, y * dy, dx, dy) | |
| def sizeHint(self): | |
| return QSize(self.width * self.Margin, self.height * self.Margin) | |
| def keyPressEvent(self, event): | |
| e = event.key() | |
| if e == Qt.Key_Up: | |
| self.cursor_pos[1] -= 1 | |
| if self.cursor_pos[1] < 0: | |
| self.cursor_pos[1] = 0 | |
| elif e == Qt.Key_Down: | |
| self.cursor_pos[1] += 1 | |
| if self.cursor_pos[1] > NUM_COL - 1: | |
| self.cursor_pos[1] = NUM_COL - 1 | |
| elif e == Qt.Key_Left: | |
| self.cursor_pos[0] -= 1 | |
| if self.cursor_pos[0] < 0: | |
| self.cursor_pos[0] = 0 | |
| elif e == Qt.Key_Right: | |
| self.cursor_pos[0] += 1 | |
| if self.cursor_pos[0] > NUM_ROW - 1: | |
| self.cursor_pos[0] = NUM_ROW - 1 | |
| elif e == Qt.Key_Plus: | |
| self.interval_time -= 20 | |
| if self.interval_time < 1: | |
| self.interval_time = 1 | |
| self.timer.setInterval(self.interval_time) | |
| # self.timer.setInterval(1000) | |
| elif e == Qt.Key_Minus: | |
| self.interval_time += 20 | |
| if self.interval_time > 200: | |
| self.interval_time = 200 | |
| self.timer.setInterval(self.interval_time) | |
| # self.timer.setInterval(1) | |
| else: | |
| pass | |
| print 'INTERVAL_TIME', self.interval_time | |
| self.update() | |
| if __name__ == '__main__': | |
| # Map() | |
| app = QApplication(sys.argv) | |
| win = gameWindow() | |
| win.show() | |
| sys.exit(app.exec_()) | |
| print 'Fin .. map-qt.py' | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment