Created
May 7, 2026 09:03
-
-
Save CypherpunkSamurai/ad7cf6bb58630b900ede6f2df5cfb51c to your computer and use it in GitHub Desktop.
PPO Optimizer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| ================================================================================ | |
| PPO.RL.py — Production-Grade Proximal Policy Optimization | |
| ================================================================================ | |
| WHAT IS PPO? (Explain Like I'm 5) | |
| ---------------------------------- | |
| Imagine you're teaching a dog (the "agent") to navigate a maze (the "environment"). | |
| Every step the dog takes, it gets a treat (+reward) or a smack (-reward). | |
| The dog has two "brains": | |
| 🧠 Actor → decides WHAT action to take ("go left / right / jump") | |
| 🧠 Critic → judges HOW GOOD the current situation is ("this position is worth 10 treats") | |
| PPO (Proximal Policy Optimization) is the training algorithm that teaches both brains. | |
| Its key idea: "Don't change your mind too drastically in one step." | |
| If the dog accidentally learns "always go right" too strongly, it might forget how to go left. | |
| PPO clips (limits) how much the policy can change per update — like a safety leash. | |
| Key math concepts (ELI5): | |
| - Advantage : "Was this action better or worse than average?" (Actual reward - Expected reward) | |
| - GAE : A smoother way to estimate advantage using multiple future steps | |
| - Clipped Ratio: If new policy is too different from old, clip the update to prevent instability | |
| - Entropy Bonus: Encourage exploration — don't let the agent become boring and repetitive | |
| SUPPORTS: | |
| ✅ Discrete action spaces (e.g. CartPole, Atari, Gymnasium) | |
| ✅ Continuous action spaces (e.g. MuJoCo, robotics arms, locomotion) | |
| ✅ Pluggable backbone networks (MLP default, Transformer, CNN, custom) | |
| ✅ Vectorized parallel environments | |
| ✅ Observation & reward normalization | |
| ✅ Gradient clipping | |
| ✅ Learning-rate scheduling (linear decay / cosine annealing / warmup) | |
| ✅ Value-function clipping (extra PPO stabilization trick) | |
| ✅ KL-divergence early stopping | |
| ✅ TensorBoard + W&B + CSV logging | |
| ✅ Full checkpoint save/resume | |
| ✅ Mixed precision (AMP) training | |
| ✅ Robot training (continuous, recurrent-ready) | |
| ✅ Transformer policy backbone | |
| USAGE QUICK-START: | |
| ------------------ | |
| from PPO.RL import PPOConfig, PPOAgent, make_gymnasium_env | |
| cfg = PPOConfig(env_id="CartPole-v1", total_timesteps=500_000) | |
| agent = PPOAgent(cfg) | |
| agent.learn() | |
| # Continuous (robotic) env: | |
| cfg = PPOConfig(env_id="HalfCheetah-v4", action_space_type="continuous", | |
| total_timesteps=1_000_000) | |
| agent = PPOAgent(cfg) | |
| agent.learn() | |
| # Custom environment: | |
| agent = PPOAgent(cfg, env_factory=my_env_factory) | |
| agent.learn() | |
| # Custom policy backbone (e.g. Transformer): | |
| from PPO.RL import TransformerPolicyBackbone | |
| cfg = PPOConfig(env_id="MyEnv-v0", backbone_cls=TransformerPolicyBackbone) | |
| agent = PPOAgent(cfg) | |
| AUTHORS: Built from PyTorch PPO Tutorial + tomasspangelo/proximal-policy-optimization | |
| LICENSE: MIT | |
| ================================================================================ | |
| """ | |
| from __future__ import annotations | |
| # ── Standard Library ───────────────────────────────────────────────────────── | |
| import abc | |
| import copy | |
| import csv | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import time | |
| import warnings | |
| from collections import deque | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import ( | |
| Any, Callable, Dict, Generator, List, Optional, | |
| Sequence, Tuple, Type, Union, | |
| ) | |
| # ── Third-Party ─────────────────────────────────────────────────────────────── | |
| import numpy as np | |
| import scipy.signal | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.distributions import Categorical, Normal, Beta | |
| from torch.utils.tensorboard import SummaryWriter # pip install tensorboard | |
| try: | |
| import gymnasium as gym # pip install gymnasium | |
| _GYM_AVAILABLE = True | |
| except ImportError: | |
| try: | |
| import gym # fallback: classic gym | |
| _GYM_AVAILABLE = True | |
| except ImportError: | |
| _GYM_AVAILABLE = False | |
| warnings.warn("Neither 'gymnasium' nor 'gym' found. Install with: pip install gymnasium") | |
| try: | |
| import wandb # pip install wandb (optional) | |
| _WANDB_AVAILABLE = True | |
| except ImportError: | |
| _WANDB_AVAILABLE = False | |
| # ── Logging Setup ───────────────────────────────────────────────────────────── | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| logger = logging.getLogger("PPO.RL") | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §1 CONFIGURATION | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # | |
| # ELI5: Think of PPOConfig as a recipe card. | |
| # Before you cook, you write down all the ingredients and cooking times. | |
| # Here we write down all the settings PPO needs before training starts. | |
| @dataclass | |
| class PPOConfig: | |
| """ | |
| All hyper-parameters and training settings for PPO in one place. | |
| Key parameters explained simply: | |
| total_timesteps : How many total "game steps" to train for. | |
| rollout_steps : How many steps to collect before updating networks. | |
| num_envs : How many parallel game copies to run simultaneously (faster!). | |
| gamma : "How much do I care about future rewards?" 0=ignore, 1=care forever. | |
| gae_lambda : Smoothing factor for advantage estimation. 0.95 is a safe default. | |
| clip_epsilon : The PPO "safety leash". Max 20% policy change per update (0.2). | |
| n_epochs : How many passes over the collected data during each update. | |
| lr_actor : How fast the Actor brain learns. Too high → unstable, too low → slow. | |
| lr_critic : How fast the Critic brain learns. | |
| entropy_coef : Bonus reward for "thinking differently" → encourages exploration. | |
| value_coef : How much to weight the Critic's loss vs. the Actor's loss. | |
| """ | |
| # ── Environment ────────────────────────────────────────────────────────── | |
| env_id: str = "CartPole-v1" | |
| """Gymnasium environment ID (e.g. 'CartPole-v1', 'HalfCheetah-v4').""" | |
| action_space_type: str = "discrete" | |
| """'discrete' for integer actions (games), 'continuous' for real-valued (robots).""" | |
| num_envs: int = 4 | |
| """ | |
| Number of parallel environments. | |
| ELI5: Like running 4 copies of the game simultaneously to collect data faster. | |
| More envs → faster data collection, but more RAM/CPU needed. | |
| """ | |
| seed: int = 42 | |
| """Random seed for reproducibility. Same seed → same training trajectory.""" | |
| # ── Training Schedule ──────────────────────────────────────────────────── | |
| total_timesteps: int = 1_000_000 | |
| """Total environment steps to train for.""" | |
| rollout_steps: int = 2048 | |
| """ | |
| Steps collected per environment before each policy update. | |
| ELI5: Collect 2048 game steps per env, then teach the agent from that data. | |
| Total data per update = rollout_steps × num_envs. | |
| """ | |
| n_epochs: int = 10 | |
| """ | |
| How many times to re-use collected data for gradient updates. | |
| ELI5: Re-read the same textbook chapter 10 times to really memorise it. | |
| More epochs → more efficient data use, but risk of overfitting to stale data. | |
| """ | |
| minibatch_size: int = 64 | |
| """ | |
| Size of each mini-batch during gradient updates. | |
| ELI5: Instead of studying all 2048 examples at once, study 64 at a time. | |
| Smaller batches → noisier gradients but sometimes better generalisation. | |
| """ | |
| # ── Discount & GAE ─────────────────────────────────────────────────────── | |
| gamma: float = 0.99 | |
| """ | |
| Discount factor for future rewards. | |
| ELI5: A reward in 100 steps is worth gamma^100 of a reward right now. | |
| gamma=0.99 means future rewards are almost as valuable as immediate ones. | |
| gamma=0 means the agent is completely short-sighted (only NOW matters). | |
| """ | |
| gae_lambda: float = 0.95 | |
| """ | |
| GAE (Generalized Advantage Estimation) smoothing factor. | |
| ELI5: Instead of estimating "how good was that action?" from one future step, | |
| we blend estimates from many future steps. lambda controls the blend: | |
| lambda=0 → only use the immediate next step (low variance, high bias) | |
| lambda=1 → use ALL future steps (high variance, unbiased) | |
| 0.95 is a sweet spot used in most PPO papers. | |
| """ | |
| # ── PPO Clipping & Losses ──────────────────────────────────────────────── | |
| clip_epsilon: float = 0.2 | |
| """ | |
| The PPO clipping parameter — the "safety leash". | |
| ELI5: If the new policy's probability ratio vs old is outside [0.8, 1.2], | |
| we CLIP it. This prevents catastrophic policy changes in one update. | |
| Standard value: 0.2. Larger = more aggressive updates. | |
| """ | |
| clip_epsilon_schedule: str = "linear" | |
| """ | |
| How clip_epsilon changes over training: 'constant', 'linear' (decay to 0), or 'none'. | |
| Linear decay makes training more conservative as it matures. | |
| """ | |
| value_clip_epsilon: float = 0.2 | |
| """ | |
| Clip value function loss too (same as clip_epsilon by default). | |
| This is an extra stabilisation trick: don't let the Critic's value estimates | |
| change too drastically in a single update either. | |
| Set to None to disable value clipping. | |
| """ | |
| value_coef: float = 0.5 | |
| """ | |
| Weight of the Critic (value) loss in the total loss. | |
| Total loss = actor_loss + value_coef * critic_loss - entropy_coef * entropy. | |
| """ | |
| entropy_coef: float = 0.01 | |
| """ | |
| Weight of the entropy bonus in the total loss. | |
| ELI5: We ADD a bonus to encourage the agent to be "uncertain" (exploratory). | |
| If the agent becomes too confident too fast, it stops exploring better strategies. | |
| entropy_coef=0.01 adds a tiny exploration bonus. | |
| """ | |
| entropy_coef_schedule: str = "linear" | |
| """ | |
| How entropy_coef changes: 'constant', 'linear' (decay to 0), or 'none'. | |
| Decaying entropy → more exploration early, more exploitation later. | |
| """ | |
| max_grad_norm: float = 0.5 | |
| """ | |
| Maximum gradient norm for gradient clipping. | |
| ELI5: If the gradient (the learning signal) is too large, it's scaled down. | |
| This prevents the network from making huge, destabilizing leaps in one step. | |
| Think of it as a speed limiter on learning. | |
| """ | |
| target_kl: Optional[float] = 0.02 | |
| """ | |
| If KL-divergence between old and new policy exceeds this, stop epoch early. | |
| ELI5: KL measures "how different are the old and new policies?" | |
| If the policies diverge too much, we abort the update to prevent instability. | |
| Set to None to disable early stopping. | |
| """ | |
| # ── Network Architecture ────────────────────────────────────────────────── | |
| hidden_sizes: Tuple[int, ...] = (64, 64) | |
| """ | |
| Hidden layer sizes for MLP backbone. | |
| ELI5: The Actor and Critic each have a stack of layers. (64, 64) means | |
| two hidden layers each with 64 neurons. Bigger → more capacity, slower. | |
| """ | |
| activation: str = "tanh" | |
| """ | |
| Activation function: 'tanh', 'relu', 'elu', 'gelu'. | |
| ELI5: The "squishing function" applied after each layer. | |
| 'tanh' works well for bounded state spaces. 'relu' is popular for large nets. | |
| """ | |
| shared_backbone: bool = False | |
| """ | |
| If True, Actor and Critic share the same feature extractor (backbone). | |
| ELI5: One pair of eyes (shared) sees the world, then two separate heads decide | |
| (action head) and judge (value head). Often faster to train. | |
| If False, Actor and Critic each have their own independent networks. | |
| """ | |
| backbone_cls: Optional[Type[nn.Module]] = None | |
| """ | |
| Optional custom backbone class. Must accept (obs_dim, hidden_sizes, activation) | |
| and output a feature tensor. Overrides the default MLP if provided. | |
| Use this to plug in a Transformer, CNN, LSTM, etc. | |
| """ | |
| # ── Continuous Action Space Settings ───────────────────────────────────── | |
| continuous_dist: str = "gaussian" | |
| """ | |
| For continuous actions: 'gaussian' (Normal) or 'beta' (Beta distribution). | |
| Gaussian: unbounded actions, common for MuJoCo, robotics. | |
| Beta: bounded [0,1] actions, good for normalized action spaces. | |
| """ | |
| log_std_init: float = -0.5 | |
| """ | |
| Initial log-standard-deviation for Gaussian policy. | |
| ELI5: How "spread out" the action distribution is initially. | |
| More spread = more exploration early on. Becomes learnable during training. | |
| """ | |
| squash_actions: bool = False | |
| """ | |
| Apply tanh squashing to Gaussian actions to bound them to [-1, 1]. | |
| ELI5: Forces continuous actions to stay within a safe range. | |
| Required for environments with bounded action spaces (e.g. [-1, 1]). | |
| """ | |
| # ── Optimiser & LR ─────────────────────────────────────────────────────── | |
| lr_actor: float = 3e-4 | |
| """Actor learning rate. 3e-4 is the "magic number" from the PPO paper.""" | |
| lr_critic: float = 1e-3 | |
| """Critic learning rate. Usually slightly higher than actor.""" | |
| lr_schedule: str = "linear" | |
| """ | |
| LR schedule: 'constant', 'linear' (decay to 0), 'cosine', 'warmup_cosine'. | |
| Linear decay is the most common PPO choice (matches the paper). | |
| """ | |
| optimizer: str = "adam" | |
| """Optimizer: 'adam', 'adamw', 'sgd'.""" | |
| adam_eps: float = 1e-5 | |
| """ | |
| Adam epsilon for numerical stability. | |
| ELI5: Tiny number to avoid division by zero in the Adam optimizer. | |
| Slightly larger (1e-5 vs 1e-8) can improve stability for RL. | |
| """ | |
| weight_decay: float = 0.0 | |
| """L2 regularisation weight. 0 = no regularisation (standard for PPO).""" | |
| # ── Normalisation ───────────────────────────────────────────────────────── | |
| normalise_obs: bool = True | |
| """ | |
| Normalise observations using a running mean/std. | |
| ELI5: If observations are wildly different scales (e.g. position=0.01, velocity=100), | |
| learning is unstable. Normalising puts everything in the same "currency". | |
| """ | |
| normalise_rewards: bool = True | |
| """ | |
| Normalise rewards using a running std (NOT mean — preserves sign). | |
| ELI5: If some rewards are +1000 and some are -0.001, the gradient signals are noisy. | |
| Normalising makes learning more stable across different reward scales. | |
| """ | |
| reward_clip: float = 10.0 | |
| """Clip normalised rewards to [-reward_clip, reward_clip]. Prevents outlier reward explosions.""" | |
| # ── Mixed Precision ─────────────────────────────────────────────────────── | |
| use_amp: bool = False | |
| """ | |
| Use Automatic Mixed Precision (AMP / float16) for faster GPU training. | |
| ELI5: Use 16-bit numbers instead of 32-bit where safe → up to 2× speedup on modern GPUs. | |
| Requires CUDA. Automatically disabled if CPU is used. | |
| """ | |
| # ── Logging & Checkpointing ─────────────────────────────────────────────── | |
| log_dir: str = "./ppo_logs" | |
| """Directory for TensorBoard logs, CSV logs, and checkpoints.""" | |
| experiment_name: str = "ppo_run" | |
| """Name for this experiment run (used in log filenames and W&B).""" | |
| log_interval: int = 1 | |
| """Log metrics every N policy updates.""" | |
| save_interval: int = 10 | |
| """Save checkpoint every N policy updates.""" | |
| use_wandb: bool = False | |
| """Enable Weights & Biases logging (requires: pip install wandb; wandb login).""" | |
| wandb_project: str = "ppo_rl" | |
| """W&B project name.""" | |
| resume_from: Optional[str] = None | |
| """Path to checkpoint file to resume training from.""" | |
| # ── Advanced ────────────────────────────────────────────────────────────── | |
| ortho_init: bool = True | |
| """ | |
| Use orthogonal initialisation for network weights. | |
| ELI5: A special way to initialise weights that helps RL training be more stable. | |
| Empirically shown to improve PPO performance (used in Stable-Baselines3). | |
| """ | |
| norm_adv: bool = True | |
| """ | |
| Normalise advantages (zero mean, unit variance) within each mini-batch. | |
| ELI5: Makes sure advantage estimates don't have wildly different scales, | |
| keeping gradient updates consistent across different parts of training. | |
| """ | |
| recurrent: bool = False | |
| """ | |
| [Future flag] Enable recurrent (LSTM/GRU) policy for partially observable envs. | |
| Not yet implemented — use a Transformer backbone for sequence modelling instead. | |
| """ | |
| device: str = "auto" | |
| """ | |
| Compute device: 'auto' (GPU if available), 'cpu', 'cuda', 'cuda:0', 'mps'. | |
| 'auto' is recommended — it automatically picks the best available device. | |
| """ | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §2 UTILITY FUNCTIONS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def get_device(cfg_device: str) -> torch.device: | |
| """ | |
| Resolve the compute device. | |
| ELI5: Figures out WHERE to run the calculations. | |
| GPU (CUDA) is like a sports car — fast at parallel math. | |
| CPU is like a reliable sedan — always available, but slower for big networks. | |
| MPS is Apple Silicon's GPU backend. | |
| """ | |
| if cfg_device == "auto": | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| logger.info(f"Auto-selected device: CUDA ({torch.cuda.get_device_name(0)})") | |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| logger.info("Auto-selected device: Apple MPS") | |
| else: | |
| device = torch.device("cpu") | |
| logger.info("Auto-selected device: CPU") | |
| else: | |
| device = torch.device(cfg_device) | |
| logger.info(f"Using device: {device}") | |
| return device | |
| def set_seed(seed: int) -> None: | |
| """ | |
| Set all random seeds for reproducibility. | |
| ELI5: Like shuffling a deck of cards in a specific way so you always get | |
| the same "random" order. Needed to reproduce results exactly. | |
| """ | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # Makes CUDA deterministic (slight performance cost) | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = False | |
| def discount_cumsum(x: np.ndarray, discount: float) -> np.ndarray: | |
| """ | |
| Compute discounted cumulative sum of a 1D array. | |
| ELI5: Given rewards [r0, r1, r2, ...], compute: | |
| G0 = r0 + γ*r1 + γ²*r2 + ... | |
| G1 = r1 + γ*r2 + ... | |
| G2 = r2 + ... | |
| This is the "return" — how much total future reward from each step. | |
| scipy.signal.lfilter computes this efficiently as a digital filter. | |
| Args: | |
| x : 1D array of rewards or deltas | |
| discount: discount factor γ (gamma) | |
| Returns: | |
| Discounted cumulative sums, same shape as x | |
| """ | |
| # scipy's lfilter applies an IIR filter — this is a math trick to compute | |
| # the cumulative sum from right to left efficiently in O(n). | |
| return scipy.signal.lfilter([1.0], [1.0, -discount], x[::-1], axis=0)[::-1] | |
| def get_activation(name: str) -> nn.Module: | |
| """ | |
| Return activation function by name. | |
| ELI5: The "squishing function" that goes between neural network layers. | |
| Without activations, stacking linear layers is pointless (they collapse into one). | |
| Activations introduce non-linearity — the ability to learn complex patterns. | |
| """ | |
| activations = { | |
| "tanh": nn.Tanh(), | |
| "relu": nn.ReLU(), | |
| "elu": nn.ELU(), | |
| "gelu": nn.GELU(), | |
| "leaky_relu": nn.LeakyReLU(0.01), | |
| "silu": nn.SiLU(), # Smooth version of ReLU, used in transformers | |
| } | |
| if name not in activations: | |
| raise ValueError(f"Unknown activation '{name}'. Choose from: {list(activations)}") | |
| return activations[name] | |
| def layer_init(layer: nn.Module, std: float = np.sqrt(2), bias_const: float = 0.0) -> nn.Module: | |
| """ | |
| Orthogonal initialisation for a linear layer. | |
| ELI5: Normally neural network weights start random. Orthogonal init starts them | |
| in a special configuration that avoids vanishing/exploding gradients early on. | |
| This is the standard init used in PPO implementations (CleanRL, Stable-Baselines3). | |
| Args: | |
| layer : nn.Linear layer to initialise | |
| std : Standard deviation scale for the init (smaller = smaller initial outputs) | |
| bias_const: Initial value for bias terms | |
| """ | |
| nn.init.orthogonal_(layer.weight, std) | |
| nn.init.constant_(layer.bias, bias_const) | |
| return layer | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §3 RUNNING STATISTICS (Observation & Reward Normalization) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class RunningMeanStd: | |
| """ | |
| Track running mean and standard deviation of a stream of data. | |
| ELI5: Imagine you're measuring heights of students one by one. | |
| Instead of storing all heights and computing average at the end, | |
| you track a "running average" that updates with each new student. | |
| This is memory-efficient and works for infinite data streams. | |
| Used to normalise observations so the network always sees well-scaled inputs. | |
| Algorithm: Welford's online algorithm (numerically stable). | |
| """ | |
| def __init__(self, shape: Tuple[int, ...] = (), epsilon: float = 1e-8): | |
| """ | |
| Args: | |
| shape : Shape of each data point (e.g., (obs_dim,) for observations) | |
| epsilon: Tiny number to prevent division by zero | |
| """ | |
| self.mean = np.zeros(shape, dtype=np.float64) | |
| self.var = np.ones(shape, dtype=np.float64) | |
| self.count = epsilon # Start at epsilon to avoid /0 | |
| def update(self, x: np.ndarray) -> None: | |
| """ | |
| Update running statistics with a batch of new data. | |
| ELI5: Given a batch of new observations, update our "running average" | |
| and "running variance" estimates efficiently. | |
| Args: | |
| x: Array of shape (batch, *shape) or (*shape,) for single sample | |
| """ | |
| batch_mean = np.mean(x, axis=0) | |
| batch_var = np.var(x, axis=0) | |
| batch_count = x.shape[0] if x.ndim > len(self.mean.shape) else 1 | |
| # Parallel/Welford update formula — combines old stats with new batch stats | |
| total_count = self.count + batch_count | |
| delta = batch_mean - self.mean | |
| new_mean = self.mean + delta * batch_count / total_count | |
| m_a = self.var * self.count | |
| m_b = batch_var * batch_count | |
| m2 = m_a + m_b + delta**2 * self.count * batch_count / total_count | |
| new_var = m2 / total_count | |
| self.mean, self.var, self.count = new_mean, new_var, total_count | |
| @property | |
| def std(self) -> np.ndarray: | |
| """Standard deviation (sqrt of variance).""" | |
| return np.sqrt(self.var + 1e-8) # epsilon for numerical safety | |
| def normalise(self, x: np.ndarray) -> np.ndarray: | |
| """Subtract mean, divide by std → zero-mean, unit-variance.""" | |
| return (x - self.mean) / self.std | |
| def state_dict(self) -> Dict: | |
| """Serialise for checkpointing.""" | |
| return {"mean": self.mean, "var": self.var, "count": self.count} | |
| def load_state_dict(self, d: Dict) -> None: | |
| """Restore from checkpoint.""" | |
| self.mean, self.var, self.count = d["mean"], d["var"], d["count"] | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §4 ROLLOUT BUFFER | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class RolloutBuffer: | |
| """ | |
| Stores trajectories collected from the environment during rollout. | |
| ELI5: This is the "notebook" where we write down everything that happened | |
| during the game: | |
| - What state were we in? | |
| - What action did we take? | |
| - What reward did we get? | |
| - How likely was that action (old policy)? | |
| - What did the Critic think the state was worth? | |
| - Was this a terminal (game-over) state? | |
| After the episode, we compute advantages and returns from these notes, | |
| then use them to update the Actor and Critic. | |
| Supports MULTIPLE parallel environments (vectorised). | |
| """ | |
| def __init__( | |
| self, | |
| rollout_steps: int, | |
| num_envs: int, | |
| obs_shape: Tuple[int, ...], | |
| act_shape: Tuple[int, ...], | |
| device: torch.device, | |
| gae_lambda: float = 0.95, | |
| gamma: float = 0.99, | |
| action_space_type: str = "discrete", | |
| ): | |
| """ | |
| Args: | |
| rollout_steps : Steps per environment per rollout | |
| num_envs : Number of parallel environments | |
| obs_shape : Shape of a single observation | |
| act_shape : Shape of a single action | |
| device : Torch device to store tensors on | |
| gae_lambda : GAE lambda parameter (λ) | |
| gamma : Discount factor (γ) | |
| action_space_type : 'discrete' or 'continuous' | |
| """ | |
| self.rollout_steps = rollout_steps | |
| self.num_envs = num_envs | |
| self.obs_shape = obs_shape | |
| self.act_shape = act_shape | |
| self.device = device | |
| self.gae_lambda = gae_lambda | |
| self.gamma = gamma | |
| self.action_space_type = action_space_type | |
| # Pre-allocate tensors for efficiency | |
| # Shape: (rollout_steps, num_envs, *feature_shape) | |
| self.observations = torch.zeros((rollout_steps, num_envs) + obs_shape, dtype=torch.float32) | |
| self.actions = torch.zeros((rollout_steps, num_envs) + act_shape, dtype=torch.float32) | |
| self.rewards = torch.zeros((rollout_steps, num_envs), dtype=torch.float32) | |
| self.dones = torch.zeros((rollout_steps, num_envs), dtype=torch.float32) | |
| self.log_probs = torch.zeros((rollout_steps, num_envs), dtype=torch.float32) | |
| self.values = torch.zeros((rollout_steps, num_envs), dtype=torch.float32) | |
| # Computed after rollout (GAE) | |
| self.advantages = torch.zeros((rollout_steps, num_envs), dtype=torch.float32) | |
| self.returns = torch.zeros((rollout_steps, num_envs), dtype=torch.float32) | |
| self.ptr = 0 # Current write position in the buffer | |
| self.full = False # Whether the buffer has been filled | |
| def reset(self) -> None: | |
| """Clear the buffer — reset pointer to start.""" | |
| self.ptr = 0 | |
| self.full = False | |
| def add( | |
| self, | |
| obs: torch.Tensor, | |
| action: torch.Tensor, | |
| reward: torch.Tensor, | |
| done: torch.Tensor, | |
| log_prob: torch.Tensor, | |
| value: torch.Tensor, | |
| ) -> None: | |
| """ | |
| Store one timestep of experience across all parallel environments. | |
| ELI5: Write one row into our notebook. One row = one game step. | |
| We store the observation, action, reward, and the "metadata" (log prob, value). | |
| Args: | |
| obs : Current observation (num_envs, *obs_shape) | |
| action : Action taken (num_envs, *act_shape) | |
| reward : Reward received (num_envs,) | |
| done : Terminal flag (num_envs,) | |
| log_prob: Log prob of action (num_envs,) | |
| value : Critic value estimate(num_envs,) | |
| """ | |
| assert self.ptr < self.rollout_steps, ( | |
| f"Buffer overflow at ptr={self.ptr}. Call compute_returns_and_advantages() then reset()." | |
| ) | |
| self.observations[self.ptr] = obs.cpu() | |
| self.actions[self.ptr] = action.cpu() | |
| self.rewards[self.ptr] = reward.cpu() | |
| self.dones[self.ptr] = done.cpu() | |
| self.log_probs[self.ptr] = log_prob.cpu() | |
| self.values[self.ptr] = value.cpu() | |
| self.ptr += 1 | |
| def compute_returns_and_advantages(self, last_values: torch.Tensor, last_dones: torch.Tensor) -> None: | |
| """ | |
| Compute Generalised Advantage Estimates (GAE) and returns. | |
| ELI5 of GAE: | |
| Say you're at step t. How GOOD was the action you took? | |
| Naively: actual_return - critic_estimate = advantage | |
| But this is noisy! GAE smooths this by blending multiple future steps. | |
| δ_t = r_t + γ * V(s_{t+1}) * (1 - done) - V(s_t) | |
| ↑ "TD error" = how much better/worse than expected (one step) | |
| A_t = δ_t + (γλ)δ_{t+1} + (γλ)²δ_{t+2} + ... | |
| ↑ exponentially-weighted sum of future TD errors | |
| λ=0: A_t = δ_t (pure TD, low variance, biased) | |
| λ=1: A_t = G_t - V(s_t) (pure MC, unbiased, high variance) | |
| λ=0.95: blend that balances bias/variance | |
| Returns: R_t = A_t + V(s_t) (advantage + baseline = estimated return) | |
| Args: | |
| last_values: Critic's value estimate for the state AFTER the last step | |
| (needed to bootstrap if episode wasn't done) (num_envs,) | |
| last_dones : Whether the last step was terminal (num_envs,) | |
| """ | |
| last_gae_lam = torch.zeros(self.num_envs, dtype=torch.float32) | |
| # Walk BACKWARD through time — this is how GAE is computed efficiently | |
| for step in reversed(range(self.rollout_steps)): | |
| if step == self.rollout_steps - 1: | |
| # After the last step, the "next" values come from last_values | |
| next_non_terminal = 1.0 - last_dones.float().cpu() | |
| next_values = last_values.cpu() | |
| else: | |
| # "next" step is just step+1 | |
| next_non_terminal = 1.0 - self.dones[step + 1] | |
| next_values = self.values[step + 1] | |
| # TD Error (delta): how much better/worse than the critic predicted | |
| # ELI5: "You predicted 10 treats. You got 3 + predicted 8 next step. | |
| # Delta = 3 + 0.99*8 - 10 = 0.92 (slightly better than expected)" | |
| delta = ( | |
| self.rewards[step] | |
| + self.gamma * next_values * next_non_terminal | |
| - self.values[step] | |
| ) | |
| # GAE: accumulate TD errors backward with exponential decay | |
| # ELI5: "Stack up all future deltas, but discount them by (γλ) each step" | |
| last_gae_lam = delta + self.gamma * self.gae_lambda * next_non_terminal * last_gae_lam | |
| self.advantages[step] = last_gae_lam | |
| # Returns = advantages + baseline (critic's value estimate) | |
| # ELI5: "How much total reward do we expect from this state?" | |
| self.returns = self.advantages + self.values | |
| def get_batches(self, minibatch_size: int) -> Generator[Dict[str, torch.Tensor], None, None]: | |
| """ | |
| Yield random mini-batches of experiences for gradient updates. | |
| ELI5: We have 2048 × 4 = 8192 experiences in the buffer. | |
| Instead of updating the network on all 8192 at once (memory hog), | |
| we shuffle them and take random groups of 64 to update from. | |
| This is standard mini-batch stochastic gradient descent. | |
| Args: | |
| minibatch_size: Number of samples per mini-batch | |
| Yields: | |
| Dictionary of tensors for one mini-batch | |
| """ | |
| total_samples = self.rollout_steps * self.num_envs | |
| # Flatten (rollout_steps, num_envs) → (total_samples,) for each tensor | |
| # ELI5: Merge all parallel environments into one big list | |
| b_obs = self.observations.reshape(total_samples, *self.obs_shape) | |
| b_actions = self.actions.reshape(total_samples, *self.act_shape) | |
| b_log_probs= self.log_probs.reshape(total_samples) | |
| b_advantages = self.advantages.reshape(total_samples) | |
| b_returns = self.returns.reshape(total_samples) | |
| b_values = self.values.reshape(total_samples) | |
| # Random permutation of indices — shuffle the data before mini-batching | |
| indices = torch.randperm(total_samples) | |
| # Yield mini-batches of shuffled data | |
| for start in range(0, total_samples, minibatch_size): | |
| end = start + minibatch_size | |
| mb_idx = indices[start:end] | |
| yield { | |
| "obs" : b_obs[mb_idx].to(self.device), | |
| "actions" : b_actions[mb_idx].to(self.device), | |
| "old_log_probs": b_log_probs[mb_idx].to(self.device), | |
| "advantages": b_advantages[mb_idx].to(self.device), | |
| "returns" : b_returns[mb_idx].to(self.device), | |
| "old_values": b_values[mb_idx].to(self.device), | |
| } | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §5 POLICY BACKBONES (Feature Extractors) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # | |
| # ELI5: The backbone is the "eyes and ears" of the agent. | |
| # It processes raw observations (pixels, joint angles, sensor readings) | |
| # into a compact feature vector that the Actor/Critic heads can use. | |
| class MLPBackbone(nn.Module): | |
| """ | |
| Multi-Layer Perceptron backbone — the default feature extractor. | |
| ELI5: A stack of fully-connected layers with non-linear activations between them. | |
| Input: observation vector (e.g. [position, velocity, angle] = 4 numbers for CartPole) | |
| Output: feature vector (e.g. 64 numbers) | |
| Suitable for: tabular/vector observations, physics simulations, robotics joint angles. | |
| Not suitable for: raw pixels (use CNN instead), sequences (use Transformer/LSTM). | |
| """ | |
| def __init__( | |
| self, | |
| obs_dim: int, | |
| hidden_sizes: Tuple[int, ...] = (64, 64), | |
| activation: str = "tanh", | |
| ortho_init: bool = True, | |
| ): | |
| super().__init__() | |
| layers: List[nn.Module] = [] | |
| in_size = obs_dim | |
| # Build hidden layers | |
| for h_size in hidden_sizes: | |
| lin = nn.Linear(in_size, h_size) | |
| if ortho_init: | |
| # Orthogonal init with scale=√2 is the PPO standard | |
| layer_init(lin, std=np.sqrt(2)) | |
| layers.extend([lin, get_activation(activation)]) | |
| in_size = h_size | |
| self.net = nn.Sequential(*layers) | |
| self.output_dim = in_size # Used by Actor/Critic heads to know input size | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Observation tensor (..., obs_dim) | |
| Returns: | |
| Feature tensor (..., hidden_sizes[-1]) | |
| """ | |
| return self.net(x) | |
| class TransformerPolicyBackbone(nn.Module): | |
| """ | |
| Transformer-based backbone for sequence / token observations. | |
| ELI5: Transformers are great at understanding "what relates to what" in a sequence. | |
| If your observation is a sequence of tokens (e.g., text, time-series sensor data, | |
| or a sequence of robot joint states over time), this backbone uses self-attention | |
| to extract rich relational features. | |
| Architecture: | |
| obs → Linear embedding → Positional Encoding → N×TransformerEncoderLayer → mean pool → features | |
| Args: | |
| obs_dim : Dimension of each input token (or flat obs projected to d_model) | |
| seq_len : Length of the input sequence (set to 1 for non-sequential use) | |
| d_model : Internal transformer dimension (embedding size) | |
| nhead : Number of attention heads | |
| n_layers : Number of transformer encoder layers | |
| ffn_dim : Feed-forward network dimension inside transformer | |
| dropout : Dropout probability (0 = disabled for eval, useful for training) | |
| """ | |
| def __init__( | |
| self, | |
| obs_dim: int, | |
| hidden_sizes: Tuple[int, ...] = (64, 64), # ignored — kept for interface compatibility | |
| activation: str = "gelu", | |
| ortho_init: bool = True, | |
| seq_len: int = 1, | |
| d_model: int = 128, | |
| nhead: int = 4, | |
| n_layers: int = 2, | |
| ffn_dim: int = 256, | |
| dropout: float = 0.0, | |
| ): | |
| super().__init__() | |
| self.seq_len = seq_len | |
| self.d_model = d_model | |
| # Project input observations into transformer's d_model dimension | |
| # ELI5: Like translating your observation into the transformer's "language" | |
| self.input_proj = nn.Linear(obs_dim, d_model) | |
| if ortho_init: | |
| layer_init(self.input_proj, std=np.sqrt(2)) | |
| # Positional encoding — tells the transformer WHERE in the sequence each token is | |
| # ELI5: Like numbering the pages in a book so the transformer knows order | |
| self.pos_encoding = nn.Parameter(torch.zeros(1, seq_len, d_model)) | |
| nn.init.normal_(self.pos_encoding, std=0.02) | |
| # Transformer encoder — the self-attention magic | |
| encoder_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, | |
| nhead=nhead, | |
| dim_feedforward=ffn_dim, | |
| dropout=dropout, | |
| activation=activation, | |
| batch_first=True, # (batch, seq, features) — easier to work with | |
| norm_first=True, # Pre-norm for training stability (GPT-2 style) | |
| ) | |
| self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers) | |
| self.output_dim = d_model | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Observation tensor. | |
| If seq_len=1: shape (..., obs_dim) — treated as single-step observation. | |
| If seq_len>1: shape (..., seq_len, obs_dim) — sequential observation. | |
| Returns: | |
| Feature tensor (..., d_model) | |
| """ | |
| if x.dim() == 2: | |
| # (batch, obs_dim) → (batch, 1, obs_dim) — treat as length-1 sequence | |
| x = x.unsqueeze(1) | |
| # Project to d_model | |
| x = self.input_proj(x) # (batch, seq_len, d_model) | |
| # Add positional encoding | |
| x = x + self.pos_encoding[:, :x.size(1), :] | |
| # Run through transformer | |
| x = self.transformer(x) # (batch, seq_len, d_model) | |
| # Pool over the sequence dimension → single feature vector per sample | |
| # ELI5: "Summarise the whole sequence into one compact representation" | |
| x = x.mean(dim=1) # (batch, d_model) | |
| return x | |
| class CNNBackbone(nn.Module): | |
| """ | |
| Convolutional Neural Network backbone for image/pixel observations. | |
| ELI5: CNNs are great at processing images. They use sliding "filters" that | |
| detect local patterns (edges, textures, shapes) and compose them into | |
| higher-level features (faces, objects, game elements). | |
| Suitable for: Atari games, visual robotics, any pixel-based observation. | |
| Architecture (Nature DQN style): | |
| (C, H, W) image → Conv(32,8,4) → Conv(64,4,2) → Conv(64,3,1) → Flatten → Linear(512) | |
| """ | |
| def __init__( | |
| self, | |
| obs_dim: int, # Actually treated as (C, H, W) tuple for images | |
| hidden_sizes: Tuple[int, ...] = (512,), | |
| activation: str = "relu", | |
| ortho_init: bool = True, | |
| channels: Tuple[int, ...] = (32, 64, 64), | |
| kernels: Tuple[int, ...] = (8, 4, 3), | |
| strides: Tuple[int, ...] = (4, 2, 1), | |
| in_channels: int = 4, # Number of stacked frames (e.g., 4 for Atari) | |
| ): | |
| super().__init__() | |
| conv_layers: List[nn.Module] = [] | |
| ch_in = in_channels | |
| for ch_out, k, s in zip(channels, kernels, strides): | |
| conv = nn.Conv2d(ch_in, ch_out, kernel_size=k, stride=s) | |
| if ortho_init: | |
| nn.init.orthogonal_(conv.weight, gain=np.sqrt(2)) | |
| nn.init.constant_(conv.bias, 0) | |
| conv_layers.extend([conv, get_activation(activation)]) | |
| ch_in = ch_out | |
| self.conv = nn.Sequential(*conv_layers) | |
| # Dummy forward to compute flattened CNN output size | |
| # ELI5: Run a fake image through the convolutions to see how big the output is | |
| with torch.no_grad(): | |
| # Assume obs_dim is flattened image size; try to infer spatial dims | |
| dummy = torch.zeros(1, in_channels, 84, 84) # Standard Atari size | |
| flat_size = int(np.prod(self.conv(dummy).shape[1:])) | |
| # Fully-connected head after convolutions | |
| fc_layers: List[nn.Module] = [] | |
| in_size = flat_size | |
| for h in hidden_sizes: | |
| lin = nn.Linear(in_size, h) | |
| if ortho_init: | |
| layer_init(lin, std=np.sqrt(2)) | |
| fc_layers.extend([lin, get_activation(activation)]) | |
| in_size = h | |
| self.fc = nn.Sequential(*fc_layers) | |
| self.output_dim = in_size | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: Image tensor (batch, C, H, W) — values in [0, 255] or [0, 1] | |
| Returns: | |
| Feature tensor (batch, output_dim) | |
| """ | |
| # Scale pixel values to [0, 1] if in [0, 255] | |
| if x.max() > 1.0: | |
| x = x / 255.0 | |
| x = self.conv(x) | |
| x = x.flatten(start_dim=1) | |
| x = self.fc(x) | |
| return x | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §6 ACTOR-CRITIC NETWORKS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # | |
| # ELI5: Two "brains": | |
| # Actor → the DECISION-MAKER → "given this situation, what should I do?" | |
| # Critic → the EVALUATOR → "given this situation, how good is it?" | |
| # | |
| # During rollout: use Actor to pick actions, Critic to estimate value. | |
| # During update: optimise both using PPO loss. | |
| class ActorCritic(nn.Module): | |
| """ | |
| Combined Actor-Critic module supporting: | |
| - Discrete actions (Categorical distribution) | |
| - Continuous actions (Diagonal Gaussian / Beta distribution) | |
| - Shared or separate backbones for Actor and Critic | |
| - Pluggable feature extractor (MLP, Transformer, CNN, or custom) | |
| ELI5: This is the whole "brain" of the agent in one class. | |
| The backbone processes observations → features. | |
| The actor_head turns features → action probabilities / mean. | |
| The critic_head turns features → a single number (value estimate). | |
| """ | |
| def __init__( | |
| self, | |
| obs_dim: int, | |
| act_dim: int, | |
| cfg: PPOConfig, | |
| ): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.act_dim = act_dim | |
| self.action_space_type = cfg.action_space_type | |
| self.continuous_dist = cfg.continuous_dist | |
| # ── Select backbone class ───────────────────────────────────────── | |
| # ELI5: Pick WHICH type of feature extractor to use. | |
| # Default: MLP. User can override with Transformer, CNN, or custom. | |
| BackboneCls = cfg.backbone_cls if cfg.backbone_cls is not None else MLPBackbone | |
| backbone_kwargs = dict( | |
| obs_dim = obs_dim, | |
| hidden_sizes = cfg.hidden_sizes, | |
| activation = cfg.activation, | |
| ortho_init = cfg.ortho_init, | |
| ) | |
| if cfg.shared_backbone: | |
| # ONE backbone shared by both actor and critic heads | |
| # Saves parameters; actor and critic "see the world" the same way | |
| self.shared_net = BackboneCls(**backbone_kwargs) | |
| feat_dim = self.shared_net.output_dim | |
| self.actor_backbone = None | |
| self.critic_backbone = None | |
| else: | |
| # SEPARATE backbones — actor and critic learn independently | |
| # More parameters but often better for complex tasks | |
| self.actor_backbone = BackboneCls(**backbone_kwargs) | |
| self.critic_backbone = BackboneCls(**backbone_kwargs) | |
| feat_dim = self.actor_backbone.output_dim | |
| self.shared_net = None | |
| # ── Critic head ──────────────────────────────────────────────────── | |
| # ELI5: One output number = "this state is worth X future reward" | |
| self.critic_head = nn.Linear(feat_dim, 1) | |
| if cfg.ortho_init: | |
| # Small std for critic head → stable value estimates at start | |
| layer_init(self.critic_head, std=1.0) | |
| # ── Actor head(s) ────────────────────────────────────────────────── | |
| if cfg.action_space_type == "discrete": | |
| # Output: logits for each action (will be softmaxed into probabilities) | |
| # ELI5: "I have 5% chance of going left, 90% right, 5% jump" | |
| self.actor_head = nn.Linear(feat_dim, act_dim) | |
| if cfg.ortho_init: | |
| # Very small std for action logits → near-uniform distribution initially | |
| # ELI5: Start by choosing actions almost randomly, then learn preferences | |
| layer_init(self.actor_head, std=0.01) | |
| elif cfg.action_space_type == "continuous": | |
| if cfg.continuous_dist == "gaussian": | |
| # Mean network: maps features → action mean vector | |
| # ELI5: "Turn the steering wheel by about 0.3 radians (the mean)" | |
| self.actor_mean = nn.Linear(feat_dim, act_dim) | |
| if cfg.ortho_init: | |
| layer_init(self.actor_mean, std=0.01) | |
| # Log-std: LEARNABLE per-action exploration spread | |
| # ELI5: "How uncertain am I about this action?" | |
| # Initialised to log_std_init (not a function of observations) | |
| self.log_std = nn.Parameter( | |
| torch.full((act_dim,), cfg.log_std_init) | |
| ) | |
| elif cfg.continuous_dist == "beta": | |
| # Beta distribution lives on [0, 1] — good for bounded actions | |
| # Parameterised by α (alpha) and β (beta) concentration params | |
| self.actor_alpha = nn.Linear(feat_dim, act_dim) | |
| self.actor_beta = nn.Linear(feat_dim, act_dim) | |
| if cfg.ortho_init: | |
| layer_init(self.actor_alpha, std=0.01) | |
| layer_init(self.actor_beta, std=0.01) | |
| else: | |
| raise ValueError(f"Unknown continuous_dist: {cfg.continuous_dist}") | |
| else: | |
| raise ValueError(f"Unknown action_space_type: {cfg.action_space_type}") | |
| def _get_features(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Extract features from observation using the backbone(s). | |
| Returns: | |
| actor_features : Features for the actor head | |
| critic_features: Features for the critic head | |
| """ | |
| if self.shared_net is not None: | |
| # Single shared backbone | |
| feats = self.shared_net(obs) | |
| return feats, feats | |
| else: | |
| # Separate backbones | |
| return self.actor_backbone(obs), self.critic_backbone(obs) | |
| def get_value(self, obs: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Get critic's value estimate for an observation. | |
| ELI5: "How many future treats do you expect from this game state?" | |
| Args: | |
| obs: Observation tensor (batch, obs_dim) | |
| Returns: | |
| Value tensor (batch, 1) | |
| """ | |
| _, critic_feats = self._get_features(obs) | |
| return self.critic_head(critic_feats) | |
| def get_action_and_value( | |
| self, | |
| obs: torch.Tensor, | |
| action: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Main method used during rollout (sampling) AND during update (evaluating). | |
| During ROLLOUT: pass no action → sample a new action from the policy. | |
| During UPDATE: pass the OLD action → evaluate it under the NEW policy. | |
| ELI5: | |
| Rollout: "What should I do?" → sample action from distribution. | |
| Update: "How would the NEW me have evaluated that OLD action?" → compute log-prob. | |
| Args: | |
| obs : Observation tensor (batch, obs_dim) | |
| action: (Optional) action tensor. If None, sample from distribution. | |
| Returns: | |
| action : The (sampled or provided) action | |
| log_prob : Log probability of the action under current policy | |
| entropy : Entropy of the action distribution (for exploration bonus) | |
| value : Critic's value estimate | |
| """ | |
| actor_feats, critic_feats = self._get_features(obs) | |
| value = self.critic_head(critic_feats) | |
| # ── Build action distribution ────────────────────────────────────── | |
| if self.action_space_type == "discrete": | |
| # Categorical: "which bin (action) to sample from?" | |
| # ELI5: Like rolling a loaded die — each face has a probability | |
| logits = self.actor_head(actor_feats) | |
| dist = Categorical(logits=logits) | |
| elif self.action_space_type == "continuous": | |
| if self.continuous_dist == "gaussian": | |
| mean = self.actor_mean(actor_feats) | |
| # Clamp log_std to prevent numerical instability | |
| # ELI5: Don't let the "uncertainty" become infinitely large or zero | |
| log_std = torch.clamp(self.log_std, min=-20.0, max=2.0) | |
| std = log_std.exp() | |
| dist = Normal(mean, std) | |
| if self.cfg.squash_actions: | |
| # Use tanh-squashed Gaussian (like SAC) | |
| # ELI5: Sample from Gaussian, then squash into [-1, 1] range | |
| # Requires correction to the log probability (Jacobian adjustment) | |
| pass # Handled below when computing log_prob | |
| elif self.continuous_dist == "beta": | |
| # Beta distribution: outputs in [0, 1] | |
| # α = softplus(alpha_net), β = softplus(beta_net) — ensure positivity | |
| alpha = F.softplus(self.actor_alpha(actor_feats)) + 1.0 | |
| beta = F.softplus(self.actor_beta(actor_feats)) + 1.0 | |
| dist = Beta(alpha, beta) | |
| # ── Sample or evaluate action ────────────────────────────────────── | |
| if action is None: | |
| # ROLLOUT: sample a new action | |
| action = dist.sample() | |
| if (self.action_space_type == "continuous" | |
| and self.continuous_dist == "gaussian" | |
| and self.cfg.squash_actions): | |
| # Apply tanh squashing AFTER sampling | |
| action = torch.tanh(action) | |
| # Compute log probability of the action | |
| if (self.action_space_type == "continuous" | |
| and self.continuous_dist == "gaussian" | |
| and self.cfg.squash_actions): | |
| # For tanh-squashed Gaussian, use the pre-squash action for log-prob | |
| # and apply Jacobian correction: log π(a) = log π(u) - log(1 - tanh²(u)) | |
| pre_tanh = torch.atanh(action.clamp(-0.9999, 0.9999)) | |
| log_prob = dist.log_prob(pre_tanh).sum(-1) | |
| log_prob -= (2.0 * (np.log(2) - pre_tanh - F.softplus(-2.0 * pre_tanh))).sum(-1) | |
| else: | |
| if self.action_space_type == "discrete": | |
| log_prob = dist.log_prob(action) | |
| else: | |
| # For continuous: sum log probs across action dimensions | |
| # ELI5: If action has 3 dimensions, the joint probability is the product | |
| # → in log space, that's the sum | |
| log_prob = dist.log_prob(action).sum(-1) | |
| # Entropy: measure of "how random/exploratory" the policy is | |
| # ELI5: High entropy = agent is uncertain → exploring. | |
| # Low entropy = agent is confident → exploiting. | |
| if self.action_space_type == "discrete": | |
| entropy = dist.entropy() | |
| else: | |
| entropy = dist.entropy().sum(-1) | |
| return action, log_prob, entropy, value.squeeze(-1) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §7 ENVIRONMENT WRAPPERS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # | |
| # ELI5: Environments are the "game world". Wrappers are like controllers | |
| # that add extra features (e.g., normalisation, stacking frames, parallelism). | |
| class VecEnv: | |
| """ | |
| Abstract base class for vectorised (parallel) environments. | |
| ELI5: Instead of one game copy, run N copies simultaneously. | |
| This multiplies data collection speed by N with minimal overhead. | |
| """ | |
| def __init__(self, num_envs: int): | |
| self.num_envs = num_envs | |
| @abc.abstractmethod | |
| def reset(self) -> np.ndarray: | |
| """Reset all environments. Returns initial observations (num_envs, obs_dim).""" | |
| ... | |
| @abc.abstractmethod | |
| def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict]]: | |
| """ | |
| Step all environments. | |
| Returns: obs, rewards, dones, infos | |
| """ | |
| ... | |
| @property | |
| @abc.abstractmethod | |
| def observation_space(self): ... | |
| @property | |
| @abc.abstractmethod | |
| def action_space(self): ... | |
| class SyncVecEnv(VecEnv): | |
| """ | |
| Synchronous vectorised environment — runs N envs sequentially. | |
| ELI5: Like running 4 games one after the other in the same process. | |
| Simple, reliable, no multiprocessing headaches. | |
| For async (true parallel), use AsyncVecEnv (not implemented here — use stable-baselines3). | |
| Args: | |
| env_fns: List of callables, each returning a new gym.Env instance. | |
| One callable per parallel environment. | |
| """ | |
| def __init__(self, env_fns: List[Callable[[], Any]]): | |
| super().__init__(num_envs=len(env_fns)) | |
| self.envs = [fn() for fn in env_fns] | |
| self._obs_space = self.envs[0].observation_space | |
| self._act_space = self.envs[0].action_space | |
| @property | |
| def observation_space(self): | |
| return self._obs_space | |
| @property | |
| def action_space(self): | |
| return self._act_space | |
| def reset(self) -> np.ndarray: | |
| """ | |
| Reset all environments and return initial observations. | |
| ELI5: Start all games from scratch. Returns the initial "screenshot" | |
| of each game copy. | |
| """ | |
| obs_list = [] | |
| for env in self.envs: | |
| result = env.reset() | |
| # gymnasium returns (obs, info); gym returns obs | |
| obs = result[0] if isinstance(result, tuple) else result | |
| obs_list.append(obs) | |
| return np.stack(obs_list, axis=0) # (num_envs, obs_dim) | |
| def step(self, actions: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict]]: | |
| """ | |
| Apply actions to all environments, collect results. | |
| ELI5: Take one step in EACH game copy. Collect observations, rewards, | |
| and done-flags from all copies. If a game ends (done=True), auto-reset it. | |
| Args: | |
| actions: (num_envs, *act_shape) or (num_envs,) for discrete | |
| Returns: | |
| obs : Next observations (num_envs, obs_dim) | |
| rewards: Reward per env (num_envs,) | |
| dones : Terminal flags (num_envs,) | |
| infos : Extra info dicts [num_envs] | |
| """ | |
| obs_list, rews, dones, infos = [], [], [], [] | |
| for i, (env, act) in enumerate(zip(self.envs, actions)): | |
| result = env.step(act) | |
| if len(result) == 5: | |
| # gymnasium: (obs, reward, terminated, truncated, info) | |
| o, r, terminated, truncated, info = result | |
| done = terminated or truncated | |
| else: | |
| # gym: (obs, reward, done, info) | |
| o, r, done, info = result | |
| if done: | |
| # Auto-reset: immediately restart the environment after it ends | |
| # ELI5: When the game is over, start a new one automatically. | |
| # We still return the TERMINAL observation in info['terminal_obs']. | |
| if isinstance(info, dict): | |
| info["terminal_obs"] = o | |
| reset_result = env.reset() | |
| o = reset_result[0] if isinstance(reset_result, tuple) else reset_result | |
| obs_list.append(o) | |
| rews.append(r) | |
| dones.append(done) | |
| infos.append(info if info else {}) | |
| return ( | |
| np.stack(obs_list, axis=0), | |
| np.array(rews, dtype=np.float32), | |
| np.array(dones, dtype=np.float32), | |
| infos, | |
| ) | |
| def close(self) -> None: | |
| """Close all environments cleanly.""" | |
| for env in self.envs: | |
| env.close() | |
| def make_gymnasium_env(env_id: str, seed: int, idx: int, **kwargs) -> Callable: | |
| """ | |
| Factory function: returns a callable that creates one gym environment. | |
| ELI5: Like a cookie cutter — call this to get a function that makes | |
| one specific game copy with the right settings. | |
| Args: | |
| env_id: Gymnasium environment ID (e.g., 'CartPole-v1') | |
| seed : Base random seed | |
| idx : Environment index (seed offset, so each env gets a unique seed) | |
| kwargs: Extra kwargs passed to gym.make() | |
| Returns: | |
| Callable → gym.Env (called later by SyncVecEnv) | |
| """ | |
| def _init(): | |
| env = gym.make(env_id, **kwargs) | |
| env.reset(seed=seed + idx) | |
| env.action_space.seed(seed + idx) | |
| env.observation_space.seed(seed + idx) | |
| return env | |
| return _init | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §8 LR SCHEDULE HELPERS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class ScheduledValue: | |
| """ | |
| A value that changes according to a schedule over training. | |
| ELI5: Imagine turning down the volume gradually on a song. | |
| 'linear' decay goes from initial_value → 0 linearly. | |
| 'cosine' decay follows a cosine curve (smooth, slower end). | |
| 'constant' stays at the same value forever. | |
| Used for: learning rate, clip_epsilon, entropy_coef. | |
| """ | |
| def __init__(self, initial: float, schedule: str, total_steps: int): | |
| """ | |
| Args: | |
| initial : Starting value | |
| schedule : 'constant', 'linear', 'cosine', 'warmup_cosine' | |
| total_steps: Total number of update steps for scheduling | |
| """ | |
| self.initial = initial | |
| self.schedule = schedule | |
| self.total_steps = max(total_steps, 1) | |
| def __call__(self, step: int) -> float: | |
| """ | |
| Get the scheduled value at a given step. | |
| Args: | |
| step: Current update step (0-indexed) | |
| Returns: | |
| Scheduled value | |
| """ | |
| progress = step / self.total_steps # 0.0 at start, 1.0 at end | |
| if self.schedule == "constant" or self.schedule == "none": | |
| return self.initial | |
| elif self.schedule == "linear": | |
| # Linearly decay from initial to 0 | |
| # ELI5: Like a candle burning down at a steady rate | |
| return self.initial * max(0.0, 1.0 - progress) | |
| elif self.schedule == "cosine": | |
| # Cosine annealing: smooth decay following a cosine curve | |
| # ELI5: Like a ball rolling down a smooth hill — slows at the bottom | |
| return self.initial * 0.5 * (1.0 + math.cos(math.pi * progress)) | |
| elif self.schedule == "warmup_cosine": | |
| # Linear warmup for first 10%, then cosine decay | |
| # ELI5: Warm up the engine slowly, then let it run down smoothly | |
| warmup_steps = 0.1 * self.total_steps | |
| if step < warmup_steps: | |
| return self.initial * (step / warmup_steps) | |
| else: | |
| progress_after_warmup = (step - warmup_steps) / (self.total_steps - warmup_steps) | |
| return self.initial * 0.5 * (1.0 + math.cos(math.pi * progress_after_warmup)) | |
| else: | |
| raise ValueError(f"Unknown schedule: '{self.schedule}'") | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §9 LOGGER | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class TrainingLogger: | |
| """ | |
| Handles all logging: TensorBoard, CSV, Weights & Biases, and console. | |
| ELI5: Keeps a diary of training. Every few updates, it writes down | |
| reward, loss, learning rate, etc. so you can see how training is going | |
| and debug problems early. | |
| """ | |
| def __init__(self, cfg: PPOConfig): | |
| self.cfg = cfg | |
| self.log_dir = Path(cfg.log_dir) / cfg.experiment_name | |
| self.log_dir.mkdir(parents=True, exist_ok=True) | |
| # TensorBoard writer | |
| self.tb_writer = SummaryWriter(log_dir=str(self.log_dir / "tensorboard")) | |
| logger.info(f"TensorBoard logs: {self.log_dir / 'tensorboard'}") | |
| # CSV log file | |
| self.csv_path = self.log_dir / "training_log.csv" | |
| self._csv_file = open(self.csv_path, "w", newline="") | |
| self._csv_writer = None # Initialised on first log call (to get headers) | |
| logger.info(f"CSV log: {self.csv_path}") | |
| # Weights & Biases | |
| if cfg.use_wandb: | |
| if not _WANDB_AVAILABLE: | |
| warnings.warn("wandb not installed. Install with: pip install wandb") | |
| else: | |
| wandb.init( | |
| project=cfg.wandb_project, | |
| name=cfg.experiment_name, | |
| config=vars(cfg), | |
| sync_tensorboard=True, # Auto-sync TensorBoard to W&B | |
| ) | |
| logger.info("W&B run initialised.") | |
| # Episode reward tracking (deque = sliding window) | |
| # ELI5: Keep the last 100 episode rewards to compute a rolling average | |
| self.ep_rewards: deque = deque(maxlen=100) | |
| self.ep_lengths: deque = deque(maxlen=100) | |
| def log_step(self, step: int, metrics: Dict[str, float]) -> None: | |
| """ | |
| Log a dictionary of metrics at a given training step. | |
| Args: | |
| step : Current update step | |
| metrics: Dict of metric_name → value | |
| """ | |
| # TensorBoard | |
| for k, v in metrics.items(): | |
| if v is not None: | |
| self.tb_writer.add_scalar(k, v, step) | |
| # CSV (write header on first call) | |
| if self._csv_writer is None: | |
| fieldnames = ["step"] + sorted(metrics.keys()) | |
| self._csv_writer = csv.DictWriter(self._csv_file, fieldnames=fieldnames) | |
| self._csv_writer.writeheader() | |
| row = {"step": step} | |
| row.update({k: (f"{v:.6f}" if v is not None else "") for k, v in metrics.items()}) | |
| self._csv_writer.writerow(row) | |
| self._csv_file.flush() | |
| # W&B | |
| if self.cfg.use_wandb and _WANDB_AVAILABLE: | |
| wandb.log({"step": step, **metrics}) | |
| def log_episode(self, reward: float, length: int) -> None: | |
| """Track completed episode stats.""" | |
| self.ep_rewards.append(reward) | |
| self.ep_lengths.append(length) | |
| @property | |
| def mean_reward(self) -> float: | |
| """Rolling mean reward over last 100 episodes.""" | |
| return float(np.mean(self.ep_rewards)) if self.ep_rewards else 0.0 | |
| @property | |
| def mean_length(self) -> float: | |
| """Rolling mean episode length over last 100 episodes.""" | |
| return float(np.mean(self.ep_lengths)) if self.ep_lengths else 0.0 | |
| def close(self) -> None: | |
| """Clean up all logging resources.""" | |
| self.tb_writer.close() | |
| self._csv_file.close() | |
| if self.cfg.use_wandb and _WANDB_AVAILABLE: | |
| wandb.finish() | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §10 PPO AGENT — THE MAIN CLASS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| class PPOAgent: | |
| """ | |
| Production-grade Proximal Policy Optimisation (PPO) agent. | |
| Implements the full PPO training loop: | |
| 1. ROLLOUT : Run current policy in environment, collect experiences. | |
| 2. GAE : Compute advantages and returns using GAE. | |
| 3. UPDATE : Update Actor and Critic using PPO loss for N epochs. | |
| 4. REPEAT : Repeat until total_timesteps reached. | |
| Supports: | |
| - Discrete and continuous action spaces | |
| - Observation and reward normalisation | |
| - LR, clip_epsilon, entropy_coef scheduling | |
| - Gradient clipping | |
| - Value function clipping | |
| - KL-divergence early stopping | |
| - Mixed precision (AMP) | |
| - Full checkpointing (save/resume) | |
| - TensorBoard + W&B + CSV logging | |
| - Pluggable policy backbones (MLP, Transformer, CNN, custom) | |
| - Vectorised parallel environments | |
| Usage: | |
| cfg = PPOConfig(env_id="CartPole-v1", total_timesteps=500_000) | |
| agent = PPOAgent(cfg) | |
| agent.learn() | |
| """ | |
| def __init__( | |
| self, | |
| cfg: PPOConfig, | |
| env_factory: Optional[Callable[[int], Any]] = None, | |
| ): | |
| """ | |
| Args: | |
| cfg : PPOConfig dataclass with all hyper-parameters. | |
| env_factory: Optional callable(idx: int) → gym.Env. | |
| If None, uses cfg.env_id with gymnasium. | |
| Provide this to use custom environments. | |
| """ | |
| self.cfg = cfg | |
| self.device = get_device(cfg.device) | |
| set_seed(cfg.seed) | |
| # ── Build environments ────────────────────────────────────────────── | |
| if env_factory is not None: | |
| # User-provided environment factory | |
| # ELI5: User knows best what game to play — use their setup | |
| env_fns = [lambda i=i: env_factory(i) for i in range(cfg.num_envs)] | |
| elif _GYM_AVAILABLE: | |
| # Standard Gymnasium environment | |
| env_fns = [ | |
| make_gymnasium_env(cfg.env_id, cfg.seed, i) | |
| for i in range(cfg.num_envs) | |
| ] | |
| else: | |
| raise RuntimeError( | |
| "No environment available. Install gymnasium: pip install gymnasium, " | |
| "or provide env_factory." | |
| ) | |
| self.envs = SyncVecEnv(env_fns) | |
| # ── Extract space dimensions ──────────────────────────────────────── | |
| obs_space = self.envs.observation_space | |
| act_space = self.envs.action_space | |
| # Observation shape (supports Box observations) | |
| # ELI5: "How many numbers describe the game state?" | |
| self.obs_shape = obs_space.shape | |
| self.obs_dim = int(np.prod(self.obs_shape)) | |
| # Action dimensions | |
| if cfg.action_space_type == "discrete": | |
| self.act_dim = act_space.n # Number of discrete choices | |
| self.act_shape = () | |
| else: | |
| self.act_dim = act_space.shape[0] # Dimension of continuous action | |
| self.act_shape = (self.act_dim,) | |
| logger.info( | |
| f"Environment: {cfg.env_id} | obs_shape={self.obs_shape} | " | |
| f"act_dim={self.act_dim} | action_type={cfg.action_space_type}" | |
| ) | |
| # ── Build Actor-Critic ────────────────────────────────────────────── | |
| self.policy = ActorCritic( | |
| obs_dim=self.obs_dim, | |
| act_dim=self.act_dim, | |
| cfg=cfg, | |
| ).to(self.device) | |
| total_params = sum(p.numel() for p in self.policy.parameters()) | |
| logger.info(f"Policy network: {total_params:,} parameters") | |
| # ── Build Optimizer ───────────────────────────────────────────────── | |
| # ELI5: The optimizer is the "learning mechanism" that adjusts weights. | |
| # Adam is most popular: it adapts the learning rate for each weight individually. | |
| optim_cls = { | |
| "adam": torch.optim.Adam, | |
| "adamw": torch.optim.AdamW, | |
| "sgd": torch.optim.SGD, | |
| }.get(cfg.optimizer.lower(), torch.optim.Adam) | |
| optim_kwargs = dict(lr=cfg.lr_actor, eps=cfg.adam_eps, weight_decay=cfg.weight_decay) | |
| if cfg.optimizer.lower() == "sgd": | |
| optim_kwargs = dict(lr=cfg.lr_actor, weight_decay=cfg.weight_decay, momentum=0.9) | |
| self.optimizer = optim_cls(self.policy.parameters(), **optim_kwargs) | |
| # ── Running Statistics ────────────────────────────────────────────── | |
| # ELI5: Track the "average" and "spread" of observations and rewards | |
| # so we can normalise them during training. | |
| self.obs_rms = RunningMeanStd(shape=self.obs_shape) if cfg.normalise_obs else None | |
| self.reward_rms = RunningMeanStd(shape=()) if cfg.normalise_rewards else None | |
| # ── Rollout Buffer ────────────────────────────────────────────────── | |
| self.buffer = RolloutBuffer( | |
| rollout_steps = cfg.rollout_steps, | |
| num_envs = cfg.num_envs, | |
| obs_shape = self.obs_shape, | |
| act_shape = self.act_shape, | |
| device = self.device, | |
| gae_lambda = cfg.gae_lambda, | |
| gamma = cfg.gamma, | |
| action_space_type = cfg.action_space_type, | |
| ) | |
| # ── Compute total update steps ────────────────────────────────────── | |
| # ELI5: How many times will we update the policy in total? | |
| # One update = collect rollout_steps × num_envs steps, then run N epochs | |
| self.steps_per_update = cfg.rollout_steps * cfg.num_envs | |
| self.total_updates = int(cfg.total_timesteps // self.steps_per_update) | |
| logger.info( | |
| f"Total timesteps: {cfg.total_timesteps:,} | " | |
| f"Updates: {self.total_updates} | " | |
| f"Steps/update: {self.steps_per_update:,}" | |
| ) | |
| # ── Scheduled Values ──────────────────────────────────────────────── | |
| self.lr_schedule = ScheduledValue(cfg.lr_actor, cfg.lr_schedule, self.total_updates) | |
| self.clip_schedule = ScheduledValue(cfg.clip_epsilon, cfg.clip_epsilon_schedule, self.total_updates) | |
| self.entropy_schedule = ScheduledValue(cfg.entropy_coef, cfg.entropy_coef_schedule, self.total_updates) | |
| # ── Mixed Precision ───────────────────────────────────────────────── | |
| # ELI5: AMP uses 16-bit floats where safe → up to 2× faster on GPUs | |
| self.use_amp = cfg.use_amp and self.device.type == "cuda" | |
| self.scaler = torch.cuda.amp.GradScaler(enabled=self.use_amp) | |
| if self.use_amp: | |
| logger.info("Mixed precision (AMP) enabled.") | |
| # ── Logger ────────────────────────────────────────────────────────── | |
| self.training_logger = TrainingLogger(cfg) | |
| # ── State variables ────────────────────────────────────────────────── | |
| self.global_step = 0 # Total env steps taken | |
| self.update_count = 0 # Number of policy updates performed | |
| self._ep_rewards = np.zeros(cfg.num_envs, dtype=np.float32) # Accumulate per-env episode reward | |
| self._ep_lengths = np.zeros(cfg.num_envs, dtype=np.int32) | |
| # ── Resume from checkpoint ────────────────────────────────────────── | |
| if cfg.resume_from: | |
| self.load_checkpoint(cfg.resume_from) | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # §10.1 OBSERVATION NORMALISATION | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def _normalise_obs(self, obs: np.ndarray, update_stats: bool = True) -> np.ndarray: | |
| """ | |
| Normalise observations using running mean/std. | |
| ELI5: If the ant robot's leg angles are in degrees (0-360) and velocities | |
| are in m/s (0-5), the network has a hard time because they're different scales. | |
| Normalising makes all inputs roughly zero-mean, unit-variance → easier to learn. | |
| Args: | |
| obs : Raw observation (num_envs, *obs_shape) | |
| update_stats: Whether to update the running stats (True during rollout, False during eval) | |
| Returns: | |
| Normalised observation array | |
| """ | |
| if self.obs_rms is None: | |
| return obs | |
| if update_stats: | |
| self.obs_rms.update(obs) | |
| return self.obs_rms.normalise(obs).astype(np.float32) | |
| def _normalise_rewards(self, rewards: np.ndarray, update_stats: bool = True) -> np.ndarray: | |
| """ | |
| Normalise rewards by running std (preserves sign). | |
| ELI5: If we clip and normalise rewards, sparse-reward environments | |
| (where most rewards are 0 with occasional +1) become easier to learn from. | |
| Note: We only divide by std (NOT subtract mean) to preserve reward sign. | |
| """ | |
| if self.reward_rms is None: | |
| return rewards | |
| if update_stats: | |
| self.reward_rms.update(rewards) | |
| normalised = rewards / (self.reward_rms.std + 1e-8) | |
| return np.clip(normalised, -self.cfg.reward_clip, self.cfg.reward_clip).astype(np.float32) | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # §10.2 ROLLOUT COLLECTION | |
| # ────────────────────────────────────────────────────────────────────────── | |
| @torch.no_grad() | |
| def collect_rollout(self, obs: np.ndarray) -> Tuple[np.ndarray, float]: | |
| """ | |
| Collect rollout_steps × num_envs environment transitions. | |
| ELI5: Play the game for rollout_steps steps using the CURRENT policy, | |
| writing down everything that happened into the buffer. | |
| We use torch.no_grad() because we're just collecting data — not training yet. | |
| Args: | |
| obs: Current observations (num_envs, *obs_shape) | |
| Returns: | |
| obs : Updated observations at the end of rollout | |
| ep_time: Time taken for the rollout (seconds) | |
| """ | |
| self.buffer.reset() | |
| t0 = time.time() | |
| for step in range(self.cfg.rollout_steps): | |
| self.global_step += self.cfg.num_envs | |
| # Normalise observations | |
| obs_norm = self._normalise_obs(obs, update_stats=True) | |
| obs_tensor = torch.FloatTensor(obs_norm).to(self.device) | |
| # Get action from policy (no gradient needed) | |
| with torch.cuda.amp.autocast(enabled=self.use_amp): | |
| action, log_prob, _, value = self.policy.get_action_and_value(obs_tensor) | |
| # Step ALL environments simultaneously | |
| action_np = action.cpu().numpy() | |
| if self.cfg.action_space_type == "discrete": | |
| action_np = action_np.astype(int) | |
| next_obs, reward, done, infos = self.envs.step(action_np) | |
| # Normalise rewards | |
| reward_norm = self._normalise_rewards(reward, update_stats=True) | |
| # Track episode statistics | |
| self._ep_rewards += reward | |
| self._ep_lengths += 1 | |
| for i, (d, info) in enumerate(zip(done, infos)): | |
| if d: | |
| self.training_logger.log_episode(self._ep_rewards[i], self._ep_lengths[i]) | |
| self._ep_rewards[i] = 0.0 | |
| self._ep_lengths[i] = 0 | |
| # Store in buffer | |
| self.buffer.add( | |
| obs=obs_tensor.cpu(), | |
| action=action.cpu(), | |
| reward=torch.FloatTensor(reward_norm), | |
| done=torch.FloatTensor(done), | |
| log_prob=log_prob.cpu(), | |
| value=value.cpu(), | |
| ) | |
| obs = next_obs | |
| # Bootstrap value for the last state (needed if episode isn't done) | |
| # ELI5: "What does the Critic think the state AFTER the last step is worth?" | |
| # This is used to compute advantages for the last few steps. | |
| obs_norm = self._normalise_obs(obs, update_stats=False) | |
| obs_tensor = torch.FloatTensor(obs_norm).to(self.device) | |
| with torch.cuda.amp.autocast(enabled=self.use_amp): | |
| last_value = self.policy.get_value(obs_tensor).squeeze(-1).cpu() | |
| last_done = torch.FloatTensor(done) | |
| self.buffer.compute_returns_and_advantages(last_value, last_done) | |
| return obs, time.time() - t0 | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # §10.3 PPO LOSS & UPDATE | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def _compute_ppo_loss( | |
| self, | |
| mb: Dict[str, torch.Tensor], | |
| clip_epsilon: float, | |
| entropy_coef: float, | |
| ) -> Tuple[torch.Tensor, Dict[str, float]]: | |
| """ | |
| Compute the full PPO loss for one mini-batch. | |
| ELI5 of PPO loss: | |
| 1. ACTOR LOSS (Clipped Surrogate Objective): | |
| - Ratio r_t(θ) = π_new(a|s) / π_old(a|s) → in log space: exp(new_logprob - old_logprob) | |
| - "The ratio of NEW to OLD probability of taking action a in state s." | |
| - Unclipped objective: r_t * A_t (if better action, use it more) | |
| - Clipped objective: clip(r_t, 1-ε, 1+ε) * A_t (don't change policy too much) | |
| - Take the MIN of both → pessimistic bound → prevents over-optimistic updates | |
| - Negate it (because PyTorch minimises, but we want to maximise reward) | |
| 2. CRITIC LOSS (Value Function MSE): | |
| - Predict state value V(s), minimise (V(s) - actual_return)² | |
| - Optional: also clip value updates (like PPO clips policy updates) | |
| 3. ENTROPY BONUS: | |
| - Add entropy H[π] to encourage exploration | |
| - Entropy is "how random/exploratory is the policy right now?" | |
| - High entropy → exploring more options. We want to maintain some. | |
| TOTAL LOSS = -actor_loss + value_coef * critic_loss - entropy_coef * entropy | |
| (minus signs because we MAXIMISE actor obj and entropy, but MINIMISE losses) | |
| Args: | |
| mb : Mini-batch dictionary from RolloutBuffer.get_batches() | |
| clip_epsilon: Current clipping parameter value (may be scheduled) | |
| entropy_coef: Current entropy coefficient (may be scheduled) | |
| Returns: | |
| loss : Combined PPO loss scalar | |
| stats : Dictionary of loss components for logging | |
| """ | |
| obs = mb["obs"] | |
| actions = mb["actions"] | |
| old_lp = mb["old_log_probs"] | |
| advantages = mb["advantages"] | |
| returns = mb["returns"] | |
| old_values = mb["old_values"] | |
| # Normalise advantages within this mini-batch | |
| # ELI5: Make advantages zero-mean and unit-variance so gradient magnitudes | |
| # are consistent regardless of the reward scale. | |
| if self.cfg.norm_adv: | |
| advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) | |
| # For discrete, actions must be integer type for log_prob | |
| if self.cfg.action_space_type == "discrete": | |
| actions = actions.long() | |
| # Get new log-probs, entropy, and values from the CURRENT policy | |
| # ELI5: "Under the CURRENT (updated) policy, how likely would we have taken | |
| # that action in that state?" | |
| _, new_log_probs, entropy, new_values = self.policy.get_action_and_value(obs, actions) | |
| # ── Clipped Surrogate (Actor) Loss ───────────────────────────────── | |
| # Probability ratio: how much has the policy changed for this action? | |
| # ELI5: If old π gave 10% chance and new π gives 15%, ratio = 1.5 | |
| log_ratio = new_log_probs - old_lp | |
| ratio = log_ratio.exp() | |
| # For monitoring: approximate KL divergence between old and new policy | |
| # KL ≈ (ratio - 1) - log(ratio) [first-order approximation] | |
| # ELI5: "How different is the new policy from the old one?" | |
| with torch.no_grad(): | |
| approx_kl = ((ratio - 1.0) - log_ratio).mean().item() | |
| # Two versions of the surrogate objective: | |
| surr1 = ratio * advantages # Unclipped | |
| surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages # Clipped | |
| # PPO actor loss: take the pessimistic (minimum) of both | |
| # Negate because we MAXIMISE this (but torch minimises → negate) | |
| # ELI5: "Take whichever update is more conservative to prevent going too far" | |
| actor_loss = -torch.min(surr1, surr2).mean() | |
| # Proportion of updates where clipping was active (diagnostic) | |
| clip_fraction = ((ratio - 1.0).abs() > clip_epsilon).float().mean().item() | |
| # ── Value Function (Critic) Loss ─────────────────────────────────── | |
| if self.cfg.value_clip_epsilon is not None: | |
| # Clipped value loss: don't let the value estimate jump too far from old estimate | |
| # ELI5: Same idea as PPO's clipped policy — don't let Critic change too aggressively | |
| value_clipped = old_values + torch.clamp( | |
| new_values - old_values, | |
| -self.cfg.value_clip_epsilon, | |
| +self.cfg.value_clip_epsilon, | |
| ) | |
| v_loss_unclipped = (new_values - returns) ** 2 | |
| v_loss_clipped = (value_clipped - returns) ** 2 | |
| critic_loss = 0.5 * torch.max(v_loss_unclipped, v_loss_clipped).mean() | |
| else: | |
| # Standard MSE loss (no value clipping) | |
| critic_loss = 0.5 * F.mse_loss(new_values, returns) | |
| # ── Entropy Bonus ─────────────────────────────────────────────────── | |
| # ELI5: Reward the agent for being "uncertain" → keeps it exploring | |
| entropy_loss = -entropy.mean() # Negate because we want to MAXIMISE entropy | |
| # ── Combined Loss ─────────────────────────────────────────────────── | |
| loss = ( | |
| actor_loss | |
| + self.cfg.value_coef * critic_loss | |
| + entropy_coef * entropy_loss | |
| ) | |
| stats = { | |
| "train/actor_loss" : actor_loss.item(), | |
| "train/critic_loss" : critic_loss.item(), | |
| "train/entropy" : -entropy_loss.item(), | |
| "train/approx_kl" : approx_kl, | |
| "train/clip_fraction": clip_fraction, | |
| "train/total_loss" : loss.item(), | |
| } | |
| return loss, stats | |
| def update_policy(self) -> Dict[str, float]: | |
| """ | |
| Perform N epochs of PPO gradient updates on collected rollout data. | |
| ELI5: We collected a big batch of experiences. Now we re-read them | |
| multiple times (epochs), each time updating the Actor and Critic | |
| a little bit. We stop early if the policy changed too much (KL check). | |
| Returns: | |
| Dictionary of averaged metrics across all mini-batch updates | |
| """ | |
| cfg = self.cfg | |
| # Get current scheduled values | |
| clip_epsilon = self.clip_schedule(self.update_count) | |
| entropy_coef = self.entropy_schedule(self.update_count) | |
| # Update learning rate | |
| lr = self.lr_schedule(self.update_count) | |
| for pg in self.optimizer.param_groups: | |
| pg["lr"] = lr | |
| all_stats: List[Dict[str, float]] = [] | |
| early_stopped = False | |
| for epoch in range(cfg.n_epochs): | |
| for mb in self.buffer.get_batches(cfg.minibatch_size): | |
| with torch.cuda.amp.autocast(enabled=self.use_amp): | |
| loss, stats = self._compute_ppo_loss(mb, clip_epsilon, entropy_coef) | |
| all_stats.append(stats) | |
| # KL early stopping: stop training if policy changed too much | |
| # ELI5: "The gap between old and new policy is getting too wide — stop now!" | |
| if cfg.target_kl is not None and stats["train/approx_kl"] > 1.5 * cfg.target_kl: | |
| logger.debug( | |
| f"Early stopping at epoch {epoch} due to KL {stats['train/approx_kl']:.4f} " | |
| f"> target {1.5 * cfg.target_kl:.4f}" | |
| ) | |
| early_stopped = True | |
| break | |
| # Gradient update | |
| self.optimizer.zero_grad(set_to_none=True) # More efficient than zero_grad() | |
| # AMP backward pass: scales loss to prevent underflow in float16 | |
| self.scaler.scale(loss).backward() | |
| # Gradient clipping: prevent huge gradient steps from destabilising training | |
| # ELI5: "If the learning signal is shouting too loudly, turn it down" | |
| self.scaler.unscale_(self.optimizer) | |
| nn.utils.clip_grad_norm_(self.policy.parameters(), cfg.max_grad_norm) | |
| # Apply gradients (using scaler for AMP) | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| if early_stopped: | |
| break | |
| # Average all logged stats across mini-batches | |
| avg_stats: Dict[str, float] = {} | |
| for k in all_stats[0]: | |
| avg_stats[k] = float(np.mean([s[k] for s in all_stats])) | |
| avg_stats["train/learning_rate"] = lr | |
| avg_stats["train/clip_epsilon"] = clip_epsilon | |
| avg_stats["train/entropy_coef"] = entropy_coef | |
| avg_stats["train/early_stopped"] = float(early_stopped) | |
| return avg_stats | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # §10.4 MAIN TRAINING LOOP | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def learn(self) -> "PPOAgent": | |
| """ | |
| Main PPO training loop. | |
| ELI5 of the full loop: | |
| 1. Reset all game copies. | |
| 2. Play the game for rollout_steps steps (collect experience). | |
| 3. Compute how good each action was (GAE advantages + returns). | |
| 4. Update Actor and Critic using PPO loss for n_epochs. | |
| 5. Log metrics, save checkpoint if needed. | |
| 6. Repeat from step 2 until total_timesteps reached. | |
| Returns: | |
| self (for method chaining) | |
| """ | |
| logger.info( | |
| f"Starting PPO training: {self.cfg.env_id} | " | |
| f"{self.cfg.total_timesteps:,} steps | device={self.device}" | |
| ) | |
| # ── Initial reset ─────────────────────────────────────────────────── | |
| obs = self.envs.reset() # (num_envs, *obs_shape) | |
| start_time = time.time() | |
| for update in range(self.update_count, self.total_updates): | |
| self.update_count = update | |
| # ── 1. Collect rollout ────────────────────────────────────────── | |
| obs, rollout_time = self.collect_rollout(obs) | |
| # ── 2. Update policy ──────────────────────────────────────────── | |
| update_t0 = time.time() | |
| self.policy.train() # Switch to training mode (enables dropout, etc.) | |
| stats = self.update_policy() | |
| self.policy.eval() # Switch back to eval mode for rollout | |
| update_time = time.time() - update_t0 | |
| # ── 3. Compute throughput ─────────────────────────────────────── | |
| elapsed = time.time() - start_time | |
| sps = self.global_step / max(elapsed, 1e-6) # Steps per second | |
| progress = self.global_step / self.cfg.total_timesteps * 100 | |
| # ── 4. Log metrics ────────────────────────────────────────────── | |
| if update % self.cfg.log_interval == 0: | |
| stats.update({ | |
| "charts/global_step" : self.global_step, | |
| "charts/mean_reward" : self.training_logger.mean_reward, | |
| "charts/mean_ep_length": self.training_logger.mean_length, | |
| "charts/steps_per_sec": sps, | |
| "charts/rollout_time" : rollout_time, | |
| "charts/update_time" : update_time, | |
| "charts/progress_pct" : progress, | |
| }) | |
| self.training_logger.log_step(self.global_step, stats) | |
| logger.info( | |
| f"[{progress:5.1f}%] step={self.global_step:,} | " | |
| f"reward={self.training_logger.mean_reward:7.2f} | " | |
| f"actor_loss={stats['train/actor_loss']:+.4f} | " | |
| f"critic_loss={stats['train/critic_loss']:.4f} | " | |
| f"entropy={stats['train/entropy']:.4f} | " | |
| f"kl={stats['train/approx_kl']:.4f} | " | |
| f"sps={sps:.0f}" | |
| ) | |
| # ── 5. Save checkpoint ────────────────────────────────────────── | |
| if update % self.cfg.save_interval == 0: | |
| ckpt_path = ( | |
| Path(self.cfg.log_dir) | |
| / self.cfg.experiment_name | |
| / f"checkpoint_step{self.global_step}.pt" | |
| ) | |
| self.save_checkpoint(str(ckpt_path)) | |
| # ── Final save ────────────────────────────────────────────────────── | |
| final_path = ( | |
| Path(self.cfg.log_dir) | |
| / self.cfg.experiment_name | |
| / "checkpoint_final.pt" | |
| ) | |
| self.save_checkpoint(str(final_path)) | |
| self.training_logger.close() | |
| self.envs.close() | |
| logger.info( | |
| f"Training complete! Final mean reward: {self.training_logger.mean_reward:.2f} | " | |
| f"Total time: {(time.time() - start_time) / 60:.1f} min" | |
| ) | |
| return self | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # §10.5 EVALUATION | |
| # ────────────────────────────────────────────────────────────────────────── | |
| @torch.no_grad() | |
| def evaluate( | |
| self, | |
| n_episodes: int = 10, | |
| max_steps: int = 1000, | |
| env_factory: Optional[Callable] = None, | |
| render: bool = False, | |
| ) -> Dict[str, float]: | |
| """ | |
| Evaluate the current policy for n_episodes. | |
| ELI5: After training, test how well the agent performs. | |
| No randomness from PPO — just the best action (deterministic / greedy). | |
| Args: | |
| n_episodes : Number of test episodes | |
| max_steps : Max steps per episode | |
| env_factory: Optional custom env factory for evaluation | |
| render : Whether to render the environment visually | |
| Returns: | |
| Dictionary with mean/std of rewards and episode lengths | |
| """ | |
| self.policy.eval() | |
| if env_factory is not None: | |
| eval_env = env_factory(999) | |
| elif _GYM_AVAILABLE: | |
| kwargs = {"render_mode": "human"} if render else {} | |
| eval_env = gym.make(self.cfg.env_id, **kwargs) | |
| else: | |
| raise RuntimeError("No environment available for evaluation.") | |
| ep_rewards, ep_lengths = [], [] | |
| for ep in range(n_episodes): | |
| result = eval_env.reset() | |
| obs = result[0] if isinstance(result, tuple) else result | |
| ep_r, ep_l = 0.0, 0 | |
| done = False | |
| while not done and ep_l < max_steps: | |
| obs_norm = self._normalise_obs(np.array([obs]), update_stats=False)[0] | |
| obs_tensor = torch.FloatTensor(obs_norm).unsqueeze(0).to(self.device) | |
| # Deterministic action: sample from policy (which is near-deterministic after training) | |
| # For truly greedy: use argmax for discrete, mean for continuous | |
| action, _, _, _ = self.policy.get_action_and_value(obs_tensor) | |
| action_val = action.cpu().numpy()[0] | |
| if self.cfg.action_space_type == "discrete": | |
| action_val = int(action_val) | |
| step_result = eval_env.step(action_val) | |
| if len(step_result) == 5: | |
| obs, r, term, trunc, _ = step_result | |
| done = term or trunc | |
| else: | |
| obs, r, done, _ = step_result | |
| ep_r += r | |
| ep_l += 1 | |
| ep_rewards.append(ep_r) | |
| ep_lengths.append(ep_l) | |
| logger.info(f" Eval episode {ep+1}/{n_episodes}: reward={ep_r:.1f}, length={ep_l}") | |
| eval_env.close() | |
| results = { | |
| "eval/mean_reward" : float(np.mean(ep_rewards)), | |
| "eval/std_reward" : float(np.std(ep_rewards)), | |
| "eval/min_reward" : float(np.min(ep_rewards)), | |
| "eval/max_reward" : float(np.max(ep_rewards)), | |
| "eval/mean_length" : float(np.mean(ep_lengths)), | |
| } | |
| logger.info( | |
| f"Evaluation ({n_episodes} episodes): " | |
| f"mean={results['eval/mean_reward']:.2f} ± {results['eval/std_reward']:.2f} | " | |
| f"range=[{results['eval/min_reward']:.1f}, {results['eval/max_reward']:.1f}]" | |
| ) | |
| return results | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # §10.6 CHECKPOINT SAVE / LOAD | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def save_checkpoint(self, path: str) -> None: | |
| """ | |
| Save complete training state to disk. | |
| ELI5: Like saving a game mid-play. Stores EVERYTHING: | |
| - Network weights (actor + critic) | |
| - Optimizer state (so Adam's momentum is preserved) | |
| - Running stats (obs and reward normalisers) | |
| - Training counters (where we were in training) | |
| This allows FULL resumption of training later. | |
| Args: | |
| path: Full path to save the checkpoint file (.pt) | |
| """ | |
| Path(path).parent.mkdir(parents=True, exist_ok=True) | |
| checkpoint = { | |
| "version" : "PPO.RL.v1", | |
| "global_step" : self.global_step, | |
| "update_count" : self.update_count, | |
| "policy_state_dict" : self.policy.state_dict(), | |
| "optimizer_state" : self.optimizer.state_dict(), | |
| "scaler_state" : self.scaler.state_dict(), | |
| "obs_rms" : self.obs_rms.state_dict() if self.obs_rms else None, | |
| "reward_rms" : self.reward_rms.state_dict() if self.reward_rms else None, | |
| "config" : self.cfg.__dict__, | |
| } | |
| torch.save(checkpoint, path) | |
| logger.info(f"Checkpoint saved → {path}") | |
| def load_checkpoint(self, path: str) -> None: | |
| """ | |
| Load a previously saved checkpoint to resume training. | |
| ELI5: Load a saved game. Restores the network weights, optimizer state, | |
| and training counters so training continues exactly where it left off. | |
| Args: | |
| path: Path to checkpoint file (.pt) | |
| """ | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Checkpoint not found: {path}") | |
| checkpoint = torch.load(path, map_location=self.device) | |
| self.policy.load_state_dict(checkpoint["policy_state_dict"]) | |
| self.optimizer.load_state_dict(checkpoint["optimizer_state"]) | |
| self.scaler.load_state_dict(checkpoint["scaler_state"]) | |
| self.global_step = checkpoint["global_step"] | |
| self.update_count = checkpoint["update_count"] | |
| if self.obs_rms and checkpoint["obs_rms"]: | |
| self.obs_rms.load_state_dict(checkpoint["obs_rms"]) | |
| if self.reward_rms and checkpoint["reward_rms"]: | |
| self.reward_rms.load_state_dict(checkpoint["reward_rms"]) | |
| logger.info( | |
| f"Checkpoint loaded from {path} " | |
| f"(step={self.global_step}, updates={self.update_count})" | |
| ) | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # §10.7 CONVENIENCE — SAVE / LOAD POLICY ONLY | |
| # ────────────────────────────────────────────────────────────────────────── | |
| def save_policy(self, path: str) -> None: | |
| """ | |
| Save only the policy network weights (for deployment / inference). | |
| ELI5: If you just want to USE the trained agent (not continue training), | |
| you only need the network weights — not the optimizer state. | |
| This makes a much smaller file. | |
| """ | |
| torch.save({ | |
| "policy_state_dict": self.policy.state_dict(), | |
| "obs_rms" : self.obs_rms.state_dict() if self.obs_rms else None, | |
| "config" : self.cfg.__dict__, | |
| }, path) | |
| logger.info(f"Policy saved → {path}") | |
| def load_policy(self, path: str) -> None: | |
| """Load only policy weights (for inference/evaluation after training).""" | |
| data = torch.load(path, map_location=self.device) | |
| self.policy.load_state_dict(data["policy_state_dict"]) | |
| if self.obs_rms and data.get("obs_rms"): | |
| self.obs_rms.load_state_dict(data["obs_rms"]) | |
| logger.info(f"Policy weights loaded from {path}") | |
| # ────────────────────────────────────────────────────────────────────────── | |
| # §10.8 PREDICT (INFERENCE) — Deploy the trained policy | |
| # ────────────────────────────────────────────────────────────────────────── | |
| @torch.no_grad() | |
| def predict( | |
| self, | |
| obs: Union[np.ndarray, torch.Tensor], | |
| deterministic: bool = True, | |
| ) -> np.ndarray: | |
| """ | |
| Predict an action for a given observation (inference mode). | |
| ELI5: After training is done, use this to let the agent "play" in a new situation. | |
| deterministic=True → always pick the most likely action (greedy, no randomness). | |
| deterministic=False → sample from the distribution (stochastic, like during training). | |
| Args: | |
| obs : Single or batched observation | |
| deterministic: If True, argmax for discrete / mean for continuous | |
| Returns: | |
| action: Predicted action as numpy array | |
| """ | |
| self.policy.eval() | |
| if isinstance(obs, np.ndarray): | |
| obs = self._normalise_obs(obs, update_stats=False) | |
| obs = torch.FloatTensor(obs) | |
| if obs.dim() == 1: | |
| obs = obs.unsqueeze(0) # Add batch dimension | |
| obs = obs.to(self.device) | |
| if deterministic: | |
| if self.cfg.action_space_type == "discrete": | |
| # Deterministic: pick action with highest logit | |
| actor_feats, _ = self.policy._get_features(obs) | |
| logits = self.policy.actor_head(actor_feats) | |
| action = logits.argmax(dim=-1) | |
| else: | |
| # Deterministic: use the mean of the Gaussian (no sampling) | |
| actor_feats, _ = self.policy._get_features(obs) | |
| action = self.policy.actor_mean(actor_feats) | |
| if self.cfg.squash_actions: | |
| action = torch.tanh(action) | |
| else: | |
| # Stochastic: sample from the distribution | |
| action, _, _, _ = self.policy.get_action_and_value(obs) | |
| return action.cpu().numpy() | |
| def __repr__(self) -> str: | |
| return ( | |
| f"PPOAgent(env={self.cfg.env_id}, " | |
| f"device={self.device}, " | |
| f"action_type={self.cfg.action_space_type}, " | |
| f"total_timesteps={self.cfg.total_timesteps:,})" | |
| ) | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §11 CONVENIENCE HELPERS — QUICK-START FUNCTIONS | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| def train_discrete( | |
| env_id: str = "CartPole-v1", | |
| total_timesteps: int = 500_000, | |
| **kwargs, | |
| ) -> PPOAgent: | |
| """ | |
| Quick-start PPO training for DISCRETE action space environments. | |
| ELI5: One function call to train a PPO agent on any discrete gym environment. | |
| Good for: CartPole, LunarLander, Atari (with appropriate backbone). | |
| Args: | |
| env_id : Gymnasium environment ID | |
| total_timesteps : Training budget | |
| **kwargs : Any PPOConfig fields to override | |
| Returns: | |
| Trained PPOAgent | |
| """ | |
| cfg = PPOConfig( | |
| env_id = env_id, | |
| action_space_type = "discrete", | |
| total_timesteps = total_timesteps, | |
| **kwargs, | |
| ) | |
| agent = PPOAgent(cfg) | |
| return agent.learn() | |
| def train_continuous( | |
| env_id: str = "HalfCheetah-v4", | |
| total_timesteps: int = 1_000_000, | |
| **kwargs, | |
| ) -> PPOAgent: | |
| """ | |
| Quick-start PPO training for CONTINUOUS action space environments. | |
| ELI5: One function call to train on any robot/physics sim environment. | |
| Good for: MuJoCo (HalfCheetah, Hopper, Ant), robotics, real-world control. | |
| Args: | |
| env_id : Gymnasium environment ID | |
| total_timesteps : Training budget | |
| **kwargs : Any PPOConfig fields to override | |
| Returns: | |
| Trained PPOAgent | |
| """ | |
| cfg = PPOConfig( | |
| env_id = env_id, | |
| action_space_type = "continuous", | |
| total_timesteps = total_timesteps, | |
| hidden_sizes = (256, 256), | |
| lr_actor = 3e-4, | |
| lr_critic = 1e-3, | |
| rollout_steps = 2048, | |
| n_epochs = 10, | |
| minibatch_size = 64, | |
| **kwargs, | |
| ) | |
| agent = PPOAgent(cfg) | |
| return agent.learn() | |
| def train_with_transformer_backbone( | |
| env_id: str = "CartPole-v1", | |
| total_timesteps: int = 500_000, | |
| **kwargs, | |
| ) -> PPOAgent: | |
| """ | |
| Quick-start PPO with Transformer policy backbone. | |
| ELI5: Use a Transformer (the same architecture as GPT/BERT) as the policy | |
| backbone. Good when observations have sequential or relational structure. | |
| Args: | |
| env_id : Gymnasium environment ID | |
| total_timesteps : Training budget | |
| **kwargs : Any PPOConfig fields to override | |
| Returns: | |
| Trained PPOAgent | |
| """ | |
| cfg = PPOConfig( | |
| env_id = env_id, | |
| backbone_cls = TransformerPolicyBackbone, | |
| total_timesteps = total_timesteps, | |
| **kwargs, | |
| ) | |
| agent = PPOAgent(cfg) | |
| return agent.learn() | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| # §12 ENTRY POINT — Demo / Smoke Test | |
| # ══════════════════════════════════════════════════════════════════════════════ | |
| if __name__ == "__main__": | |
| """ | |
| Demo: Train PPO on CartPole-v1 (should solve in ~200k steps). | |
| ELI5: CartPole is a "balance the stick on the cart" game. | |
| The agent learns to push the cart left/right to keep the pole upright. | |
| A perfect agent can balance it indefinitely (score = 500 = max). | |
| """ | |
| # ── Example 1: Discrete (CartPole) ────────────────────────────────────── | |
| print("\n" + "="*70) | |
| print(" PPO.RL.py — Production PPO Demo") | |
| print("="*70) | |
| print("\nExample 1: Discrete action space — CartPole-v1") | |
| cfg_discrete = PPOConfig( | |
| env_id = "CartPole-v1", | |
| action_space_type = "discrete", | |
| total_timesteps = 200_000, | |
| num_envs = 4, | |
| rollout_steps = 512, | |
| n_epochs = 10, | |
| minibatch_size = 64, | |
| hidden_sizes = (64, 64), | |
| lr_actor = 2.5e-4, | |
| lr_schedule = "linear", | |
| gamma = 0.99, | |
| gae_lambda = 0.95, | |
| clip_epsilon = 0.2, | |
| entropy_coef = 0.01, | |
| normalise_obs = True, | |
| normalise_rewards = False, # CartPole rewards are already {0, 1} | |
| log_dir = "./ppo_logs", | |
| experiment_name = "cartpole_demo", | |
| log_interval = 5, | |
| save_interval = 20, | |
| seed = 42, | |
| ) | |
| agent_discrete = PPOAgent(cfg_discrete) | |
| agent_discrete.learn() | |
| print("\nEvaluating trained CartPole agent...") | |
| results = agent_discrete.evaluate(n_episodes=10, max_steps=500) | |
| print(f" Mean reward: {results['eval/mean_reward']:.1f} / 500.0") | |
| # ── Example 2: Continuous (simple env if available) ────────────────────── | |
| print("\n" + "-"*70) | |
| print("Example 2: Continuous action space — Pendulum-v1") | |
| cfg_cont = PPOConfig( | |
| env_id = "Pendulum-v1", | |
| action_space_type = "continuous", | |
| total_timesteps = 100_000, | |
| num_envs = 4, | |
| rollout_steps = 512, | |
| n_epochs = 10, | |
| minibatch_size = 64, | |
| hidden_sizes = (64, 64), | |
| lr_actor = 3e-4, | |
| lr_critic = 1e-3, | |
| gamma = 0.99, | |
| gae_lambda = 0.95, | |
| clip_epsilon = 0.2, | |
| entropy_coef = 0.0, | |
| normalise_obs = True, | |
| normalise_rewards = True, | |
| continuous_dist = "gaussian", | |
| log_dir = "./ppo_logs", | |
| experiment_name = "pendulum_demo", | |
| log_interval = 5, | |
| save_interval = 20, | |
| seed = 42, | |
| ) | |
| try: | |
| agent_cont = PPOAgent(cfg_cont) | |
| agent_cont.learn() | |
| results_cont = agent_cont.evaluate(n_episodes=5) | |
| print(f" Pendulum mean reward: {results_cont['eval/mean_reward']:.1f}") | |
| except Exception as e: | |
| print(f" Skipped (env may not be installed): {e}") | |
| # ── Example 3: Transformer Backbone ───────────────────────────────────── | |
| print("\n" + "-"*70) | |
| print("Example 3: Transformer backbone — CartPole-v1") | |
| cfg_tf = PPOConfig( | |
| env_id = "CartPole-v1", | |
| action_space_type = "discrete", | |
| total_timesteps = 100_000, | |
| num_envs = 4, | |
| rollout_steps = 512, | |
| backbone_cls = TransformerPolicyBackbone, | |
| log_dir = "./ppo_logs", | |
| experiment_name = "cartpole_transformer", | |
| log_interval = 5, | |
| save_interval = 20, | |
| seed = 0, | |
| ) | |
| try: | |
| agent_tf = PPOAgent(cfg_tf) | |
| agent_tf.learn() | |
| results_tf = agent_tf.evaluate(n_episodes=5) | |
| print(f" Transformer backbone mean reward: {results_tf['eval/mean_reward']:.1f}") | |
| except Exception as e: | |
| print(f" Skipped: {e}") | |
| print("\n" + "="*70) | |
| print(" All demos complete. Check ./ppo_logs for TensorBoard & CSV logs.") | |
| print(" Run: tensorboard --logdir ./ppo_logs") | |
| print("="*70) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment