|
# ==================================================================================== |
|
# PPO MJÖLNIR INTERCEPT TRAINING SCRIPT (V12 - Drift Truncation & Distance Observation) |
|
# ==================================================================================== |
|
|
|
# ------------------------------------------------------------------------------------ |
|
# --- INSTALLATION & SETUP (CRITICAL FOR COLAB VIDEO) ⚙️ --- |
|
# ------------------------------------------------------------------------------------ |
|
|
|
!pip install stable_baselines3[extra] gymnasium[classic_control] -q |
|
!apt-get install -y swig |
|
!pip install 'gymnasium[box2d]' -q |
|
!pip install pyvirtualdisplay -q |
|
!apt-get install -y xvfb xserver-xephyr tightvncserver # FIX: Replaced vnc4server |
|
!pip install moviepy -q |
|
|
|
# ------------------------------------------------------------------------------------ |
|
# --- IMPORTS 🧰 --- |
|
# ------------------------------------------------------------------------------------ |
|
|
|
import os |
|
import numpy as np |
|
import gymnasium as gym |
|
from gymnasium import spaces |
|
from stable_baselines3 import PPO |
|
from stable_baselines3.common.env_util import make_vec_env |
|
from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv |
|
from IPython.display import HTML |
|
from pyvirtualdisplay import Display |
|
import base64 |
|
import time |
|
|
|
# Initialize the virtual display for rendering video |
|
display = Display(visible=0, size=(1000, 750)) |
|
display.start() |
|
|
|
# Define Paths and Budget |
|
TOTAL_TIMESTEPS = 500_000 |
|
LOG_DIR = "./ppo_mjollnir_logs/" |
|
VIDEO_DIR = f"{LOG_DIR}/videos/" |
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
os.makedirs(VIDEO_DIR, exist_ok=True) |
|
|
|
# ------------------------------------------------------------------------------------ |
|
## 1. Custom Environment: MjollnirEnvV12 (Drift Truncation & Distance Observation) |
|
# ------------------------------------------------------------------------------------ |
|
|
|
class MjollnirEnvV12(gym.Env): # <--- ENVIRONMENT CLASS BUMPED TO V12 |
|
""" |
|
V12 includes V11 features PLUS: |
|
1. Added Normalized Distance from target to the observation space (Obs space size: 12). |
|
2. Added hard truncation logic to reset the episode if the hammer drifts too far off-screen. |
|
""" |
|
metadata = {"render_modes": ["rgb_array", "human"], "render_fps": 15} |
|
|
|
def __init__(self, render_mode=None, size=1000, dt=1/30): |
|
super().__init__() |
|
self.size = size |
|
self.dt = dt |
|
self.max_force = 240.0 |
|
self.max_torque = 15.0 |
|
self.hammer_mass = 5.0 |
|
self.hammer_inertia = 0.75 |
|
self.smash_radius_base = 60.0 |
|
self.smash_velocity = 20.0 |
|
self.terminal_accel_radius = 250.0 |
|
self.max_steps = 1500 |
|
self.current_step = 0 |
|
|
|
# OBSERVATION SPACE (12 values: 2xHammerPos, 4xRelPos/Vel, 2xHammerVel, 1xHammerAngle, 1xAngleDiff, 1xRadius, 1xNormDist) |
|
self.observation_space = spaces.Box(-np.inf, np.inf, shape=(12,), dtype=np.float32) # <--- SHAPE BUMPED TO 12 |
|
|
|
# ACTION SPACE (3 values: Forward/Aft Thrust, Lateral Thrust, Torque) |
|
self.action_space = spaces.Box(-1.0, 1.0, shape=(3,), dtype=np.float32) |
|
|
|
self.render_mode = render_mode |
|
self.window = None |
|
self.clock = None |
|
|
|
# State variables for hammer orientation |
|
self._hammer_angle = 0.0 |
|
self._hammer_angular_vel = 0.0 |
|
|
|
def _get_info(self): |
|
"""Calculates auxiliary information, including angle difference for observation and distance for reward.""" |
|
delta_pos = self._asteroid_pos - self._hammer_pos |
|
distance = np.linalg.norm(delta_pos) |
|
target_angle = np.arctan2(delta_pos[1], delta_pos[0]) |
|
|
|
angle_diff = target_angle - self._hammer_angle |
|
angle_diff = (angle_diff + np.pi) % (2 * np.pi) - np.pi # Normalize to (-pi, pi) |
|
|
|
current_speed = np.linalg.norm(self._hammer_vel) |
|
lateral_speed = 0.0 |
|
if current_speed > 1e-5: |
|
forward_dir = np.array([np.cos(self._hammer_angle), np.sin(self._hammer_angle)]) |
|
lateral_dir = np.array([-forward_dir[1], forward_dir[0]]) |
|
lateral_speed = np.dot(self._hammer_vel, lateral_dir) |
|
|
|
return { |
|
"distance": distance, |
|
"normalized_distance": distance / self.size, # <--- NEW: Normalized distance for observation |
|
"hammer_speed": current_speed, |
|
"angle_diff": angle_diff, |
|
"angular_vel": self._hammer_angular_vel, |
|
"lateral_speed": lateral_speed |
|
} |
|
|
|
def _get_obs(self): |
|
"""Returns the relative state needed for interception PLUS absolute hammer position and normalized distance.""" |
|
info = self._get_info() |
|
delta_pos = self._asteroid_pos - self._hammer_pos |
|
delta_vel = self._asteroid_vel - self._hammer_vel |
|
|
|
# V12 Observation includes absolute hammer position and normalized distance |
|
return np.concatenate([ |
|
self._hammer_pos / self.size, # Normalized Hammer position (2 values) |
|
delta_pos / self.size, # Normalized Relative position (2 values) |
|
delta_vel, # Relative velocity (2 values) |
|
self._hammer_vel, # Hammer velocity (2 values) |
|
[self._hammer_angle], # Hammer angle (1 value) |
|
[info['angle_diff']], # Angle difference (1 value) |
|
[self._asteroid_radius], # Asteroid radius (1 value) |
|
[info['normalized_distance']] # Normalized Distance (1 value) <--- NEW |
|
], dtype=np.float32) |
|
|
|
def reset(self, seed=None, options=None): |
|
super().reset(seed=seed) |
|
self.current_step = 0 |
|
|
|
self.ASTEROID_MIN_R = 30.0 |
|
self.ASTEROID_MAX_R = 60.0 |
|
self._asteroid_radius = self.np_random.uniform(self.ASTEROID_MIN_R, self.ASTEROID_MAX_R) |
|
self._asteroid_mass = self._asteroid_radius |
|
|
|
self._hammer_pos = self.np_random.uniform(low=self.size*0.2, high=self.size*0.8, size=(2,)) |
|
self._hammer_vel = np.zeros(2, dtype=np.float32) |
|
self._hammer_angle = self.np_random.uniform(0, 2 * np.pi) |
|
self._hammer_angular_vel = 0.0 |
|
|
|
# Asteroid starts far away (on an edge) - V8: SAFE SPAWN LOGIC |
|
METRIC_ZONE = 250 |
|
start_edge = self.np_random.choice([0, 1, 2, 3]) |
|
|
|
if start_edge == 0: # Left Edge |
|
y_pos = self.np_random.uniform(low=METRIC_ZONE, high=self.size) |
|
x_pos = 0 |
|
elif start_edge == 1: # Right Edge |
|
y_pos = self.np_random.uniform(low=0, high=self.size) |
|
x_pos = self.size |
|
elif start_edge == 2: # Bottom Edge |
|
x_pos = self.np_random.uniform(low=0, high=self.size) |
|
y_pos = self.size |
|
else: # Top Edge |
|
x_pos = self.np_random.uniform(low=METRIC_ZONE, high=self.size) |
|
y_pos = 0 |
|
|
|
self._asteroid_pos = np.array([x_pos, y_pos], dtype=np.float32) |
|
self._asteroid_vel = np.zeros(2, dtype=np.float32) |
|
|
|
obs = self._get_obs() |
|
info = self._get_info() |
|
|
|
if self.render_mode == "human": |
|
self._render_frame() |
|
|
|
return obs, info |
|
|
|
def step(self, action): |
|
self.current_step += 1 |
|
info_before = self._get_info() |
|
dist_before = info_before['distance'] |
|
|
|
# --- 1. Hammer Dynamics (Linear & Angular) --- |
|
forward_action = action[0] |
|
lateral_action = action[1] |
|
torque_action = action[2] |
|
|
|
# FIX: ALIGN HAMMER THRUST WITH RENDER ORIENTATION |
|
corrected_angle = self._hammer_angle + np.pi/2 |
|
forward_vec = np.array([np.cos(corrected_angle), np.sin(corrected_angle)]) |
|
lateral_vec = np.array([-np.sin(corrected_angle), np.cos(corrected_angle)]) |
|
|
|
# Total Force calculation |
|
forward_force = forward_vec * forward_action * self.max_force |
|
lateral_force = lateral_vec * lateral_action * self.max_force * 0.5 |
|
total_force = forward_force + lateral_force |
|
|
|
# Linear Dynamics |
|
hammer_accel = total_force / self.hammer_mass |
|
self._hammer_vel += hammer_accel * self.dt |
|
self._hammer_pos += self._hammer_vel * self.dt |
|
|
|
# Angular Dynamics |
|
torque = torque_action * self.max_torque |
|
hammer_angular_accel = torque / self.hammer_inertia |
|
self._hammer_angular_vel += hammer_angular_accel * self.dt |
|
self._hammer_angle += self._hammer_angular_vel * self.dt |
|
self._hammer_angle %= (2 * np.pi) |
|
|
|
# --- 2. Asteroid Dynamics (Inertial Movement) --- |
|
self._asteroid_pos += self._asteroid_vel * self.dt |
|
|
|
# --- 3. Reward Calculation --- |
|
reward = 0.0 |
|
info_after = self._get_info() |
|
dist_after = info_after['distance'] |
|
hammer_speed = info_after['hammer_speed'] |
|
angle_diff = info_after['angle_diff'] |
|
|
|
# R_distance (Distance Reduction - Strongly incentivizing closing the gap) |
|
distance_change = dist_before - dist_after |
|
reward += 50.0 * distance_change / self.size |
|
|
|
# R_retreat (V11: Large Penalty for moving away) |
|
if distance_change < 0: |
|
# Penalize moving away |
|
reward -= 1.0 |
|
|
|
# R_proximity (Boomerang/Proximity Incentive) |
|
epsilon = 1e-5 |
|
R_proximity = 1.0 / (dist_after / self.size + epsilon) |
|
reward += 0.05 * R_proximity |
|
|
|
# R_orientation (Facing the target) |
|
reward += 0.5 * np.cos(angle_diff) |
|
|
|
# R_action (Fuel/Energy Penalty) |
|
reward -= 0.01 * (forward_action**2 + lateral_action**2 + 0.1 * torque_action**2) |
|
|
|
# R_time (Minor Step Penalty) |
|
reward -= 0.05 |
|
|
|
# R_terminal (Acceleration Bonus) |
|
if dist_after < self.terminal_accel_radius: |
|
velocity_fraction = np.clip(hammer_speed / self.smash_velocity, 0.0, 1.0) |
|
reward += 2.0 * velocity_fraction |
|
|
|
# --- 4. Termination/Truncation --- |
|
terminated = False |
|
truncated = False # <--- FIX: Initialize truncated to False (solves UnboundLocalError) |
|
collision_distance = self.smash_radius_base + self._asteroid_radius |
|
|
|
# Smash Check: Proximity AND Orientation |
|
if dist_after <= collision_distance: |
|
is_oriented = np.abs(angle_diff) < np.deg2rad(30) |
|
|
|
if is_oriented: |
|
# SUCCESSFUL SMASH! Massive Bonus! |
|
reward += 10000.0 |
|
terminated = True |
|
else: |
|
# Gentle collision or bad angle hit |
|
reward -= 100.0 |
|
terminated = True |
|
|
|
# --- NEW V12 FIX: Hard Truncation/Reset for Drifting/Max Distance --- |
|
max_drift_dist = 2.0 * self.size # Hammer is 2x the play area size (1000m) away from the target |
|
max_boundary = 3.0 * self.size # Hammer is 3x the size away from (0,0) or (size,size) |
|
|
|
if dist_after > max_drift_dist: |
|
# Penalize for drifting too far from the target |
|
reward -= 500.0 |
|
truncated = True |
|
|
|
# Truncate if the hammer moves far outside the general area (e.g., beyond -2000 or 3000) |
|
x_out_of_bounds = self._hammer_pos[0] < -max_boundary or self._hammer_pos[0] > max_boundary |
|
y_out_of_bounds = self._hammer_pos[1] < -max_boundary or self._hammer_pos[1] > max_boundary |
|
|
|
if x_out_of_bounds or y_out_of_bounds: |
|
# Penalize for leaving the known universe |
|
reward -= 1000.0 |
|
truncated = True |
|
|
|
truncated = truncated or (self.current_step >= self.max_steps) # Final check for max steps |
|
|
|
if self.render_mode == "human": |
|
self._render_frame() |
|
|
|
return self._get_obs(), reward, terminated, truncated, info_after |
|
|
|
# --- RENDER METHOD (Updated for V12) --- |
|
def render(self): |
|
if self.render_mode == "rgb_array": |
|
return self._render_frame() |
|
|
|
def _render_frame(self): |
|
import pygame |
|
|
|
if self.window is None: |
|
pygame.init() |
|
self.window = pygame.display.set_mode((self.size, self.size)) |
|
pygame.display.set_caption("Mjölnir Intercept-v12") # <--- RENDER CAPTION BUMPED TO V12 |
|
|
|
if self.clock is None: |
|
self.clock = pygame.time.Clock() |
|
|
|
canvas = pygame.Surface((self.size, self.size)) |
|
canvas.fill((20, 20, 40)) # Space background |
|
|
|
info = self._get_info() |
|
distance = info['distance'] |
|
hammer_speed = info['hammer_speed'] |
|
angle_diff = info['angle_diff'] |
|
angular_vel = info['angular_vel'] |
|
|
|
# 1. Draw Target (Asteroid) |
|
asteroid_pos = self._asteroid_pos.astype(int) |
|
asteroid_color = (150, 75, 0) |
|
pygame.draw.circle(canvas, asteroid_color, asteroid_pos, int(self._asteroid_radius)) |
|
|
|
# 2. Draw Agent (Hammer) |
|
hammer_pos = self._hammer_pos.astype(int) |
|
|
|
head_w, head_h = 20, 50 |
|
handle_l = 50 |
|
|
|
hammer_head_surf = pygame.Surface((head_w, head_h), pygame.SRCALPHA) |
|
hammer_head_surf.fill((150, 150, 255)) |
|
|
|
hammer_handle_surf = pygame.Surface((5, handle_l), pygame.SRCALPHA) |
|
hammer_handle_surf.fill((100, 50, 0)) |
|
|
|
max_dim = max(head_w, head_h + handle_l) |
|
hammer_surf = pygame.Surface((max_dim * 2, max_dim * 2), pygame.SRCALPHA) |
|
hammer_rect = hammer_surf.get_rect() |
|
|
|
hammer_surf.blit(hammer_head_surf, (hammer_rect.centerx - head_w / 2, hammer_rect.centery - head_h)) |
|
hammer_surf.blit(hammer_handle_surf, (hammer_rect.centerx - 5 / 2, hammer_rect.centery)) |
|
|
|
rotated_hammer_surf = pygame.transform.rotate(hammer_surf, np.rad2deg(self._hammer_angle) + 90) |
|
|
|
new_rect = rotated_hammer_surf.get_rect(center=hammer_pos) |
|
canvas.blit(rotated_hammer_surf, new_rect) |
|
|
|
# 3. Draw Smash Radius (Warning Zone) |
|
collision_distance = self.smash_radius_base + self._asteroid_radius |
|
warning_color = (255, 100, 0) |
|
pygame.draw.circle(canvas, warning_color, hammer_pos, int(collision_distance), 2) |
|
|
|
# 4. Draw Orientation Marker |
|
hammer_tip = hammer_pos + np.array([np.cos(self._hammer_angle), np.sin(self._hammer_angle)]) * 75 |
|
marker_color = (255, 255, 255) |
|
pygame.draw.line(canvas, marker_color, hammer_pos, hammer_tip.astype(int), 3) |
|
|
|
# 5. Draw Info text |
|
font = pygame.font.Font(None, 36) |
|
|
|
# Row 1: Distance Status |
|
distance_text = f"DIST: {distance:.1f} / {self.size:.1f}" |
|
text_color_distance = (255, 255, 255) |
|
text_surface_distance = font.render(distance_text, True, text_color_distance) |
|
canvas.blit(text_surface_distance, (10, 10)) |
|
|
|
# Row 2: Speed Status |
|
status_text_speed = f"SPD: {hammer_speed:.1f} (Max Force: {self.max_force:.1f})" |
|
text_color_speed = (0, 255, 0) if hammer_speed >= self.smash_velocity else (255, 255, 0) |
|
text_surface_speed = font.render(status_text_speed, True, text_color_speed) |
|
canvas.blit(text_surface_speed, (10, 50)) |
|
|
|
# Row 3: Angle Status |
|
angle_text = f"ANGLE ERR: {np.rad2deg(angle_diff):.1f}°" |
|
text_color_angle = (0, 255, 0) if np.abs(np.rad2deg(angle_diff)) < 30 else (255, 100, 100) |
|
text_surface_angle = font.render(angle_text, True, text_color_angle) |
|
canvas.blit(text_surface_angle, (10, 90)) |
|
|
|
# Row 4: Angular Velocity |
|
angular_vel_text = f"ROT VEL: {angular_vel:.2f} rad/s" |
|
text_color_rot = (100, 200, 255) |
|
text_surface_rot = font.render(angular_vel_text, True, text_color_rot) |
|
canvas.blit(text_surface_rot, (10, 130)) |
|
|
|
# Row 5: Lateral Speed |
|
lateral_speed_text = f"LAT SPD: {info['lateral_speed']:.2f} m/s" |
|
text_color_lat = (255, 150, 0) |
|
text_surface_lat = font.render(lateral_speed_text, True, text_color_lat) |
|
canvas.blit(text_surface_lat, (10, 170)) |
|
|
|
|
|
if self.render_mode == "rgb_array": |
|
return np.transpose( |
|
np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2) |
|
) |
|
|
|
def close(self): |
|
if self.window is not None: |
|
import pygame |
|
pygame.quit() |
|
self.window = None |
|
|
|
# ------------------------------------------------------------------------------------ |
|
## 2. PPO Training and Evaluation (V12 ENVIRONMENT) 🚀 |
|
# ------------------------------------------------------------------------------------ |
|
|
|
# --- Helper Functions (No change) --- |
|
def linear_schedule(initial_value: float): |
|
def func(progress_remaining: float) -> float: |
|
return initial_value * progress_remaining |
|
return func |
|
|
|
def show_video(): |
|
mp4list = [f for f in os.listdir(VIDEO_DIR) if f.endswith('.mp4')] |
|
if len(mp4list) > 0: |
|
mp4list.sort(key=lambda x: os.path.getmtime(os.path.join(VIDEO_DIR, x)), reverse=True) |
|
file = mp4list[0] |
|
video = open(os.path.join(VIDEO_DIR, file), 'rb').read() |
|
b64 = base64.b64encode(video).decode() |
|
return HTML(f'<video width="100%" controls><source src="data:video/mp4;base64,{b64}" type="video/mp4"></video>') |
|
else: |
|
print("No video file found.") |
|
return HTML('<p>No video file found.</p>') |
|
|
|
# --- Hyperparameters --- |
|
INIT_LR = 0.0001 |
|
N_STEPS = 2048 |
|
BATCH_SIZE = 64 |
|
N_EPOCHS = 10 |
|
GAMMA = 0.99 |
|
ENT_COEF = 0.005 |
|
POLICY_KWARGS = dict(net_arch=[dict(pi=[64, 64], vf=[64, 64])]) |
|
|
|
|
|
# --- Model Initialization and Training --- |
|
env_id = "MjollnirIntercept-v12" # <--- ENVIRONMENT ID BUMPED TO V12 |
|
# NOTE: Using the new environment ID and class |
|
gym.envs.registration.register(id=env_id, entry_point=MjollnirEnvV12, max_episode_steps=1500) # <--- ENTRY POINT BUMPED TO V12 |
|
|
|
print(f"Starting Mjölnir training for {TOTAL_TIMESTEPS} timesteps using {env_id}...") |
|
|
|
train_env = make_vec_env(env_id, n_envs=10, env_kwargs=dict(dt=1/30)) |
|
|
|
model = PPO( |
|
"MlpPolicy", |
|
train_env, |
|
# Using linear schedule for learning rate |
|
learning_rate=linear_schedule(INIT_LR), |
|
n_steps=N_STEPS, |
|
batch_size=BATCH_SIZE, |
|
n_epochs=N_EPOCHS, |
|
gamma=GAMMA, |
|
ent_coef=ENT_COEF, |
|
policy_kwargs=POLICY_KWARGS, |
|
verbose=1 |
|
) |
|
|
|
model.learn(total_timesteps=TOTAL_TIMESTEPS) |
|
|
|
# --- Video Recording and Evaluation --- |
|
print("\n--- Recording Trained Agent Performance ---") |
|
|
|
# Ensure the correct V12 environment is used for evaluation |
|
eval_env = DummyVecEnv([lambda: gym.make(env_id, render_mode="rgb_array", dt=1/30)]) |
|
video_length = 6000 |
|
|
|
eval_env = VecVideoRecorder( |
|
eval_env, |
|
VIDEO_DIR, |
|
record_video_trigger=lambda x: x == 0, |
|
video_length=video_length, |
|
name_prefix=f"ppo-mjollnir-agent-v12" # <--- VIDEO PREFIX BUMPED TO V12 |
|
) |
|
|
|
# Robust Reset |
|
try: |
|
obs, _ = eval_env.reset() |
|
except ValueError: |
|
obs = eval_env.reset() |
|
|
|
if obs.ndim == 1: |
|
obs = obs[None] |
|
|
|
for _ in range(video_length): |
|
action, _ = model.predict(obs, deterministic=True) |
|
|
|
# Robust Step Handling |
|
try: |
|
obs, _, terminated, truncated, _ = eval_env.step(action) |
|
except ValueError: |
|
obs, _, done, _ = eval_env.step(action) |
|
terminated = done |
|
truncated = np.array([False]) |
|
|
|
if terminated[0] or truncated[0]: |
|
try: |
|
obs, _ = eval_env.reset() |
|
except ValueError: |
|
obs = eval_env.reset() |
|
|
|
break |
|
|
|
eval_env.close() |
|
|
|
# --- Display Results --- |
|
print("\n--- Video of Trained Agent ---") |
|
show_video() |