|
# ============================================================================== |
|
# 1. IMPORTS AND SETUP |
|
# ============================================================================== |
|
import numpy as np # NumPy is essential for handling array operations and random number generation. |
|
import pygame # PyGame is used here exclusively for rendering the environment visually. |
|
import gymnasium as gym # The core library that provides the Gym/Gymnasium API and base classes. |
|
from gymnasium import spaces # Contains the classes used to define the Observation and Action Spaces. |
|
from enum import Enum # Used to create simple, readable constants for our actions (e.g., RIGHT=0). |
|
|
|
# A constant used for identifying and registering the environment later. |
|
ENV_ID = "CustomGym/GridWorld-v0" |
|
|
|
# ============================================================================== |
|
# 2. ACTIONS DEFINITION (The "What Can I Do?" part) |
|
# ============================================================================== |
|
class Actions(Enum): |
|
"""Maps discrete action indices to movement directions.""" |
|
RIGHT = 0 # If the agent chooses action '0', it means "Move Right". |
|
UP = 1 # If the agent chooses action '1', it means "Move Up". |
|
LEFT = 2 # If the agent chooses action '2', it means "Move Left". |
|
DOWN = 3 # If the agent chooses action '3', it means "Move Down". |
|
|
|
# ============================================================================== |
|
# 3. THE ENVIRONMENT CLASS (The "World Rules" part) |
|
# ============================================================================== |
|
class GridWorldEnv(gym.Env): |
|
""" |
|
A simple 2D grid world where the agent (blue circle) must reach the target (red square). |
|
It defines the rules for state transitions, rewards, and episode termination. |
|
""" |
|
# Required attribute: defines supported visualization modes and framerate. |
|
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} |
|
|
|
def __init__(self, render_mode=None, size=5): |
|
"""Initializes the environment's constants and defines the spaces.""" |
|
super().__init__() |
|
self.size = size |
|
# 1. Store and print the initialized grid size. |
|
print(f"--- INIT --- Grid size set to: {self.size}x{self.size}") |
|
self.window_size = 512 |
|
|
|
# --- CRUCIAL API: Observation Space Definition --- |
|
# The Observation Space tells the RL agent *what kind of data* it will receive. |
|
# 2. Define the observation space (input to the agent's policy). |
|
self.observation_space = spaces.Dict( |
|
{ |
|
# Agent location: Box(0, size-1, shape=(2,), int) -> A vector [x, y]. |
|
"agent": spaces.Box( |
|
0, size - 1, shape=(2,), dtype=int |
|
), |
|
# Target location: Identical structure to the agent location. |
|
"target": spaces.Box(0, size - 1, shape=(2,), dtype=int), |
|
} |
|
) |
|
# 3. Print the structure of the Observation Space. |
|
# Output Example: "--- INIT --- Observation Space defined (Agent Input): Dict('agent': Box(0, 4, (2,), int), 'target': Box(0, 4, (2,), int))" |
|
print(f"--- INIT --- Observation Space defined (Agent Input): {self.observation_space}") |
|
|
|
|
|
# --- CRUCIAL API: Action Space Definition --- |
|
# The Action Space tells the RL agent *what actions it can choose*. |
|
# 4. Define the action space (output from the agent's policy: 0, 1, 2, or 3). |
|
self.action_space = spaces.Discrete(len(Actions)) |
|
# 5. Print the structure of the Action Space. |
|
# Output Example: "--- INIT --- Action Space defined (Agent Choice): Discrete(4)" |
|
print(f"--- INIT --- Action Space defined (Agent Choice): {self.action_space}") |
|
|
|
# Internal state variables |
|
self._agent_location = np.array([-1, -1], dtype=int) |
|
self._target_location = np.array([-1, -1], dtype=int) |
|
|
|
# Mapping for action translation: Converts action index (0-3) to a coordinate change vector. |
|
self._action_to_direction = { |
|
Actions.RIGHT.value: np.array([1, 0]), |
|
Actions.UP.value: np.array([0, 1]), |
|
Actions.LEFT.value: np.array([-1, 0]), |
|
Actions.DOWN.value: np.array([0, -1]), |
|
} |
|
|
|
# Rendering setup variables |
|
assert render_mode is None or render_mode in self.metadata["render_modes"] |
|
self.render_mode = render_mode |
|
self.window = None |
|
self.clock = None |
|
|
|
def _get_obs(self): |
|
"""Translates the internal state into the formal observation dictionary.""" |
|
return {"agent": self._agent_location, "target": self._target_location} |
|
|
|
def _get_info(self): |
|
"""Provides auxiliary information (like distance) for logging or debugging.""" |
|
# Calculate Manhattan distance (L1 norm). |
|
distance = np.linalg.norm( |
|
self._agent_location - self._target_location, ord=1 |
|
) |
|
return {"distance": distance} |
|
|
|
# ============================================================================== |
|
# 4. MANDATORY API CALL: reset() - Starting a New Game |
|
# ============================================================================== |
|
def reset(self, seed=None, options=None): |
|
""" |
|
Called to initiate a new episode. Randomly places the agent and target. |
|
|
|
@API Return: (observation: Dict, info: Dict) |
|
""" |
|
super().reset(seed=seed) |
|
# 1. Signal the start of the reset call. |
|
# Output Example: "--- RESET API CALL --- Initializing new episode." |
|
print("\n--- RESET API CALL --- Initializing new episode.") |
|
|
|
# 2. Randomly place the agent using the environment's internal RNG (self.np_random). |
|
self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int) |
|
# 3. Print the randomly chosen starting position of the agent. |
|
# Output Example: " > Agent random start location set to: [3 1]" |
|
print(f" > Agent random start location set to: {self._agent_location}") |
|
|
|
# 4. Randomly place the target, ensuring it does not start on the agent's square. |
|
self._target_location = self._agent_location |
|
while np.array_equal(self._target_location, self._agent_location): |
|
self._target_location = self.np_random.integers( |
|
0, self.size, size=2, dtype=int |
|
) |
|
# 5. Print the randomly chosen target location. |
|
# Output Example: " > Target location set to: [0 4]" |
|
print(f" > Target location set to: {self._target_location}") |
|
|
|
# 6. Package the results for the agent. |
|
observation = self._get_obs() |
|
info = self._get_info() |
|
# 7. Print the initial Manhattan distance. |
|
# Output Example: " > Initial Info (Manhattan Distance): 7.0" |
|
print(f" > Initial Info (Manhattan Distance): {info['distance']}") |
|
|
|
# 8. Render the initial state if in human mode. |
|
if self.render_mode == "human": |
|
self._render_frame() |
|
|
|
return observation, info |
|
|
|
# ============================================================================== |
|
# 5. MANDATORY API CALL: step(action) - Playing One Turn |
|
# ============================================================================== |
|
def step(self, action): |
|
""" |
|
Executes one time step based on the agent's choice (action). |
|
|
|
@API Return: (observation: Dict, reward: float, terminated: bool, truncated: bool, info: Dict) |
|
""" |
|
# 1. Print the start of the step call and the chosen action (e.g., '1 (UP)'). |
|
# Output Example: "--- STEP API CALL --- Action taken: 1 (UP)" |
|
print(f"\n--- STEP API CALL --- Action taken: {action} ({Actions(action).name})") |
|
# 2. Print the agent's location before the move. |
|
# Output Example: " > Agent location before move: [3 1]" |
|
print(f" > Agent location before move: {self._agent_location}") |
|
|
|
# 3. Translate the discrete action into a coordinate change vector. |
|
direction = self._action_to_direction[action] |
|
# 4. Print the calculated direction vector. |
|
# Output Example: " > Direction vector applied: [0 1]" |
|
print(f" > Direction vector applied: {direction}") |
|
|
|
# Calculate the theoretical new position |
|
new_location_pre_clip = self._agent_location + direction |
|
|
|
# 5. Calculate the new position and enforce boundaries using np.clip. |
|
# np.clip prevents the agent from moving outside the grid (0 to size-1). |
|
self._agent_location = np.clip( |
|
new_location_pre_clip, |
|
0, |
|
self.size - 1 |
|
) |
|
# 6. Print the final location after clipping. |
|
# Output Example: " > New location after clip (final position): [3 2]" |
|
print(f" > New location after clip (final position): {self._agent_location}") |
|
|
|
# 7. Determine if the episode is over (Terminated). |
|
terminated = np.array_equal(self._agent_location, self._target_location) |
|
|
|
# 8. Calculate Reward: 1 point for reaching the goal, 0 otherwise. |
|
reward = 1 if terminated else 0 |
|
|
|
# 9. Print the termination status and the resulting reward. |
|
# Output Example: " > Check termination: Goal reached? False. Reward: 0" |
|
print(f" > Check termination: Goal reached? {terminated}. Reward: {reward}") |
|
|
|
|
|
# 10. Gather the results for the next loop iteration. |
|
observation = self._get_obs() |
|
info = self._get_info() |
|
|
|
# Truncated: Set to False (used for time limits, which we don't have here). |
|
truncated = False |
|
|
|
# 11. Print the new Manhattan distance to the target. |
|
# Output Example: " > New Distance to Target: 6.0" |
|
print(f" > New Distance to Target: {info['distance']}") |
|
|
|
# 12. Render if necessary. |
|
if self.render_mode == "human": |
|
self._render_frame() |
|
|
|
# Return the required 5-tuple. |
|
return observation, reward, terminated, truncated, info |
|
|
|
# ============================================================================== |
|
# 6. MANDATORY API CALL: render() - Drawing the World |
|
# ============================================================================== |
|
def render(self): |
|
"""Public method to request a frame.""" |
|
if self.render_mode == "rgb_array": |
|
return self._render_frame() |
|
|
|
# Private rendering logic (standard PyGame boilerplate). |
|
def _render_frame(self): |
|
if self.window is None and self.render_mode == "human": |
|
pygame.init() |
|
pygame.display.init() |
|
self.window = pygame.display.set_mode( |
|
(self.window_size, self.window_size) |
|
) |
|
if self.clock is None and self.render_mode == "human": |
|
self.clock = pygame.time.Clock() |
|
|
|
canvas = pygame.Surface((self.window_size, self.window_size)) |
|
canvas.fill((255, 255, 255)) |
|
pix_square_size = (self.window_size / self.size) |
|
|
|
# Draw the Target (Red Square) |
|
pygame.draw.rect( |
|
canvas, |
|
(255, 0, 0), |
|
pygame.Rect( |
|
self._target_location * pix_square_size, |
|
(pix_square_size, pix_square_size), |
|
), |
|
) |
|
|
|
# Draw the Agent (Blue Circle) |
|
center_coords = (self._agent_location + 0.5) * pix_square_size |
|
pygame.draw.circle( |
|
canvas, |
|
(0, 0, 255), |
|
center_coords, |
|
pix_square_size / 3, |
|
) |
|
|
|
# Draw Gridlines |
|
for x in range(self.size + 1): |
|
pygame.draw.line(canvas, 0, (pix_square_size * x, 0), (pix_square_size * x, self.window_size), width=3) |
|
pygame.draw.line(canvas, 0, (0, pix_square_size * x), (self.window_size, pix_square_size * x), width=3) |
|
|
|
if self.render_mode == "human": |
|
self.window.blit(canvas, canvas.get_rect()) |
|
pygame.event.pump() |
|
pygame.display.update() |
|
self.clock.tick(self.metadata["render_fps"]) |
|
else: |
|
return np.transpose( |
|
np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2) |
|
) |
|
|
|
# ============================================================================== |
|
# 7. MANDATORY API CALL: close() - Cleaning Up |
|
# ============================================================================== |
|
def close(self): |
|
"""Closes the environment and associated PyGame resources.""" |
|
if self.window is not None: |
|
pygame.display.quit() |
|
pygame.quit() |
|
self.window = None |
|
self.clock = None |
|
|
|
# ============================================================================== |
|
# 8. USAGE EXAMPLE: How an Agent (or Test Script) Uses the API |
|
# ============================================================================== |
|
|
|
if __name__ == '__main__': |
|
# 1. Create an instance of the environment |
|
env = GridWorldEnv(render_mode="human", size=5) |
|
|
|
# 2. CALL THE reset() API to start the episode |
|
observation, info = env.reset(seed=42) |
|
|
|
# 3. Define a sequence of actions |
|
steps_to_take = [Actions.RIGHT.value, Actions.UP.value, Actions.LEFT.value, Actions.DOWN.value] |
|
|
|
total_reward = 0 |
|
terminated = False |
|
|
|
for i, action in enumerate(steps_to_take): |
|
if terminated: |
|
print("\n--- EPISODE TERMINATED --- Stop stepping.") |
|
break |
|
|
|
# 4. CALL THE step(action) API and unpack the results |
|
observation, reward, terminated, truncated, info = env.step(action) |
|
|
|
total_reward += reward |
|
|
|
print(f"\n=============================================") |
|
print(f"Episode finished. Total Reward: {total_reward}") |
|
print(f"=============================================") |
|
|
|
# 5. CALL THE close() API to clean up PyGame resources |
|
env.close() |