Skip to content

Instantly share code, notes, and snippets.

@masterdezign
Created March 18, 2026 21:42
Show Gist options
  • Select an option

  • Save masterdezign/b761afb98239c3bf8c823dc5ed0a0903 to your computer and use it in GitHub Desktop.

Select an option

Save masterdezign/b761afb98239c3bf8c823dc5ed0a0903 to your computer and use it in GitHub Desktop.
"""
Optuna hyperparameter search for RecurrentSAC on POMDP benchmarks.
Environments (NoVel = velocity observations removed, making them POMDPs):
- PendulumNoVel-v1
- MountainCarContinuousNoVel-v0
- LunarLanderContinuousNoVel-v3 (requires Box2D / swig)
Usage:
python scripts/tune_rsac.py --env PendulumNoVel-v1 --n-trials 50 --n-timesteps 100000
python scripts/tune_rsac.py --env MountainCarContinuousNoVel-v0 --n-trials 30 --n-timesteps 300000
"""
import argparse
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import optuna
import rl_zoo3.import_envs # noqa: F401 — registers NoVel envs
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.vec_env import VecNormalize
from sb3_contrib import RecurrentSAC
# ── Evaluation settings ────────────────────────────────────────────────────────
N_EVAL_EPISODES = 10
EVAL_FREQ = None # evaluate only at the end of training
def make_env(env_id: str, n_envs: int, seed: int = 0) -> VecNormalize:
env = make_vec_env(env_id, n_envs=n_envs, seed=seed)
return VecNormalize(env, norm_obs=True, norm_reward=True)
def sample_params(trial: optuna.Trial, env_id: str) -> dict:
"""Sample hyperparameters for a trial."""
learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
gamma = trial.suggest_categorical("gamma", [0.9, 0.95, 0.98, 0.99, 0.999])
tau = trial.suggest_categorical("tau", [0.005, 0.01, 0.02])
batch_size = trial.suggest_categorical("batch_size", [64, 128, 256])
ent_coef = trial.suggest_categorical("ent_coef", ["auto", 0.001, 0.01, 0.1])
n_envs = trial.suggest_categorical("n_envs", [1, 2, 4])
# Recurrent-specific
segment_len = trial.suggest_categorical("segment_len", [20, 50, 100])
overlap_frac = trial.suggest_float("overlap_frac", 0.0, 0.4)
overlap = max(0, int(segment_len * overlap_frac))
burn_in_frac = trial.suggest_float("burn_in_frac", 0.0, 0.3)
burn_in = int(segment_len * burn_in_frac)
shared_state = trial.suggest_categorical("shared_state", [True, False])
lstm_hidden_size = trial.suggest_categorical("lstm_hidden_size", [32, 64, 128])
net_arch_depth = trial.suggest_categorical("net_arch_depth", [1, 2])
net_arch_width = trial.suggest_categorical("net_arch_width", [64, 128])
net_arch = [net_arch_width] * net_arch_depth
trial.set_user_attr("overlap", overlap)
trial.set_user_attr("burn_in", burn_in)
return dict(
learning_rate=learning_rate,
gamma=gamma,
tau=tau,
batch_size=batch_size,
ent_coef=ent_coef,
n_envs=n_envs,
segment_len=segment_len,
overlap=overlap,
burn_in=burn_in,
shared_state=shared_state,
policy_kwargs=dict(net_arch=net_arch, lstm_hidden_size=lstm_hidden_size),
)
def objective(trial: optuna.Trial, env_id: str, n_timesteps: int, seed: int = 0) -> float:
params = sample_params(trial, env_id)
n_envs = params.pop("n_envs")
# buffer_size in chunks; roughly enough for ~200k transitions at segment_len=50
buffer_size = max(512, 200_000 // params["segment_len"])
learning_starts = max(params["batch_size"] * 2, 1000)
try:
train_env = make_env(env_id, n_envs=n_envs, seed=seed)
eval_env = make_env(env_id, n_envs=1, seed=seed + 1000)
model = RecurrentSAC(
"MlpLstmPolicy",
train_env,
buffer_size=buffer_size,
learning_starts=learning_starts,
verbose=0,
seed=seed,
**params,
)
model.learn(total_timesteps=n_timesteps)
# Sync VecNormalize stats to eval env before evaluation
eval_env.obs_rms = train_env.obs_rms
eval_env.ret_rms = train_env.ret_rms
eval_env.training = False
eval_env.norm_reward = False
mean_reward, std_reward = evaluate_policy(
model, eval_env, n_eval_episodes=N_EVAL_EPISODES, deterministic=True
)
train_env.close()
eval_env.close()
except Exception as e:
print(f" Trial {trial.number} failed: {e}")
return float("-inf")
print(f" Trial {trial.number}: mean_reward={mean_reward:.2f} ± {std_reward:.2f}")
return mean_reward
def run_study(env_id: str, n_trials: int, n_timesteps: int, seed: int = 0) -> optuna.Study:
sampler = optuna.samplers.TPESampler(seed=seed)
pruner = optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=0)
study = optuna.create_study(
direction="maximize",
sampler=sampler,
pruner=pruner,
study_name=f"rsac_{env_id}",
)
study.optimize(
lambda trial: objective(trial, env_id, n_timesteps, seed),
n_trials=n_trials,
show_progress_bar=False,
)
return study
def print_results(study: optuna.Study, env_id: str) -> None:
best = study.best_trial
print(f"\n{'=' * 60}")
print(f"Best result for {env_id}")
print(f"{'=' * 60}")
print(f" Value (mean reward): {best.value:.2f}")
print(f" Params:")
for k, v in best.params.items():
print(f" {k}: {v}")
for k, v in best.user_attrs.items():
print(f" {k}: {v} [derived]")
print()
# Top-5 trials
trials = sorted([t for t in study.trials if t.value is not None and t.value > float("-inf")], key=lambda t: t.value, reverse=True)
print(f" Top-5 trials:")
for t in trials[:5]:
print(f" Trial {t.number:3d}: {t.value:.2f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--env", type=str, default="PendulumNoVel-v1")
parser.add_argument("--n-trials", type=int, default=30)
parser.add_argument("--n-timesteps", type=int, default=100_000)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
optuna.logging.set_verbosity(optuna.logging.WARNING)
print(f"Tuning RecurrentSAC on {args.env}")
print(f" n_trials={args.n_trials}, n_timesteps={args.n_timesteps}, seed={args.seed}")
study = run_study(args.env, args.n_trials, args.n_timesteps, args.seed)
print_results(study, args.env)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment