Last active
March 8, 2026 21:38
-
-
Save HAKSOAT/bba6bed4d2ae2d4a2c2492c665f69108 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Custom SAC implementation with trainable weights for multiple actor losses. | |
| This example demonstrates how to implement trainable loss weights in Stable-Baselines3 | |
| SAC when you have multiple actor losses that need to be balanced. | |
| Use case: | |
| - Critic loss: Fixed (not weighted) | |
| - Actor loss 1: Trainable weight w1 | |
| - Actor loss 2: Trainable weight w2 | |
| - Actor loss 3: Trainable weight w3 | |
| - Actor loss 4: Trainable weight w4 | |
| The total loss is: | |
| total_loss = critic_loss + w1*loss1 + w2*loss2 + w3*loss3 + w4*loss4 + regularization | |
| Key implementation details: | |
| 1. Only the 4 actor losses are weighted (critic is fixed) | |
| 2. Weights are represented in log-space for numerical stability | |
| 3. Regularization prevents weights from collapsing to zero | |
| 4. Multiple backward passes are handled with retain_graph=True | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Any, Dict, Optional, Type, Union | |
| import gymnasium as gym | |
| from stable_baselines3 import SAC | |
| from stable_baselines3.common.policies import SacPolicy | |
| from stable_baselines3.common.type_aliases import GymEnv, Schedule | |
| class CustomSacPolicy(SacPolicy): | |
| """ | |
| Custom SAC policy with learnable weights for multiple actor losses. | |
| This policy extends the standard SacPolicy to include trainable parameters | |
| for weighting different actor loss components. | |
| Attributes: | |
| log_loss_weights: Log of the loss weights for the 4 actor losses. | |
| Shape: (4,) | |
| The actual weights are computed as exp(log_loss_weights). | |
| """ | |
| def __init__( | |
| self, | |
| observation_space: gym.spaces.Space, | |
| action_space: gym.spaces.Space, | |
| lr_schedule: Schedule, | |
| net_arch: Optional[Dict[str, list]] = None, | |
| activation_fn: Type[nn.Module] = nn.ReLU, | |
| use_sde: bool = False, | |
| log_std_init: float = -3, | |
| use_expln: bool = False, | |
| clip_mean: float = 2.0, | |
| features_extractor_class: Type = None, | |
| features_extractor_kwargs: Optional[Dict[str, Any]] = None, | |
| normalize_images: bool = True, | |
| n_actor_losses: int = 4, # Number of actor losses to weight | |
| ): | |
| """ | |
| Initialize the custom SAC policy. | |
| Args: | |
| n_actor_losses: Number of actor losses to weight. In your case, this is 4. | |
| """ | |
| super().__init__( | |
| observation_space=observation_space, | |
| action_space=action_space, | |
| lr_schedule=lr_schedule, | |
| net_arch=net_arch, | |
| activation_fn=activation_fn, | |
| use_sde=use_sde, | |
| log_std_init=log_std_init, | |
| use_expln=use_expln, | |
| clip_mean=clip_mean, | |
| features_extractor_class=features_extractor_class, | |
| features_extractor_kwargs=features_extractor_kwargs, | |
| normalize_images=normalize_images, | |
| ) | |
| # Initialize learnable log-weights for the 4 actor losses | |
| # Initialized to 0 means initial weights are exp(0) = 1.0 (equal weighting) | |
| self.log_loss_weights = nn.Parameter( | |
| torch.zeros(n_actor_losses, dtype=torch.float32), | |
| requires_grad=True | |
| ) | |
| self.register_parameter('log_loss_weights', self.log_loss_weights) | |
| @property | |
| def loss_weights(self) -> torch.Tensor: | |
| """Compute actual weights from log-weights.""" | |
| return torch.exp(self.log_loss_weights) | |
| class CustomSAC(SAC): | |
| """ | |
| SAC algorithm with trainable weights for multiple actor losses. | |
| This implementation extends the standard SAC algorithm to support trainable | |
| weights for multiple actor loss components. The critic loss remains fixed, | |
| while the 4 actor losses are weighted with learnable parameters. | |
| The total loss is computed as: | |
| total_loss = critic_loss + w1*loss1 + w2*loss2 + w3*loss3 + w4*loss4 | |
| + regularization_term | |
| Where the regularization term prevents weights from collapsing to zero. | |
| """ | |
| def __init__( | |
| self, | |
| policy: Union[str, Type[SacPolicy]] = "MlpPolicy", | |
| env: Union[GymEnv, str, None] = None, | |
| learning_rate: Union[float, Schedule] = 3e-4, | |
| buffer_size: int = 1_000_000, | |
| learning_starts: int = 10_000, | |
| batch_size: int = 256, | |
| tau: float = 0.005, | |
| gamma: float = 0.99, | |
| train_freq: Union[int, tuple] = 1, | |
| gradient_steps: int = 1, | |
| action_noise: Optional[Any] = None, | |
| replay_buffer_class: Optional[Type] = None, | |
| replay_buffer_kwargs: Optional[Dict[str, Any]] = None, | |
| optimize_memory_usage: bool = False, | |
| ent_coef: Union[str, float] = "auto", | |
| target_update_interval: int = 1, | |
| target_entropy: Optional[float] = None, | |
| use_sde: bool = False, | |
| sde_sample_freq: int = -1, | |
| use_sde_at_warmup: bool = False, | |
| policy_kwargs: Optional[Dict[str, Any]] = None, | |
| verbose: int = 0, | |
| seed: Optional[int] = None, | |
| device: Union[torch.device, str] = "auto", | |
| _init_setup_model: bool = True, | |
| weight_loss_regularization: float = 0.1, | |
| n_actor_losses: int = 4, | |
| ): | |
| """ | |
| Initialize CustomSAC. | |
| Args: | |
| weight_loss_regularization: Strength of the regularization term | |
| that prevents weights from collapsing. | |
| n_actor_losses: Number of actor losses to weight (default: 4). | |
| """ | |
| self.weight_loss_regularization = weight_loss_regularization | |
| self.n_actor_losses = n_actor_losses | |
| # Use CustomSacPolicy by default | |
| if policy_kwargs is None: | |
| policy_kwargs = {} | |
| # Pass n_actor_losses to the policy | |
| policy_kwargs['n_actor_losses'] = n_actor_losses | |
| if policy == "MlpPolicy": | |
| policy = CustomSacPolicy | |
| super().__init__( | |
| policy=policy, | |
| env=env, | |
| learning_rate=learning_rate, | |
| buffer_size=buffer_size, | |
| learning_starts=learning_starts, | |
| batch_size=batch_size, | |
| tau=tau, | |
| gamma=gamma, | |
| train_freq=train_freq, | |
| gradient_steps=gradient_steps, | |
| action_noise=action_noise, | |
| replay_buffer_class=replay_buffer_class, | |
| replay_buffer_kwargs=replay_buffer_kwargs, | |
| optimize_memory_usage=optimize_memory_usage, | |
| ent_coef=ent_coef, | |
| target_update_interval=target_update_interval, | |
| target_entropy=target_entropy, | |
| use_sde=use_sde, | |
| sde_sample_freq=sde_sample_freq, | |
| use_sde_at_warmup=use_sde_at_warmup, | |
| policy_kwargs=policy_kwargs, | |
| verbose=verbose, | |
| seed=seed, | |
| device=device, | |
| _init_setup_model=_init_setup_model, | |
| ) | |
| def train(self, gradient_steps: int, batch_size: int = 64) -> None: | |
| """ | |
| Train SAC with learnable weights for multiple actor losses. | |
| This method is called during the learning process. It samples from the | |
| replay buffer and updates the critic and actor networks using the | |
| learnable loss weights for the actor losses. | |
| """ | |
| # Ensure we're using the custom policy | |
| assert isinstance(self.policy, CustomSacPolicy), \ | |
| "Policy must be CustomSacPolicy to use trainable loss weights" | |
| # Get the loss weights for the 4 actor losses | |
| loss_weights = self.policy.loss_weights # Shape: (4,) | |
| for gradient_step in range(gradient_steps): | |
| # Sample from replay buffer | |
| replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) | |
| # ============================================================ | |
| # COMPUTE CRITIC LOSS (Fixed, not weighted) | |
| # ============================================================ | |
| with torch.no_grad(): | |
| # Select action according to policy | |
| next_actions, next_log_prob = self.actor.action_log_prob( | |
| replay_data.next_observations | |
| ) | |
| # Compute the next Q-values using target networks | |
| next_q_values = torch.cat( | |
| self.critic_target(replay_data.next_observations, next_actions), dim=1 | |
| ) | |
| next_q_values, _ = torch.min(next_q_values, dim=1, keepdim=True) | |
| # Add entropy term | |
| next_q_values = next_q_values - self.ent_coef * next_log_prob.reshape(-1, 1) | |
| # Compute target Q-value | |
| target_q_values = ( | |
| replay_data.rewards + | |
| (1 - replay_data.dones) * self.gamma * next_q_values | |
| ) | |
| # Get current Q-values estimates for each critic network | |
| current_q_values = self.critic(replay_data.observations, replay_data.actions) | |
| # Compute critic loss (this is NOT weighted) | |
| critic_loss = 0.5 * sum( | |
| F.mse_loss(current_q, target_q_values) | |
| for current_q in current_q_values | |
| ) | |
| # ============================================================ | |
| # COMPUTE ACTOR LOSSES (4 losses, each with trainable weight) | |
| # ============================================================ | |
| # Sample actions from the policy | |
| pi_actions, log_pi = self.actor.action_log_prob(replay_data.observations) | |
| # Compute Q-values for the sampled actions | |
| qf_pi = torch.cat( | |
| self.critic(replay_data.observations, pi_actions), dim=1 | |
| ) | |
| # Take minimum over both critics | |
| min_qf_pi, _ = torch.min(qf_pi, dim=1, keepdim=True) | |
| # ============================================================ | |
| # EXAMPLE: Compute 4 different actor losses | |
| # ============================================================ | |
| # You'll need to replace these with your actual loss computations | |
| # This is just a template showing how to structure it | |
| # Loss 1: Standard SAC actor loss (maximize Q-value and entropy) | |
| actor_loss_1 = (self.ent_coef * log_pi - min_qf_pi).mean() | |
| # Loss 2: Example auxiliary loss (you'll replace with your actual loss) | |
| # For example, this could be a behavior cloning loss, imitation loss, etc. | |
| actor_loss_2 = torch.tensor(0.0, device=self.device) # Placeholder | |
| # Loss 3: Example auxiliary loss | |
| actor_loss_3 = torch.tensor(0.0, device=self.device) # Placeholder | |
| # Loss 4: Example auxiliary loss | |
| actor_loss_4 = torch.tensor(0.0, device=self.device) # Placeholder | |
| # ============================================================ | |
| # COMBINE ACTOR LOSSES WITH LEARNABLE WEIGHTS | |
| # ============================================================ | |
| # CRITICAL: Use torch.stack() to preserve computation graph | |
| actor_losses = torch.stack([actor_loss_1, actor_loss_2, actor_loss_3, actor_loss_4]) | |
| # Apply learned weights to actor losses | |
| weighted_actor_losses = loss_weights * actor_losses | |
| weighted_actor_loss = weighted_actor_losses.sum() | |
| # Add regularization to prevent weights from collapsing to zero | |
| regularization = self.weight_loss_regularization * torch.sum( | |
| self.policy.log_loss_weights | |
| ) | |
| # Total loss = critic (fixed) + weighted actor losses + regularization | |
| total_loss = critic_loss + weighted_actor_loss + regularization | |
| # ============================================================ | |
| # UPDATE NETWORKS | |
| # ============================================================ | |
| # Update critic networks | |
| self.critic.optimizer.zero_grad() | |
| total_loss.backward(retain_graph=True) | |
| self.critic.optimizer.step() | |
| # Update actor network | |
| self.actor.optimizer.zero_grad() | |
| total_loss.backward(retain_graph=True) | |
| self.actor.optimizer.step() | |
| # Update entropy coefficient (if learning it) | |
| if self.learn_entropy_coef: | |
| with torch.no_grad(): | |
| _, log_pi = self.actor.action_log_prob(replay_data.observations) | |
| # Entropy coefficient loss | |
| entropy_loss = -( | |
| self.log_ent_coef * (log_pi + self.target_entropy).detach() | |
| ).mean() | |
| self.log_ent_coef_optimizer.zero_grad() | |
| entropy_loss.backward() | |
| self.log_ent_coef_optimizer.step() | |
| # Update target networks using exponential moving average | |
| if gradient_step % self.target_update_interval == 0: | |
| self._update_target_networks() | |
| # Log loss weights for monitoring | |
| if self.logger is not None: | |
| self.logger.record("train/loss_weights", loss_weights.detach().cpu().numpy()) | |
| self.logger.record("train/critic_loss", critic_loss.item()) | |
| self.logger.record("train/actor_loss_1", actor_loss_1.item()) | |
| self.logger.record("train/actor_loss_2", actor_loss_2.item()) | |
| self.logger.record("train/actor_loss_3", actor_loss_3.item()) | |
| self.logger.record("train/actor_loss_4", actor_loss_4.item()) | |
| self.logger.record("train/weighted_actor_loss", weighted_actor_loss.item()) | |
| # Example usage | |
| if __name__ == "__main__": | |
| # Create environment | |
| env = gym.make("Pendulum-v1") | |
| # Create the custom SAC agent with 4 actor losses | |
| model = CustomSAC( | |
| policy="MlpPolicy", | |
| env=env, | |
| learning_rate=3e-4, | |
| batch_size=256, | |
| buffer_size=100_000, | |
| learning_starts=1_000, | |
| verbose=1, | |
| weight_loss_regularization=0.1, | |
| n_actor_losses=4, | |
| ) | |
| # Train the agent | |
| model.learn(total_timesteps=10_000) | |
| # Test the agent | |
| obs, _ = env.reset() | |
| for _ in range(100): | |
| action, _states = model.predict(obs, deterministic=True) | |
| obs, reward, terminated, truncated, info = env.step(action) | |
| if terminated or truncated: | |
| obs, _ = env.reset() | |
| env.close() | |
| # Print final loss weights | |
| print("\nFinal loss weights for the 4 actor losses:") | |
| for i, weight in enumerate(model.policy.loss_weights.detach().cpu().numpy()): | |
| print(f" Actor loss {i+1} weight: {weight:.4f}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment