Skip to content

Instantly share code, notes, and snippets.

@HAKSOAT
Last active March 8, 2026 21:38
Show Gist options
  • Select an option

  • Save HAKSOAT/bba6bed4d2ae2d4a2c2492c665f69108 to your computer and use it in GitHub Desktop.

Select an option

Save HAKSOAT/bba6bed4d2ae2d4a2c2492c665f69108 to your computer and use it in GitHub Desktop.
"""
Custom SAC implementation with trainable weights for multiple actor losses.
This example demonstrates how to implement trainable loss weights in Stable-Baselines3
SAC when you have multiple actor losses that need to be balanced.
Use case:
- Critic loss: Fixed (not weighted)
- Actor loss 1: Trainable weight w1
- Actor loss 2: Trainable weight w2
- Actor loss 3: Trainable weight w3
- Actor loss 4: Trainable weight w4
The total loss is:
total_loss = critic_loss + w1*loss1 + w2*loss2 + w3*loss3 + w4*loss4 + regularization
Key implementation details:
1. Only the 4 actor losses are weighted (critic is fixed)
2. Weights are represented in log-space for numerical stability
3. Regularization prevents weights from collapsing to zero
4. Multiple backward passes are handled with retain_graph=True
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Dict, Optional, Type, Union
import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.policies import SacPolicy
from stable_baselines3.common.type_aliases import GymEnv, Schedule
class CustomSacPolicy(SacPolicy):
"""
Custom SAC policy with learnable weights for multiple actor losses.
This policy extends the standard SacPolicy to include trainable parameters
for weighting different actor loss components.
Attributes:
log_loss_weights: Log of the loss weights for the 4 actor losses.
Shape: (4,)
The actual weights are computed as exp(log_loss_weights).
"""
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Schedule,
net_arch: Optional[Dict[str, list]] = None,
activation_fn: Type[nn.Module] = nn.ReLU,
use_sde: bool = False,
log_std_init: float = -3,
use_expln: bool = False,
clip_mean: float = 2.0,
features_extractor_class: Type = None,
features_extractor_kwargs: Optional[Dict[str, Any]] = None,
normalize_images: bool = True,
n_actor_losses: int = 4, # Number of actor losses to weight
):
"""
Initialize the custom SAC policy.
Args:
n_actor_losses: Number of actor losses to weight. In your case, this is 4.
"""
super().__init__(
observation_space=observation_space,
action_space=action_space,
lr_schedule=lr_schedule,
net_arch=net_arch,
activation_fn=activation_fn,
use_sde=use_sde,
log_std_init=log_std_init,
use_expln=use_expln,
clip_mean=clip_mean,
features_extractor_class=features_extractor_class,
features_extractor_kwargs=features_extractor_kwargs,
normalize_images=normalize_images,
)
# Initialize learnable log-weights for the 4 actor losses
# Initialized to 0 means initial weights are exp(0) = 1.0 (equal weighting)
self.log_loss_weights = nn.Parameter(
torch.zeros(n_actor_losses, dtype=torch.float32),
requires_grad=True
)
self.register_parameter('log_loss_weights', self.log_loss_weights)
@property
def loss_weights(self) -> torch.Tensor:
"""Compute actual weights from log-weights."""
return torch.exp(self.log_loss_weights)
class CustomSAC(SAC):
"""
SAC algorithm with trainable weights for multiple actor losses.
This implementation extends the standard SAC algorithm to support trainable
weights for multiple actor loss components. The critic loss remains fixed,
while the 4 actor losses are weighted with learnable parameters.
The total loss is computed as:
total_loss = critic_loss + w1*loss1 + w2*loss2 + w3*loss3 + w4*loss4
+ regularization_term
Where the regularization term prevents weights from collapsing to zero.
"""
def __init__(
self,
policy: Union[str, Type[SacPolicy]] = "MlpPolicy",
env: Union[GymEnv, str, None] = None,
learning_rate: Union[float, Schedule] = 3e-4,
buffer_size: int = 1_000_000,
learning_starts: int = 10_000,
batch_size: int = 256,
tau: float = 0.005,
gamma: float = 0.99,
train_freq: Union[int, tuple] = 1,
gradient_steps: int = 1,
action_noise: Optional[Any] = None,
replay_buffer_class: Optional[Type] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
ent_coef: Union[str, float] = "auto",
target_update_interval: int = 1,
target_entropy: Optional[float] = None,
use_sde: bool = False,
sde_sample_freq: int = -1,
use_sde_at_warmup: bool = False,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[torch.device, str] = "auto",
_init_setup_model: bool = True,
weight_loss_regularization: float = 0.1,
n_actor_losses: int = 4,
):
"""
Initialize CustomSAC.
Args:
weight_loss_regularization: Strength of the regularization term
that prevents weights from collapsing.
n_actor_losses: Number of actor losses to weight (default: 4).
"""
self.weight_loss_regularization = weight_loss_regularization
self.n_actor_losses = n_actor_losses
# Use CustomSacPolicy by default
if policy_kwargs is None:
policy_kwargs = {}
# Pass n_actor_losses to the policy
policy_kwargs['n_actor_losses'] = n_actor_losses
if policy == "MlpPolicy":
policy = CustomSacPolicy
super().__init__(
policy=policy,
env=env,
learning_rate=learning_rate,
buffer_size=buffer_size,
learning_starts=learning_starts,
batch_size=batch_size,
tau=tau,
gamma=gamma,
train_freq=train_freq,
gradient_steps=gradient_steps,
action_noise=action_noise,
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
optimize_memory_usage=optimize_memory_usage,
ent_coef=ent_coef,
target_update_interval=target_update_interval,
target_entropy=target_entropy,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
use_sde_at_warmup=use_sde_at_warmup,
policy_kwargs=policy_kwargs,
verbose=verbose,
seed=seed,
device=device,
_init_setup_model=_init_setup_model,
)
def train(self, gradient_steps: int, batch_size: int = 64) -> None:
"""
Train SAC with learnable weights for multiple actor losses.
This method is called during the learning process. It samples from the
replay buffer and updates the critic and actor networks using the
learnable loss weights for the actor losses.
"""
# Ensure we're using the custom policy
assert isinstance(self.policy, CustomSacPolicy), \
"Policy must be CustomSacPolicy to use trainable loss weights"
# Get the loss weights for the 4 actor losses
loss_weights = self.policy.loss_weights # Shape: (4,)
for gradient_step in range(gradient_steps):
# Sample from replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)
# ============================================================
# COMPUTE CRITIC LOSS (Fixed, not weighted)
# ============================================================
with torch.no_grad():
# Select action according to policy
next_actions, next_log_prob = self.actor.action_log_prob(
replay_data.next_observations
)
# Compute the next Q-values using target networks
next_q_values = torch.cat(
self.critic_target(replay_data.next_observations, next_actions), dim=1
)
next_q_values, _ = torch.min(next_q_values, dim=1, keepdim=True)
# Add entropy term
next_q_values = next_q_values - self.ent_coef * next_log_prob.reshape(-1, 1)
# Compute target Q-value
target_q_values = (
replay_data.rewards +
(1 - replay_data.dones) * self.gamma * next_q_values
)
# Get current Q-values estimates for each critic network
current_q_values = self.critic(replay_data.observations, replay_data.actions)
# Compute critic loss (this is NOT weighted)
critic_loss = 0.5 * sum(
F.mse_loss(current_q, target_q_values)
for current_q in current_q_values
)
# ============================================================
# COMPUTE ACTOR LOSSES (4 losses, each with trainable weight)
# ============================================================
# Sample actions from the policy
pi_actions, log_pi = self.actor.action_log_prob(replay_data.observations)
# Compute Q-values for the sampled actions
qf_pi = torch.cat(
self.critic(replay_data.observations, pi_actions), dim=1
)
# Take minimum over both critics
min_qf_pi, _ = torch.min(qf_pi, dim=1, keepdim=True)
# ============================================================
# EXAMPLE: Compute 4 different actor losses
# ============================================================
# You'll need to replace these with your actual loss computations
# This is just a template showing how to structure it
# Loss 1: Standard SAC actor loss (maximize Q-value and entropy)
actor_loss_1 = (self.ent_coef * log_pi - min_qf_pi).mean()
# Loss 2: Example auxiliary loss (you'll replace with your actual loss)
# For example, this could be a behavior cloning loss, imitation loss, etc.
actor_loss_2 = torch.tensor(0.0, device=self.device) # Placeholder
# Loss 3: Example auxiliary loss
actor_loss_3 = torch.tensor(0.0, device=self.device) # Placeholder
# Loss 4: Example auxiliary loss
actor_loss_4 = torch.tensor(0.0, device=self.device) # Placeholder
# ============================================================
# COMBINE ACTOR LOSSES WITH LEARNABLE WEIGHTS
# ============================================================
# CRITICAL: Use torch.stack() to preserve computation graph
actor_losses = torch.stack([actor_loss_1, actor_loss_2, actor_loss_3, actor_loss_4])
# Apply learned weights to actor losses
weighted_actor_losses = loss_weights * actor_losses
weighted_actor_loss = weighted_actor_losses.sum()
# Add regularization to prevent weights from collapsing to zero
regularization = self.weight_loss_regularization * torch.sum(
self.policy.log_loss_weights
)
# Total loss = critic (fixed) + weighted actor losses + regularization
total_loss = critic_loss + weighted_actor_loss + regularization
# ============================================================
# UPDATE NETWORKS
# ============================================================
# Update critic networks
self.critic.optimizer.zero_grad()
total_loss.backward(retain_graph=True)
self.critic.optimizer.step()
# Update actor network
self.actor.optimizer.zero_grad()
total_loss.backward(retain_graph=True)
self.actor.optimizer.step()
# Update entropy coefficient (if learning it)
if self.learn_entropy_coef:
with torch.no_grad():
_, log_pi = self.actor.action_log_prob(replay_data.observations)
# Entropy coefficient loss
entropy_loss = -(
self.log_ent_coef * (log_pi + self.target_entropy).detach()
).mean()
self.log_ent_coef_optimizer.zero_grad()
entropy_loss.backward()
self.log_ent_coef_optimizer.step()
# Update target networks using exponential moving average
if gradient_step % self.target_update_interval == 0:
self._update_target_networks()
# Log loss weights for monitoring
if self.logger is not None:
self.logger.record("train/loss_weights", loss_weights.detach().cpu().numpy())
self.logger.record("train/critic_loss", critic_loss.item())
self.logger.record("train/actor_loss_1", actor_loss_1.item())
self.logger.record("train/actor_loss_2", actor_loss_2.item())
self.logger.record("train/actor_loss_3", actor_loss_3.item())
self.logger.record("train/actor_loss_4", actor_loss_4.item())
self.logger.record("train/weighted_actor_loss", weighted_actor_loss.item())
# Example usage
if __name__ == "__main__":
# Create environment
env = gym.make("Pendulum-v1")
# Create the custom SAC agent with 4 actor losses
model = CustomSAC(
policy="MlpPolicy",
env=env,
learning_rate=3e-4,
batch_size=256,
buffer_size=100_000,
learning_starts=1_000,
verbose=1,
weight_loss_regularization=0.1,
n_actor_losses=4,
)
# Train the agent
model.learn(total_timesteps=10_000)
# Test the agent
obs, _ = env.reset()
for _ in range(100):
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
obs, _ = env.reset()
env.close()
# Print final loss weights
print("\nFinal loss weights for the 4 actor losses:")
for i, weight in enumerate(model.policy.loss_weights.detach().cpu().numpy()):
print(f" Actor loss {i+1} weight: {weight:.4f}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment