Created
          October 5, 2015 04:58 
        
      - 
      
 - 
        
Save peace098beat/bb197ebd1d82470f3b13 to your computer and use it in GitHub Desktop.  
    [強化学習] QLearn
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # -*- coding:utf-8 -+- | |
| """ | |
| action.py | |
| """ | |
| import const | |
| class Action: | |
| def __init__(self, direction, possibility): | |
| self.q_value = 0.00 | |
| # 移動の可否フラグ | |
| self.possibility = possibility | |
| # 行動パターン | |
| self.direction = direction | |
| # Q値 | |
| def getq(self): | |
| return self.q_value | |
| def setq(self, value): | |
| self.q_value = value | |
| q = property(getq, setq) | |
| # Direction | |
| def getd(self): | |
| return self.direction | |
| def setd(self, value): | |
| self.direction = value | |
| d = property(getd, setd) | |
| # Possibiliity | |
| def getp(self): | |
| return self.possibility | |
| def setp(self, value): | |
| self.possibility = value | |
| p = property(getp, setp) | |
| def set_possibility(self, direction, possibility): | |
| self.possibility = possibility | |
| self.direction = direction | |
| # **************** # | |
| # Q値の更新 | |
| # **************** # | |
| def update_q_value(self, state): | |
| alpha = const.LEARNING_RATE | |
| gamma = const.DISCOUNT_RATE | |
| # 報酬 | |
| r = state.r | |
| # Qの最大値を返す | |
| max_q = state.get_max_q_action().q | |
| # print 'max_q',max_q | |
| # Q学習更新式 | |
| self.q_value = self.q_value + alpha * (r + gamma * max_q - self.q_value) | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # -*- coding:utf-8 -+- | |
| """ | |
| agent.py | |
| エージェント | |
| 現在地と移動に関する情報を保持 | |
| """ | |
| import const | |
| class Agent: | |
| def __init__(self): | |
| # マップ上の位置をメンバにセット(初期位置で初期化) | |
| self.current_x = const.START_X | |
| self.current_y = const.START_Y | |
| # ポジション x のプロパティ | |
| def getx(self): | |
| return self.current_x | |
| def setx(self, value): | |
| self.current_x = value | |
| x = property(getx, setx) | |
| # ポジション y のプロパティ | |
| def gety(self): | |
| return self.current_y | |
| def sety(self, value): | |
| self.current_y = value | |
| y = property(gety, sety) | |
| # 動作判定 | |
| def move(self,action): | |
| # 移動の可否フラグ | |
| poss = action.p | |
| # 移動パターン | |
| direction = action.d | |
| if direction==0 and poss==True: | |
| self.current_x -= 1 | |
| elif direction==1 and poss==True: | |
| self.current_x += 1 | |
| elif direction==2 and poss==True: | |
| self.current_y -= 1 | |
| elif direction==3 and poss==True: | |
| self.current_y += 1 | |
| # 初期位置にリセット | |
| def move_start(self): | |
| self.current_x = const.START_X | |
| self.current_y = const.START_Y | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # -*- coding:utf-8 -+- | |
| """ | |
| const.py | |
| ユーザ定義の定数 | |
| """ | |
| # ε: ランダム率 | |
| EPSILON_RATE = 0.4 | |
| # 学習率 | |
| LEARNING_RATE = 0.1 | |
| # 割引率 | |
| DISCOUNT_RATE = 0.9 | |
| # スクリーンサイズ | |
| SCR_X = 45 * 13 # スクリーンサイズ | |
| SCR_Y = 45 * 10 # スクリーンサイズ | |
| # セルのサイズ | |
| CS = 45 | |
| # マップの行数、列数 | |
| NUM_ROW = SCR_Y / CS | |
| NUM_COL = SCR_X / CS | |
| # 開始位置 | |
| START_X = 2 | |
| START_Y = 2 | |
| """ | |
| FIELD = [[2 for x in range(NUM_COL)] for y in range(NUM_ROW)] | |
| for y in range(NUM_ROW): | |
| for x in range(NUM_COL): | |
| if x==0 or y==0 or x==NUM_COL-1 or y==NUM_ROW-1: | |
| FIELD[y][x] = 3 | |
| """ | |
| # row0 = [3,3,3,3,3,3] | |
| # row1 = [3,2,2,2,2,3] | |
| # row2 = [3,2,3,3,1,3] | |
| # row3 = [3,3,3,3,3,3] | |
| # row0 = [3,3,3,3,3,3,3] | |
| # row1 = [3,2,2,2,2,3,3] | |
| # row2 = [3,2,2,3,2,2,3] | |
| # row3 = [3,2,3,2,3,2,3] | |
| # row4 = [3,2,2,2,3,2,3] | |
| # row5 = [3,2,3,2,3,2,3] | |
| # row6 = [3,2,3,2,2,2,3] | |
| # row7 = [3,2,3,2,3,2,3] | |
| # row8 = [3,2,2,2,3,1,3] | |
| # row9 = [3,3,3,3,3,3,3] | |
| row0 = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] | |
| row1 = [3, 2, 2, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3] | |
| row2 = [3, 2, 2, 3, 2, 2, 2, 2, 2, 2, 3, 2, 3] | |
| row3 = [3, 2, 3, 2, 3, 3, 3, 2, 3, 2, 2, 2, 3] | |
| row4 = [3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 3, 2, 3] | |
| row5 = [3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 2, 2, 3] | |
| row6 = [3, 2, 3, 2, 2, 2, 3, 2, 2, 2, 3, 2, 3] | |
| row7 = [3, 2, 3, 2, 3, 2, 2, 2, 3, 2, 3, 2, 3] | |
| row8 = [3, 2, 2, 2, 3, 2, 3, 2, 2, 2, 3, 1, 3] | |
| row9 = [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3] | |
| FIELD = [row0, row1, row2, row3, row4, row5, row6, row7, row8, row9] | |
| # FIELD = [row0,row1,row2,row3] | |
| # ****************** | |
| import random | |
| NUM_ROW = 40 | |
| NUM_COL = NUM_ROW | |
| NUM_WALL = NUM_ROW*8 | |
| # ROAD | |
| FIELD = [[2 for x in range(NUM_COL)] for y in range(NUM_ROW)] | |
| for y in range(NUM_ROW): | |
| for x in range(NUM_COL): | |
| # ROAD | |
| FIELD[y][x] = 2 | |
| if x == 0 or y == 0 or x == NUM_COL - 1 or y == NUM_ROW - 1: | |
| # WALL | |
| FIELD[y][x] = 3 | |
| for n in range(NUM_WALL): | |
| y = random.randint(1, NUM_ROW - 2) | |
| x = random.randint(1, NUM_COL - 2) | |
| FIELD[y][x] = 3 | |
| # START | |
| FIELD[START_X][START_Y] = 0 | |
| # GOAL | |
| FIELD[NUM_ROW - 2][NUM_COL - 2] = 1 | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # -*- coding:utf-8 -+- | |
| """ | |
| Q-Learning | |
| 強化学習で迷路を探索する | |
| Pythonで迷路探索の学習コード。しかしPygameが必要。 | |
| http://qiita.com/hogefugabar/items/74bed2851a84e978b61c | |
| PySideで動かす。 | |
| """ | |
| import sys | |
| import os | |
| # ユーザーモジュール | |
| import const | |
| from state import * | |
| from agent import * | |
| # PySide系モジュール | |
| from PySide.QtGui import * | |
| from PySide.QtCore import * | |
| # フィールドの行数 | |
| NUM_ROW = const.NUM_ROW | |
| NUM_COL = const.NUM_COL | |
| # フィールドの種類 kind | |
| START, GOAL, ROAD, WALL = [0, 1, 2, 3] | |
| # 行動の種類 | |
| DIRECT = [LEFT, RIGHT, UP, DOWN] = [0, 1, 2, 3] | |
| INTERVAL_TIME = 0.1 | |
| def printMapKind(field): | |
| print '*** state kind ***' | |
| for x in range(NUM_ROW): | |
| cols = [field[x][y].k for y in range(NUM_COL)] | |
| print cols | |
| def printMapReward(field): | |
| print '*** state reward ***' | |
| for x in range(NUM_ROW): | |
| cols = [field[x][y].r for y in range(NUM_COL)] | |
| print cols | |
| def printMapMaxQ(field): | |
| print '*** state q ***' | |
| for x in range(NUM_ROW): | |
| cols = [field[x][y].get_max_q_action().q for y in range(NUM_COL)] | |
| print cols | |
| class Map(object): | |
| def __init__(self): | |
| # (トリアエズ)初期化 | |
| self.field = [[State(GOAL) for x in range(NUM_COL)] for y in range(NUM_ROW)] | |
| # エージェント | |
| self.agent = Agent() | |
| # 今の環境オブジェクト | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| self.generation = 0 | |
| self.run = True | |
| self.cursor = [0][0] | |
| # 初期化 | |
| self.clear() | |
| # 行動 | |
| self.set_all_possibility() | |
| def qdraw(self): | |
| pass | |
| def qupdate(self): | |
| if self.run == True: | |
| self.step() | |
| # 学習ステップ | |
| def step(self): | |
| self.generation += 1 | |
| # print '** Step' | |
| # print 'Agent x,y:', [self.agent.y, self.agent.x] | |
| # a = phy(S) 現在の環境から、Q値が最大となる行動オブジェクトを取得 | |
| self.action = self.state.action_select() | |
| # agent move 行動に従いエージェントの位置を更新 | |
| self.agent.move(self.action) | |
| # Q(Sdot) 今いる場所での環境でQを更新 | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| self.action.update_q_value(self.state) | |
| # ゴールしたら、初期位置に戻す | |
| if self.state.k == 1: | |
| self.agent.move_start() | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| def clear(self): | |
| # 世代の初期化 | |
| self.generation = 0 | |
| # エージェントの初期化 | |
| self.agent = Agent() | |
| # 環境の初期化 | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| # 報酬の初期化 | |
| for row in range(NUM_ROW): | |
| for col in range(NUM_COL): | |
| self.field[row][col] = State(const.FIELD[row][col]) | |
| if const.FIELD[row][col] == GOAL: | |
| self.field[row][col].r = 100 | |
| def set_possibility(self, row, col): | |
| # No Wall | |
| NO_WALL = [ROAD, START, GOAL] | |
| # テーブル関数の各座標に | |
| if self.field[row][col].k in NO_WALL: | |
| # その周りに動けるかチェック(壁以外だとTrue) | |
| if self.field[row][col - 1].k in NO_WALL: | |
| self.field[row][col].set_action(LEFT, True) | |
| if self.field[row][col + 1].k in NO_WALL: | |
| self.field[row][col].set_action(RIGHT, True) | |
| if self.field[row - 1][col].k in NO_WALL: | |
| self.field[row][col].set_action(UP, True) | |
| if self.field[row + 1][col].k in NO_WALL: | |
| self.field[row][col].set_action(DOWN, True) | |
| def set_all_possibility(self): | |
| for row in range(1, NUM_ROW - 1): | |
| for col in range(1, NUM_COL - 1): | |
| self.set_possibility(row, col) | |
| class gameWindow(QWidget, Map): | |
| width = NUM_ROW | |
| height = NUM_COL | |
| Margin = 20 | |
| interval_time = INTERVAL_TIME | |
| def __init__(self, parent=None): | |
| QWidget.__init__(self) | |
| Map.__init__(self) | |
| self.cursor_pos = [0, 0] | |
| self.iter_num = 0 | |
| self.pixmap = QPixmap(self.size()) | |
| # Pixmapの初期化 | |
| self.refreshPixmap() | |
| painter = QPainter(self.pixmap) | |
| self.drawGrid(painter) | |
| # メインループ | |
| self.timer = QTimer() | |
| self.timer.timeout.connect(self.mainloop) | |
| self.timer.start(INTERVAL_TIME) | |
| print self.cursor_pos[0], self.cursor_pos[1] | |
| # **************************** # | |
| # メインループ | |
| # **************************** # | |
| def mainloop(self): | |
| self.iter_num += 1 | |
| # print '** Iteration', self.iter_num | |
| self.qupdate() | |
| self.cursor_pos = [self.agent.y, self.agent.x] | |
| # print 'Cursor:', self.cursor | |
| # print 'Cursor pos:', self.cursor_pos | |
| # printMapMaxQ(self.field) | |
| self.update() | |
| pass | |
| def paintEvent(self, *args, **kwargs): | |
| # print 'paint' | |
| painter = QStylePainter(self) | |
| painter.drawPixmap(0, 0, self.pixmap) | |
| # | |
| self.drawGrid(painter) | |
| # Cursorの描画 | |
| dx, dy = self.Margin, self.Margin | |
| painter.setBrush(QBrush(Qt.blue, Qt.SolidPattern)) | |
| painter.drawRect(self.cursor_pos[0] * dx, self.cursor_pos[1] * dy, dx, dy) | |
| def refreshPixmap(self): | |
| self.pixmap = QPixmap(self.size()) | |
| self.pixmap.fill(self, 0, 0) | |
| painter = QPainter(self.pixmap) | |
| painter.initFrom(self) | |
| def drawGrid(self, painter): | |
| dx = self.Margin | |
| dy = self.Margin | |
| for x in range(NUM_ROW): | |
| for y in range(NUM_COL): | |
| if self.field[x][y].k == WALL: | |
| painter.setBrush(QBrush(Qt.black, Qt.SolidPattern)) | |
| elif self.field[x][y].k == ROAD: | |
| painter.setBrush(QBrush(Qt.white, Qt.SolidPattern)) | |
| if self.field[x][y].get_max_q_action().q != 0: | |
| qval = self.field[x][y].get_max_q_action().q / 100.0 *255.0 | |
| painter.setBrush(QBrush(QColor(qval, 255-qval,255-qval, 127), Qt.SolidPattern)) | |
| elif self.field[x][y].k == GOAL: | |
| painter.setBrush(QBrush(Qt.red, Qt.SolidPattern)) | |
| elif self.field[x][y].k == START: | |
| painter.setBrush(QBrush(Qt.green, Qt.SolidPattern)) | |
| painter.drawRect(x * dx, y * dy, dx, dy) | |
| def sizeHint(self): | |
| return QSize(self.width * self.Margin, self.height * self.Margin) | |
| def keyPressEvent(self, event): | |
| e = event.key() | |
| if e == Qt.Key_Up: | |
| self.cursor_pos[1] -= 1 | |
| if self.cursor_pos[1] < 0: | |
| self.cursor_pos[1] = 0 | |
| elif e == Qt.Key_Down: | |
| self.cursor_pos[1] += 1 | |
| if self.cursor_pos[1] > NUM_COL - 1: | |
| self.cursor_pos[1] = NUM_COL - 1 | |
| elif e == Qt.Key_Left: | |
| self.cursor_pos[0] -= 1 | |
| if self.cursor_pos[0] < 0: | |
| self.cursor_pos[0] = 0 | |
| elif e == Qt.Key_Right: | |
| self.cursor_pos[0] += 1 | |
| if self.cursor_pos[0] > NUM_ROW - 1: | |
| self.cursor_pos[0] = NUM_ROW - 1 | |
| elif e == Qt.Key_Plus: | |
| self.interval_time -= 20 | |
| if self.interval_time < 1: | |
| self.interval_time = 1 | |
| self.timer.setInterval(self.interval_time) | |
| # self.timer.setInterval(1000) | |
| elif e == Qt.Key_Minus: | |
| self.interval_time += 20 | |
| if self.interval_time > 200: | |
| self.interval_time = 200 | |
| self.timer.setInterval(self.interval_time) | |
| # self.timer.setInterval(1) | |
| else: | |
| pass | |
| print 'INTERVAL_TIME', self.interval_time | |
| self.update() | |
| if __name__ == '__main__': | |
| # Map() | |
| app = QApplication(sys.argv) | |
| win = gameWindow() | |
| win.show() | |
| sys.exit(app.exec_()) | |
| print 'Fin .. map-qt.py' | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # -*- coding:utf-8 -+- | |
| """ | |
| Q-Learning | |
| 強化学習で迷路を探索する | |
| Pythonで迷路探索の学習コード。しかしPygameが必要。 | |
| http://qiita.com/hogefugabar/items/74bed2851a84e978b61c | |
| """ | |
| import const | |
| import pygame | |
| from pygame.locals import * | |
| import random | |
| import sys | |
| from state import * | |
| from agent import * | |
| SCR_RECT = Rect(0, 0, const.SCR_X, const.SCR_Y) | |
| CS = const.CS | |
| # フィールドの行数 | |
| NUM_ROW = SCR_RECT.height / CS | |
| # フィールドの列数 | |
| NUM_COL = SCR_RECT.width / CS | |
| # フィールドの種類 | |
| START = 0 | |
| GOAL = 1 | |
| ROAD = 2 | |
| WALL = 3 | |
| # 色設定 | |
| CS_COLOR = (255, 255, 255) | |
| # 行動の種類 | |
| LEFT = 0 | |
| RIGHT = 1 | |
| UP = 2 | |
| DOWN = 3 | |
| # 方向 | |
| DIREC = [LEFT, RIGHT, UP, DOWN] | |
| class Map: | |
| def __init__(self): | |
| pygame.init() | |
| self.screen = pygame.display.set_mode(SCR_RECT.size) | |
| pygame.display.set_caption(u"Q-Learning") | |
| self.font = pygame.font.SysFont("timesnewroman", 42) | |
| # プロパティ: | |
| self.field = [[State(GOAL) for x in range(NUM_COL)] for y in range(NUM_ROW)] | |
| self.agent = Agent() | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| self.generation = 0 | |
| self.run = False | |
| self.cursor = [NUM_COL / 2, NUM_ROW / 2] | |
| self.clear() | |
| self.set_all_possibility() | |
| clock = pygame.time.Clock() | |
| self.draw(self.screen) | |
| # ----------------------------------- # | |
| # メインループ | |
| # ----------------------------------- # | |
| while True: | |
| # clock.tick(100) | |
| # アップデート | |
| self.update() | |
| # 描画 | |
| self.draw(self.screen) | |
| pygame.display.update() | |
| # イベントハンドラー | |
| for event in pygame.event.get(): | |
| if event.type == QUIT: | |
| pygame.quit() | |
| sys.exit() | |
| elif event.type == KEYDOWN: | |
| if event.key == K_ESCAPE: | |
| pygame.quit() | |
| sys.exit() | |
| elif event.key == K_s: | |
| self.run = not self.run | |
| elif event.key == K_n: | |
| self.step() | |
| elif event.key == K_LEFT: | |
| self.cursor[0] -= 1 | |
| if self.cursor[0] < 0: | |
| self.cursor[0] = 0 | |
| elif event.key == K_RIGHT: | |
| self.cursor[0] += 1 | |
| if self.cursor[0] > NUM_COL - 1: | |
| self.cursor[0] = NUM_COL - 1 | |
| elif event.key == K_UP: | |
| self.cursor[1] -= 1 | |
| if self.cursor[1] < 0: | |
| self.cursor[1] = 0 | |
| elif event.key == K_DOWN: | |
| self.cursor[1] += 1 | |
| if self.cursor[1] > NUM_ROW - 1: | |
| self.cursor[1] = NUM_ROW - 1 | |
| elif event.key == K_SPACE: | |
| x, y = self.cursor | |
| print '-----------------------------------' | |
| print ' %05.2f' % self.field[y][x].action[2].q | |
| print '%05.2f' % self.field[y][x].action[0].q, | |
| print ' %05.2f' % self.field[y][x].action[1].q | |
| print ' %05.2f' % self.field[y][x].action[3].q | |
| print '-----------------------------------' | |
| # init | |
| def clear(self): | |
| self.generation = 0 | |
| self.agent = Agent() | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| for y in range(NUM_ROW): | |
| for x in range(NUM_COL): | |
| self.field[y][x] = State(const.FIELD[y][x]) | |
| if const.FIELD[y][x] == GOAL: | |
| self.field[y][x].r = 100 | |
| def draw(self, screen): | |
| # print("debug draw") | |
| for y in range(NUM_ROW): | |
| for x in range(NUM_COL): | |
| if self.field[y][x].k == WALL: | |
| pygame.draw.rect(screen, (0, 0, 0), Rect(x * CS, y * CS, CS, CS)) | |
| elif self.field[y][x].k == ROAD: | |
| pygame.draw.rect(screen, CS_COLOR, Rect(x * CS, y * CS, CS, CS)) | |
| if self.field[y][x].get_max_q_action().q != 0: | |
| val = self.field[y][x].get_max_q_action().q / 100.0 | |
| # if val > 1: | |
| # val = 1 | |
| val *= 255.0 | |
| color = (255, 255 - val, 255 - val) | |
| pygame.draw.rect(screen, color, Rect(x * CS, y * CS, CS, CS)) | |
| num = self.field[y][x].get_max_q_action().d | |
| direction = u"" | |
| if num == UP: | |
| direction = u"↑" | |
| elif num == DOWN: | |
| direction = u"↓" | |
| elif num == LEFT: | |
| direction = u"←" | |
| else: | |
| direction = u"→" | |
| screen.blit(self.font.render(direction, True, (0, 0, 0)), (x * CS, y * CS)) | |
| elif self.field[y][x].k == GOAL: | |
| pygame.draw.rect(screen, (100, 255, 255), Rect(x * CS, y * CS, CS, CS)) | |
| if y == self.agent.y and x == self.agent.x: | |
| pygame.draw.rect(screen, (0, 0, 255), Rect(x * CS, y * CS, CS, CS)) | |
| pygame.draw.rect(screen, (50, 50, 50), Rect(x * CS, y * CS, CS, CS), 1) | |
| pygame.draw.rect(screen, (0, 255, 0), Rect(self.cursor[0] * CS, self.cursor[1] * CS, CS, CS), 5) | |
| # 学習ステップ | |
| def step(self): | |
| # 現在の環境から、Q値が最大となる行動オブジェクトを取得 | |
| self.action = self.state.action_select() | |
| # 行動に従いエージェントの位置を更新 | |
| self.agent.move(self.action) | |
| # | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| self.action.update_q_value(self.state) | |
| if self.state.k == 1: | |
| self.agent.move_start() | |
| self.state = self.field[self.agent.y][self.agent.x] | |
| # 学習時にはStepを返す | |
| def update(self): | |
| if self.run == True: | |
| self.step() | |
| def set_possibility(self, y, x): | |
| # No Wall | |
| N_WALL = [ROAD, START, GOAL] | |
| # もし今いる場所が壁じゃないなら = True | |
| if self.field[y][x].k in N_WALL: | |
| # 今いる場所の四方が壁じゃないなら、 | |
| if self.field[y][x - 1].k in N_WALL: | |
| self.field[y][x].set_action(LEFT, True) | |
| if self.field[y][x + 1].k in N_WALL: | |
| self.field[y][x].set_action(RIGHT, True) | |
| if self.field[y - 1][x].k in N_WALL: | |
| self.field[y][x].set_action(UP, True) | |
| if self.field[y + 1][x].k in N_WALL: | |
| self.field[y][x].set_action(DOWN, True) | |
| def set_all_possibility(self): | |
| for y in range(1, NUM_ROW - 1): | |
| for x in range(1, NUM_COL - 1): | |
| self.set_possibility(y, x) | |
| if __name__ == '__main__': | |
| Map() | 
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | # -*- coding:utf-8 -+- | |
| """ | |
| state.py | |
| 環境オブジェクト | |
| """ | |
| import const | |
| from action import * | |
| import random | |
| class State: | |
| def __init__(self, kind): | |
| # 報酬 | |
| self.reward = 0 | |
| self.kind = kind | |
| # 行動オブジェクトを登録 | |
| self.action = [Action(0, False), Action(1, False), Action(2, False), Action(3, False)] | |
| # 報酬 | |
| def getr(self): | |
| return self.reward | |
| def setr(self, value): | |
| self.reward = value | |
| r = property(getr, setr) | |
| # マップの種類(GOAL, ROAD, START, WALL) | |
| def getk(self): | |
| return self.kind | |
| def setk(self, value): | |
| self.kind = value | |
| k = property(getk, setk) | |
| # ? | |
| def set_action(self, direction, possibility): | |
| self.action[direction].set_possibility(direction, possibility) | |
| def get_max_q_action(self): | |
| # print 'get_max_q_action' | |
| # print """Q値が最大の行動actionを返す""" | |
| m = max(xrange(len(self.action)), key=lambda i: self.action[i].q) | |
| # もし、アクションのQ値がすべて0である場合はランダムに返す | |
| if self.action[0].q == 0.00 and self.action[1].q == 0.00 and self.action[2].q == 0.00 and self.action[ | |
| 3].q == 0.00: | |
| m = random.randint(0, 3) | |
| # print 'All Action is 0',m | |
| # while True: | |
| # m = random.randint(0,3) | |
| # if self.action[m].p!=False: | |
| # break | |
| return self.action[m] | |
| def action_select(self): | |
| choosen = 0 | |
| # たまにランダムを返す | |
| if random.random() > const.EPSILON_RATE: | |
| choosen = self.get_max_q_action().d | |
| else: | |
| choosen = random.randint(0, 3) | |
| # while True: | |
| # choosen = random.randint(0,3) | |
| # if self.action[choosen].p!=False: | |
| # break | |
| # 行動オブジェクトを返却 | |
| return self.action[choosen] | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment