Created
March 18, 2026 21:42
-
-
Save masterdezign/b761afb98239c3bf8c823dc5ed0a0903 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Optuna hyperparameter search for RecurrentSAC on POMDP benchmarks. | |
| Environments (NoVel = velocity observations removed, making them POMDPs): | |
| - PendulumNoVel-v1 | |
| - MountainCarContinuousNoVel-v0 | |
| - LunarLanderContinuousNoVel-v3 (requires Box2D / swig) | |
| Usage: | |
| python scripts/tune_rsac.py --env PendulumNoVel-v1 --n-trials 50 --n-timesteps 100000 | |
| python scripts/tune_rsac.py --env MountainCarContinuousNoVel-v0 --n-trials 30 --n-timesteps 300000 | |
| """ | |
| import argparse | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import numpy as np | |
| import optuna | |
| import rl_zoo3.import_envs # noqa: F401 — registers NoVel envs | |
| from stable_baselines3.common.env_util import make_vec_env | |
| from stable_baselines3.common.evaluation import evaluate_policy | |
| from stable_baselines3.common.vec_env import VecNormalize | |
| from sb3_contrib import RecurrentSAC | |
| # ── Evaluation settings ──────────────────────────────────────────────────────── | |
| N_EVAL_EPISODES = 10 | |
| EVAL_FREQ = None # evaluate only at the end of training | |
| def make_env(env_id: str, n_envs: int, seed: int = 0) -> VecNormalize: | |
| env = make_vec_env(env_id, n_envs=n_envs, seed=seed) | |
| return VecNormalize(env, norm_obs=True, norm_reward=True) | |
| def sample_params(trial: optuna.Trial, env_id: str) -> dict: | |
| """Sample hyperparameters for a trial.""" | |
| learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True) | |
| gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.98, 0.99, 0.999]) | |
| tau = trial.suggest_categorical("tau", [0.005, 0.01, 0.02]) | |
| batch_size = trial.suggest_categorical("batch_size", [64, 128, 256]) | |
| ent_coef = trial.suggest_categorical("ent_coef", ["auto", 0.001, 0.01, 0.1]) | |
| n_envs = trial.suggest_categorical("n_envs", [1, 2, 4]) | |
| # Recurrent-specific | |
| segment_len = trial.suggest_categorical("segment_len", [20, 50, 100]) | |
| overlap_frac = trial.suggest_float("overlap_frac", 0.0, 0.4) | |
| overlap = max(0, int(segment_len * overlap_frac)) | |
| burn_in_frac = trial.suggest_float("burn_in_frac", 0.0, 0.3) | |
| burn_in = int(segment_len * burn_in_frac) | |
| shared_state = trial.suggest_categorical("shared_state", [True, False]) | |
| lstm_hidden_size = trial.suggest_categorical("lstm_hidden_size", [32, 64, 128]) | |
| net_arch_depth = trial.suggest_categorical("net_arch_depth", [1, 2]) | |
| net_arch_width = trial.suggest_categorical("net_arch_width", [64, 128]) | |
| net_arch = [net_arch_width] * net_arch_depth | |
| trial.set_user_attr("overlap", overlap) | |
| trial.set_user_attr("burn_in", burn_in) | |
| return dict( | |
| learning_rate=learning_rate, | |
| gamma=gamma, | |
| tau=tau, | |
| batch_size=batch_size, | |
| ent_coef=ent_coef, | |
| n_envs=n_envs, | |
| segment_len=segment_len, | |
| overlap=overlap, | |
| burn_in=burn_in, | |
| shared_state=shared_state, | |
| policy_kwargs=dict(net_arch=net_arch, lstm_hidden_size=lstm_hidden_size), | |
| ) | |
| def objective(trial: optuna.Trial, env_id: str, n_timesteps: int, seed: int = 0) -> float: | |
| params = sample_params(trial, env_id) | |
| n_envs = params.pop("n_envs") | |
| # buffer_size in chunks; roughly enough for ~200k transitions at segment_len=50 | |
| buffer_size = max(512, 200_000 // params["segment_len"]) | |
| learning_starts = max(params["batch_size"] * 2, 1000) | |
| try: | |
| train_env = make_env(env_id, n_envs=n_envs, seed=seed) | |
| eval_env = make_env(env_id, n_envs=1, seed=seed + 1000) | |
| model = RecurrentSAC( | |
| "MlpLstmPolicy", | |
| train_env, | |
| buffer_size=buffer_size, | |
| learning_starts=learning_starts, | |
| verbose=0, | |
| seed=seed, | |
| **params, | |
| ) | |
| model.learn(total_timesteps=n_timesteps) | |
| # Sync VecNormalize stats to eval env before evaluation | |
| eval_env.obs_rms = train_env.obs_rms | |
| eval_env.ret_rms = train_env.ret_rms | |
| eval_env.training = False | |
| eval_env.norm_reward = False | |
| mean_reward, std_reward = evaluate_policy( | |
| model, eval_env, n_eval_episodes=N_EVAL_EPISODES, deterministic=True | |
| ) | |
| train_env.close() | |
| eval_env.close() | |
| except Exception as e: | |
| print(f" Trial {trial.number} failed: {e}") | |
| return float("-inf") | |
| print(f" Trial {trial.number}: mean_reward={mean_reward:.2f} ± {std_reward:.2f}") | |
| return mean_reward | |
| def run_study(env_id: str, n_trials: int, n_timesteps: int, seed: int = 0) -> optuna.Study: | |
| sampler = optuna.samplers.TPESampler(seed=seed) | |
| pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=0) | |
| study = optuna.create_study( | |
| direction="maximize", | |
| sampler=sampler, | |
| pruner=pruner, | |
| study_name=f"rsac_{env_id}", | |
| ) | |
| study.optimize( | |
| lambda trial: objective(trial, env_id, n_timesteps, seed), | |
| n_trials=n_trials, | |
| show_progress_bar=False, | |
| ) | |
| return study | |
| def print_results(study: optuna.Study, env_id: str) -> None: | |
| best = study.best_trial | |
| print(f"\n{'=' * 60}") | |
| print(f"Best result for {env_id}") | |
| print(f"{'=' * 60}") | |
| print(f" Value (mean reward): {best.value:.2f}") | |
| print(f" Params:") | |
| for k, v in best.params.items(): | |
| print(f" {k}: {v}") | |
| for k, v in best.user_attrs.items(): | |
| print(f" {k}: {v} [derived]") | |
| print() | |
| # Top-5 trials | |
| trials = sorted([t for t in study.trials if t.value is not None and t.value > float("-inf")], key=lambda t: t.value, reverse=True) | |
| print(f" Top-5 trials:") | |
| for t in trials[:5]: | |
| print(f" Trial {t.number:3d}: {t.value:.2f}") | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--env", type=str, default="PendulumNoVel-v1") | |
| parser.add_argument("--n-trials", type=int, default=30) | |
| parser.add_argument("--n-timesteps", type=int, default=100_000) | |
| parser.add_argument("--seed", type=int, default=0) | |
| args = parser.parse_args() | |
| optuna.logging.set_verbosity(optuna.logging.WARNING) | |
| print(f"Tuning RecurrentSAC on {args.env}") | |
| print(f" n_trials={args.n_trials}, n_timesteps={args.n_timesteps}, seed={args.seed}") | |
| study = run_study(args.env, args.n_trials, args.n_timesteps, args.seed) | |
| print_results(study, args.env) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment