Skip to content

Instantly share code, notes, and snippets.

@si3mshady
Created November 21, 2025 18:47
Show Gist options
  • Select an option

  • Save si3mshady/69b3b320592c91c607a2089dfa7a9934 to your computer and use it in GitHub Desktop.

Select an option

Save si3mshady/69b3b320592c91c607a2089dfa7a9934 to your computer and use it in GitHub Desktop.
This file provides a fully commented, self-contained implementation of a simple 2D GridWorld environment, built on the standard Gymnasium API (the successor to OpenAI Gym). This environment is designed to serve as a foundational, educational example for anyone learning how to create custom environments for Reinforcement Learning (RL) agents.

Gymnasium Custom Environment: Annotated GridWorld (gridworld_env.py)

This file provides a fully commented, self-contained implementation of a simple 2D GridWorld environment, built on the standard Gymnasium API (the successor to OpenAI Gym). This environment is designed to serve as a foundational, educational example for anyone learning how to create custom environments for Reinforcement Learning (RL) agents.

Key Features Demonstrated

Environment Initialization (init): Defines the size of the world and sets up rendering.

Observation & Action Spaces: Clearly defines the Dict observation space (agent and target coordinates) and the Discrete(4) action space (Up, Down, Left, Right).

reset() Function: Implements the required episode reset logic, randomly placing the agent and the target, and returning the initial (observation, info) tuple.

step() Function: The core of the environment logic. It takes an agent's action and returns the crucial 5-tuple:

observation

reward (1.0 for reaching the target, 0.0 otherwise)

terminated (True if the target is reached)

truncated (Always False in this simple example, used for time limits)

info (Auxiliary data, specifically the Manhattan distance to the target)

Rendering: Includes PyGame-based rendering logic for visualization (render_mode="human").

Usage Example: A simple if name == 'main': block demonstrates how a test script or an RL library would interact with the environment's public API (reset(), step(), close()).

Installation and Usage

To run the example locally, you need the dependencies:

pip install gymnasium numpy pygame

Execute the file directly:

python gridworld_env.py

This will run a 5x5 grid environment, display the PyGame window, and print verbose console output detailing every step, direction vector, and termination check.

# ==============================================================================
# 1. IMPORTS AND SETUP
# ==============================================================================
import numpy as np # NumPy is essential for handling array operations and random number generation.
import pygame # PyGame is used here exclusively for rendering the environment visually.
import gymnasium as gym # The core library that provides the Gym/Gymnasium API and base classes.
from gymnasium import spaces # Contains the classes used to define the Observation and Action Spaces.
from enum import Enum # Used to create simple, readable constants for our actions (e.g., RIGHT=0).
# A constant used for identifying and registering the environment later.
ENV_ID = "CustomGym/GridWorld-v0"
# ==============================================================================
# 2. ACTIONS DEFINITION (The "What Can I Do?" part)
# ==============================================================================
class Actions(Enum):
"""Maps discrete action indices to movement directions."""
RIGHT = 0 # If the agent chooses action '0', it means "Move Right".
UP = 1 # If the agent chooses action '1', it means "Move Up".
LEFT = 2 # If the agent chooses action '2', it means "Move Left".
DOWN = 3 # If the agent chooses action '3', it means "Move Down".
# ==============================================================================
# 3. THE ENVIRONMENT CLASS (The "World Rules" part)
# ==============================================================================
class GridWorldEnv(gym.Env):
"""
A simple 2D grid world where the agent (blue circle) must reach the target (red square).
It defines the rules for state transitions, rewards, and episode termination.
"""
# Required attribute: defines supported visualization modes and framerate.
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
def __init__(self, render_mode=None, size=5):
"""Initializes the environment's constants and defines the spaces."""
super().__init__()
self.size = size
# 1. Store and print the initialized grid size.
print(f"--- INIT --- Grid size set to: {self.size}x{self.size}")
self.window_size = 512
# --- CRUCIAL API: Observation Space Definition ---
# The Observation Space tells the RL agent *what kind of data* it will receive.
# 2. Define the observation space (input to the agent's policy).
self.observation_space = spaces.Dict(
{
# Agent location: Box(0, size-1, shape=(2,), int) -> A vector [x, y].
"agent": spaces.Box(
0, size - 1, shape=(2,), dtype=int
),
# Target location: Identical structure to the agent location.
"target": spaces.Box(0, size - 1, shape=(2,), dtype=int),
}
)
# 3. Print the structure of the Observation Space.
# Output Example: "--- INIT --- Observation Space defined (Agent Input): Dict('agent': Box(0, 4, (2,), int), 'target': Box(0, 4, (2,), int))"
print(f"--- INIT --- Observation Space defined (Agent Input): {self.observation_space}")
# --- CRUCIAL API: Action Space Definition ---
# The Action Space tells the RL agent *what actions it can choose*.
# 4. Define the action space (output from the agent's policy: 0, 1, 2, or 3).
self.action_space = spaces.Discrete(len(Actions))
# 5. Print the structure of the Action Space.
# Output Example: "--- INIT --- Action Space defined (Agent Choice): Discrete(4)"
print(f"--- INIT --- Action Space defined (Agent Choice): {self.action_space}")
# Internal state variables
self._agent_location = np.array([-1, -1], dtype=int)
self._target_location = np.array([-1, -1], dtype=int)
# Mapping for action translation: Converts action index (0-3) to a coordinate change vector.
self._action_to_direction = {
Actions.RIGHT.value: np.array([1, 0]),
Actions.UP.value: np.array([0, 1]),
Actions.LEFT.value: np.array([-1, 0]),
Actions.DOWN.value: np.array([0, -1]),
}
# Rendering setup variables
assert render_mode is None or render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
self.window = None
self.clock = None
def _get_obs(self):
"""Translates the internal state into the formal observation dictionary."""
return {"agent": self._agent_location, "target": self._target_location}
def _get_info(self):
"""Provides auxiliary information (like distance) for logging or debugging."""
# Calculate Manhattan distance (L1 norm).
distance = np.linalg.norm(
self._agent_location - self._target_location, ord=1
)
return {"distance": distance}
# ==============================================================================
# 4. MANDATORY API CALL: reset() - Starting a New Game
# ==============================================================================
def reset(self, seed=None, options=None):
"""
Called to initiate a new episode. Randomly places the agent and target.
@API Return: (observation: Dict, info: Dict)
"""
super().reset(seed=seed)
# 1. Signal the start of the reset call.
# Output Example: "--- RESET API CALL --- Initializing new episode."
print("\n--- RESET API CALL --- Initializing new episode.")
# 2. Randomly place the agent using the environment's internal RNG (self.np_random).
self._agent_location = self.np_random.integers(0, self.size, size=2, dtype=int)
# 3. Print the randomly chosen starting position of the agent.
# Output Example: " > Agent random start location set to: [3 1]"
print(f" > Agent random start location set to: {self._agent_location}")
# 4. Randomly place the target, ensuring it does not start on the agent's square.
self._target_location = self._agent_location
while np.array_equal(self._target_location, self._agent_location):
self._target_location = self.np_random.integers(
0, self.size, size=2, dtype=int
)
# 5. Print the randomly chosen target location.
# Output Example: " > Target location set to: [0 4]"
print(f" > Target location set to: {self._target_location}")
# 6. Package the results for the agent.
observation = self._get_obs()
info = self._get_info()
# 7. Print the initial Manhattan distance.
# Output Example: " > Initial Info (Manhattan Distance): 7.0"
print(f" > Initial Info (Manhattan Distance): {info['distance']}")
# 8. Render the initial state if in human mode.
if self.render_mode == "human":
self._render_frame()
return observation, info
# ==============================================================================
# 5. MANDATORY API CALL: step(action) - Playing One Turn
# ==============================================================================
def step(self, action):
"""
Executes one time step based on the agent's choice (action).
@API Return: (observation: Dict, reward: float, terminated: bool, truncated: bool, info: Dict)
"""
# 1. Print the start of the step call and the chosen action (e.g., '1 (UP)').
# Output Example: "--- STEP API CALL --- Action taken: 1 (UP)"
print(f"\n--- STEP API CALL --- Action taken: {action} ({Actions(action).name})")
# 2. Print the agent's location before the move.
# Output Example: " > Agent location before move: [3 1]"
print(f" > Agent location before move: {self._agent_location}")
# 3. Translate the discrete action into a coordinate change vector.
direction = self._action_to_direction[action]
# 4. Print the calculated direction vector.
# Output Example: " > Direction vector applied: [0 1]"
print(f" > Direction vector applied: {direction}")
# Calculate the theoretical new position
new_location_pre_clip = self._agent_location + direction
# 5. Calculate the new position and enforce boundaries using np.clip.
# np.clip prevents the agent from moving outside the grid (0 to size-1).
self._agent_location = np.clip(
new_location_pre_clip,
0,
self.size - 1
)
# 6. Print the final location after clipping.
# Output Example: " > New location after clip (final position): [3 2]"
print(f" > New location after clip (final position): {self._agent_location}")
# 7. Determine if the episode is over (Terminated).
terminated = np.array_equal(self._agent_location, self._target_location)
# 8. Calculate Reward: 1 point for reaching the goal, 0 otherwise.
reward = 1 if terminated else 0
# 9. Print the termination status and the resulting reward.
# Output Example: " > Check termination: Goal reached? False. Reward: 0"
print(f" > Check termination: Goal reached? {terminated}. Reward: {reward}")
# 10. Gather the results for the next loop iteration.
observation = self._get_obs()
info = self._get_info()
# Truncated: Set to False (used for time limits, which we don't have here).
truncated = False
# 11. Print the new Manhattan distance to the target.
# Output Example: " > New Distance to Target: 6.0"
print(f" > New Distance to Target: {info['distance']}")
# 12. Render if necessary.
if self.render_mode == "human":
self._render_frame()
# Return the required 5-tuple.
return observation, reward, terminated, truncated, info
# ==============================================================================
# 6. MANDATORY API CALL: render() - Drawing the World
# ==============================================================================
def render(self):
"""Public method to request a frame."""
if self.render_mode == "rgb_array":
return self._render_frame()
# Private rendering logic (standard PyGame boilerplate).
def _render_frame(self):
if self.window is None and self.render_mode == "human":
pygame.init()
pygame.display.init()
self.window = pygame.display.set_mode(
(self.window_size, self.window_size)
)
if self.clock is None and self.render_mode == "human":
self.clock = pygame.time.Clock()
canvas = pygame.Surface((self.window_size, self.window_size))
canvas.fill((255, 255, 255))
pix_square_size = (self.window_size / self.size)
# Draw the Target (Red Square)
pygame.draw.rect(
canvas,
(255, 0, 0),
pygame.Rect(
self._target_location * pix_square_size,
(pix_square_size, pix_square_size),
),
)
# Draw the Agent (Blue Circle)
center_coords = (self._agent_location + 0.5) * pix_square_size
pygame.draw.circle(
canvas,
(0, 0, 255),
center_coords,
pix_square_size / 3,
)
# Draw Gridlines
for x in range(self.size + 1):
pygame.draw.line(canvas, 0, (pix_square_size * x, 0), (pix_square_size * x, self.window_size), width=3)
pygame.draw.line(canvas, 0, (0, pix_square_size * x), (self.window_size, pix_square_size * x), width=3)
if self.render_mode == "human":
self.window.blit(canvas, canvas.get_rect())
pygame.event.pump()
pygame.display.update()
self.clock.tick(self.metadata["render_fps"])
else:
return np.transpose(
np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
)
# ==============================================================================
# 7. MANDATORY API CALL: close() - Cleaning Up
# ==============================================================================
def close(self):
"""Closes the environment and associated PyGame resources."""
if self.window is not None:
pygame.display.quit()
pygame.quit()
self.window = None
self.clock = None
# ==============================================================================
# 8. USAGE EXAMPLE: How an Agent (or Test Script) Uses the API
# ==============================================================================
if __name__ == '__main__':
# 1. Create an instance of the environment
env = GridWorldEnv(render_mode="human", size=5)
# 2. CALL THE reset() API to start the episode
observation, info = env.reset(seed=42)
# 3. Define a sequence of actions
steps_to_take = [Actions.RIGHT.value, Actions.UP.value, Actions.LEFT.value, Actions.DOWN.value]
total_reward = 0
terminated = False
for i, action in enumerate(steps_to_take):
if terminated:
print("\n--- EPISODE TERMINATED --- Stop stepping.")
break
# 4. CALL THE step(action) API and unpack the results
observation, reward, terminated, truncated, info = env.step(action)
total_reward += reward
print(f"\n=============================================")
print(f"Episode finished. Total Reward: {total_reward}")
print(f"=============================================")
# 5. CALL THE close() API to clean up PyGame resources
env.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment