Created
August 31, 2021 04:42
-
-
Save tomtung/c2fab9d0e22501b6e40ab7e5d6339ec7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "52eb889f", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import gym" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "7672fa88", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from enum import IntEnum\n", | |
"\n", | |
"class CellState(IntEnum):\n", | |
" \"\"\"State of a cell.\n", | |
" \n", | |
" A cell can be undug.\n", | |
" \n", | |
" If it's dug, it could be a rupee:\n", | |
" - Green Rupee (+1): 0 nearby bombs / rupoors\n", | |
" - Blue Rupee (+5): 1 or 2 nearby bombs / rupoors\n", | |
" - Red Rupee (+20): 3 or 4 nearby bombs / rupoors\n", | |
" - Silver Rupee (+100): 5 or 6 nearby bombs / rupoors\n", | |
" - Gold Rupee (+300): 7 or 8 nearby bombs / rupoors\n", | |
" \n", | |
" It could also be a rupoor, which reduces the total reward by 10\n", | |
" (without going into negative).\n", | |
" \n", | |
" Finally, it could also be a bomb, which doesn't reduce the reward\n", | |
" but terminates the episode immediately.\n", | |
"\n", | |
" \"\"\"\n", | |
" UNDUG = 0\n", | |
" GREEN = 1\n", | |
" BLUE = 2\n", | |
" RED = 3\n", | |
" SILVER = 4\n", | |
" GOLD = 5\n", | |
" RUPOOR = 6\n", | |
" BOMB = 7\n", | |
" \n", | |
" @classmethod\n", | |
" def from_adj_bad_count(cls, count):\n", | |
" count_to_val = {\n", | |
" 0: cls.GREEN,\n", | |
" 1: cls.BLUE,\n", | |
" 2: cls.BLUE,\n", | |
" 3: cls.RED,\n", | |
" 4: cls.RED,\n", | |
" 5: cls.SILVER,\n", | |
" 6: cls.SILVER,\n", | |
" 7: cls.GOLD,\n", | |
" 8: cls.GOLD,\n", | |
" }\n", | |
"\n", | |
" return count_to_val[count]\n", | |
" \n", | |
" def to_reward(self):\n", | |
" val_to_reward = {\n", | |
" self.GREEN: 1,\n", | |
" self.BLUE: 5,\n", | |
" self.RED: 20,\n", | |
" self.SILVER: 100,\n", | |
" self.GOLD: 300,\n", | |
" self.RUPOOR: -10,\n", | |
" self.BOMB: 0,\n", | |
" }\n", | |
" return val_to_reward[self]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "cd82b1b0", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import gym\n", | |
"import random\n", | |
"import numpy as np\n", | |
"\n", | |
"\n", | |
"class ThrillDiggerEnv(gym.Env):\n", | |
" metadata = {'render.modes': ['human']}\n", | |
" reward_range = (-10.0, 500.0)\n", | |
" \n", | |
" N_ROWS, N_COLS = 5, 8\n", | |
" N_CELLS = N_ROWS * N_COLS\n", | |
"\n", | |
" action_space = gym.spaces.Discrete(N_CELLS)\n", | |
" observation_space = gym.spaces.MultiDiscrete([len(CellState)] * N_CELLS)\n", | |
" \n", | |
" inner_states = [CellState.GREEN] * N_CELLS\n", | |
" is_dug = [False] * N_CELLS\n", | |
" total_reward = 0\n", | |
" is_done = False\n", | |
" \n", | |
" @property\n", | |
" def observation(self):\n", | |
" return [\n", | |
" cell_state if cell_is_dug else CellState.UNDUG\n", | |
" for cell_is_dug, cell_state in zip(self.is_dug, self.inner_states)\n", | |
" ]\n", | |
" \n", | |
" def reset(self):\n", | |
" grid_state = [\n", | |
" [None] * self.N_COLS\n", | |
" for _ in range(self.N_ROWS)\n", | |
" ]\n", | |
"\n", | |
" def fill_bombs_and_rupoor():\n", | |
" positions = [\n", | |
" (r, l)\n", | |
" for r in range(self.N_ROWS)\n", | |
" for l in range(self.N_COLS)\n", | |
" ]\n", | |
" random.shuffle(positions)\n", | |
" \n", | |
" for i in range(8):\n", | |
" r, l = positions[i]\n", | |
" grid_state[r][l] = CellState.RUPOOR\n", | |
" \n", | |
" for i in range(8, 16):\n", | |
" r, l = positions[i]\n", | |
" grid_state[r][l] = CellState.BOMB\n", | |
" \n", | |
" def is_bad(r, l):\n", | |
" return 0 <= l < self.N_COLS and \\\n", | |
" 0 <= r < self.N_ROWS and \\\n", | |
" grid_state[r][l] in (CellState.RUPOOR, CellState.BOMB)\n", | |
" \n", | |
" def set_rupee(r, l):\n", | |
" if grid_state[r][l] is not None:\n", | |
" return\n", | |
" \n", | |
" bad_count = sum([\n", | |
" int(is_bad(r + dr, l + dl))\n", | |
" for dr in [-1, 0, 1]\n", | |
" for dl in [-1, 0, 1]\n", | |
" ])\n", | |
" grid_state[r][l] = CellState.from_adj_bad_count(bad_count)\n", | |
" \n", | |
" fill_bombs_and_rupoor()\n", | |
" for r in range(self.N_ROWS):\n", | |
" for l in range(self.N_COLS):\n", | |
" set_rupee(r, l)\n", | |
" \n", | |
" self.inner_states = [\n", | |
" item\n", | |
" for row in grid_state\n", | |
" for item in row\n", | |
" ]\n", | |
" self.is_dug = [False] * self.N_CELLS\n", | |
" self.total_reward = 0\n", | |
" self.is_done = False\n", | |
" return np.array(self.observation, dtype=np.int64)\n", | |
"\n", | |
" def render(self, mode='human'):\n", | |
" for i in range(self.N_CELLS):\n", | |
" if i > 0 and i % self.N_COLS == 0:\n", | |
" print(\"\")\n", | |
"\n", | |
" name = self.inner_states[i].name.title()[:3]\n", | |
" if not self.is_dug[i]:\n", | |
" name = f\"({name})\"\n", | |
" \n", | |
" print(name, end=\"\\t\")\n", | |
" \n", | |
" print(\"\")\n", | |
"\n", | |
" def step(self, action):\n", | |
" # NB: agent should make sure to not dig cells that are already dug\n", | |
" reward = 0\n", | |
" if not self.is_done and not self.is_dug[action]:\n", | |
" self.is_dug[action] = True\n", | |
" self.is_done = self.is_done or self.inner_states[action] == CellState.BOMB\n", | |
"\n", | |
" # NB: Make sure that the total reward is always non-negative \n", | |
" new_total_reward = max(0, self.total_reward + self.inner_states[action].to_reward())\n", | |
" reward = new_total_reward - self.total_reward\n", | |
" self.total_reward = new_total_reward\n", | |
"\n", | |
" return (\n", | |
" np.array(self.observation, dtype=np.int64),\n", | |
" reward,\n", | |
" self.is_done,\n", | |
" {\"cell_state\": self.inner_states[action]}\n", | |
" )" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "3a1a13c8", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1.77 ms ± 60.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%%timeit\n", | |
"from stable_baselines3.common.env_checker import check_env\n", | |
"\n", | |
"env = ThrillDiggerEnv()\n", | |
"check_env(env)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "26bc3dae", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"(Blu)\t(Blu)\t(Bom)\t(Blu)\t(Blu)\t(Blu)\t(Red)\t(Rup)\t\n", | |
"(Bom)\t(Sil)\t(Red)\t(Red)\t(Red)\t(Rup)\t(Red)\t(Bom)\t\n", | |
"(Rup)\t(Bom)\t(Rup)\t(Bom)\tBom\t(Rup)\t(Red)\t(Blu)\t\n", | |
"(Blu)\t(Red)\t(Red)\t(Red)\t(Red)\t(Red)\t(Sil)\t(Rup)\t\n", | |
"(Gre)\t(Blu)\t(Bom)\t(Blu)\t(Blu)\t(Bom)\t(Rup)\t(Rup)\t\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"True" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"env = ThrillDiggerEnv()\n", | |
"env.reset()\n", | |
"env.step(20)\n", | |
"env.render()\n", | |
"env.is_done" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"id": "f03e7628", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 0,\n", | |
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),\n", | |
" 0,\n", | |
" True,\n", | |
" {'cell_state': <CellState.RED: 3>})" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"env.step(10)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a1bc9513", | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment