Last active
January 15, 2021 05:33
-
-
Save nagataka/3659810a39ce3b4aefca0ab09291b437 to your computer and use it in GitHub Desktop.
Blocking Maze for OpenAI Gym
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# OpenAI gym custom environment mimicking Blocking Maze | |
# See Sutton and Barto "Reinforcement Learning an Introduction" | |
# Example 8.2: Blocking Maze | |
from enum import Enum | |
import sys | |
import copy | |
import gym | |
from gym import error, spaces, utils | |
from gym.utils import seeding | |
MAP1 = [ | |
"+---------+", | |
"| G|", | |
"| |", | |
"| |", | |
"|wwwwwwww |", | |
"| |", | |
"| s |", | |
"+---------+", | |
] | |
MAP2 = [ | |
"+---------+", | |
"| G|", | |
"| |", | |
"| |", | |
"| wwwwwwww|", | |
"| |", | |
"| s |", | |
"+---------+", | |
] | |
class Action(Enum): | |
UP = 0 | |
DOWN = 1 | |
LEFT = 2 | |
RIGHT = 3 | |
class MazeEnv1(gym.Env): | |
metadata = {'render.modes': ['human']} | |
def __init__(self): | |
print("init") | |
self.s_loc = (6, 4) | |
self.a_loc = (6, 4) | |
self.g_loc = (1, 9) | |
self.invalid = ['w', '-', '|'] | |
self.action_space = spaces.Discrete(len(self.actions)) | |
self.observation_space = spaces.Discrete(46) # temporary hard coded | |
if not hasattr(self, 'map'): | |
self.map = MAP1 | |
def step(self, action): | |
current_loc = self.a_loc | |
if action == Action.UP.value: | |
self.a_loc = (self.a_loc[0]+1, self.a_loc[1]) | |
elif action == Action.DOWN.value: | |
self.a_loc = (self.a_loc[0]-1, self.a_loc[1]) | |
elif action == Action.LEFT.value: | |
self.a_loc = (self.a_loc[0], self.a_loc[1]-1) | |
elif action == Action.RIGHT.value: | |
self.a_loc = (self.a_loc[0], self.a_loc[1]+1) | |
else: | |
print("Invalid action is specified") | |
raise ValueError | |
reward = 0 | |
done = False | |
info = | |
if self.map[self.a_loc[0]][self.a_loc[1]] in self.invalid: | |
# if the move is invalid, return it back | |
self.a_loc = current_loc | |
if self.map[self.a_loc[0]][self.a_loc[1]] == 'G': | |
reward = 1 | |
done = True | |
return self.a_loc, reward, done, info | |
def reset(self): | |
print("reset") | |
self.__init__() | |
return self.tuple2int(self.a_loc) | |
def render(self, mode='human'): | |
outfile = sys.stdout | |
m = copy.deepcopy(self.map) | |
row = m[self.a_loc[0]] | |
row = row[:self.a_loc[1]]+utils.colorize('A', 'red', bold=True)+row[self.a_loc[1]+1:] | |
m[self.a_loc[0]] = row | |
outfile.write("\n".join(["".join(row) for row in m])+"\n") | |
outfile.write("\n") | |
def close(self): | |
print("close") | |
def switch_maze(self): | |
self.map = MAP2 | |
self.reset() | |
@property | |
def row_length(self): | |
return len(self.map)-2 | |
@property | |
def column_length(self): | |
return len(self.gmap[0])-2 | |
@property | |
def actions(self): | |
return [Action.UP, Action.DOWN, | |
Action.LEFT, Action.RIGHT] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment