Created
March 20, 2026 19:23
-
-
Save YannBerthelot/9be3ee1bb245c3edabf382a798b9a2c1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from collections.abc import Sequence | |
| from dataclasses import fields | |
| from math import floor | |
| from typing import Any, Callable, Dict, Optional, Tuple | |
| import jax | |
| import jax.numpy as jnp | |
| from flax import struct | |
| from flax.core import FrozenDict | |
| from flax.serialization import to_state_dict | |
| from flax.training.train_state import TrainState | |
| from jax.tree_util import Partial as partial | |
| from ajax.agents.cloning import ( | |
| CloningConfig, | |
| compute_imitation_score, | |
| get_cloning_args, | |
| get_pre_trained_agent, | |
| ) | |
| from ajax.agents.SAC.state import SACConfig, SACState | |
| from ajax.agents.SAC.utils import SquashedNormal | |
| from ajax.buffers.utils import get_batch_from_buffer | |
| from ajax.environments.interaction import ( | |
| collect_experience, | |
| get_pi, | |
| init_collector_state, | |
| should_use_uniform_sampling, | |
| ) | |
| from ajax.environments.utils import check_env_is_gymnax, get_state_action_shapes | |
| from ajax.log import evaluate_and_log | |
| from ajax.logging.wandb_logging import ( | |
| LoggingConfig, | |
| start_async_logging, | |
| vmap_log, | |
| ) | |
| from ajax.networks.networks import ( | |
| get_adam_tx, | |
| get_initialized_actor_critic, | |
| predict_value, | |
| ) | |
| from ajax.state import ( | |
| AlphaConfig, | |
| EnvironmentConfig, | |
| LoadedTrainState, | |
| NetworkConfig, | |
| OptimizerConfig, | |
| Transition, | |
| ) | |
| from ajax.types import BufferType | |
| from ajax.utils import get_one | |
| def get_alpha_from_params(params: FrozenDict) -> float: | |
| return jnp.exp(params["log_alpha"]) | |
| @struct.dataclass | |
| class TemperatureAuxiliaries: | |
| alpha_loss: jax.Array | |
| alpha: jax.Array | |
| log_alpha: jax.Array | |
| @struct.dataclass | |
| class PolicyAuxiliaries: | |
| policy_loss: jax.Array | |
| log_pi: jax.Array | |
| q_min: jax.Array | |
| imitation_loss: jax.Array | |
| raw_loss: jax.Array | |
| @struct.dataclass | |
| class ValueAuxiliaries: | |
| critic_loss: jax.Array | |
| q1_pred: jax.Array | |
| q2_pred: jax.Array | |
| target_q: jax.Array | |
| log_probs: jax.Array | |
| var_preds: jax.Array | |
| @struct.dataclass | |
| class AuxiliaryLogs: | |
| temperature: TemperatureAuxiliaries | |
| policy: PolicyAuxiliaries | |
| value: ValueAuxiliaries | |
| def create_alpha_train_state( | |
| learning_rate: float = 3e-4, | |
| alpha_init: float = 1.0, | |
| ) -> TrainState: | |
| """ | |
| Initialize the train state for the temperature parameter (alpha). | |
| Args: | |
| learning_rate (float): Learning rate for alpha optimizer. | |
| alpha_init (float): Initial value for alpha. | |
| Returns: | |
| TrainState: Initialized train state for alpha. | |
| """ | |
| log_alpha = jnp.log(alpha_init) | |
| params = FrozenDict({"log_alpha": log_alpha}) | |
| tx = get_adam_tx(learning_rate) | |
| return TrainState.create( | |
| apply_fn=get_alpha_from_params, # Optional | |
| params=params, | |
| tx=tx, | |
| ) | |
| def init_SAC( | |
| key: jax.Array, | |
| env_args: EnvironmentConfig, | |
| actor_optimizer_args: OptimizerConfig, | |
| critic_optimizer_args: OptimizerConfig, | |
| network_args: NetworkConfig, | |
| alpha_args: AlphaConfig, | |
| buffer: BufferType, | |
| window_size: int = 10, | |
| expert_policy: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, | |
| residual: bool = False, | |
| fixed_alpha: bool = False, | |
| max_timesteps: Optional[int] = None, | |
| num_critics: int = 2, | |
| ) -> SACState: | |
| """ | |
| Initialize the SAC agent's state, including actor, critic, alpha, and collector states. | |
| Args: | |
| key (jax.Array): Random number generator key. | |
| env_args (EnvironmentConfig): Environment configuration. | |
| optimizer_args (OptimizerConfig): Optimizer configuration. | |
| network_args (NetworkConfig): Network configuration. | |
| alpha_args (AlphaConfig): Alpha configuration. | |
| buffer (BufferType): Replay buffer. | |
| Returns: | |
| SACState: Initialized SAC agent state. | |
| """ | |
| ( | |
| rng, | |
| init_key, | |
| collector_key, | |
| ) = jax.random.split(key, num=3) | |
| actor_state, critic_state = get_initialized_actor_critic( | |
| key=init_key, | |
| env_config=env_args, | |
| actor_optimizer_config=actor_optimizer_args, | |
| critic_optimizer_config=critic_optimizer_args, | |
| network_config=network_args, | |
| continuous=True, | |
| action_value=True, | |
| squash=True, | |
| num_critics=num_critics, | |
| expert_policy=expert_policy, | |
| residual=residual, | |
| fixed_alpha=fixed_alpha, | |
| max_timesteps=max_timesteps, | |
| ) | |
| mode = "gymnax" if check_env_is_gymnax(env_args.env) else "brax" | |
| collector_state = init_collector_state( | |
| collector_key, | |
| env_args=env_args, | |
| mode=mode, | |
| buffer=buffer, | |
| window_size=window_size, | |
| max_timesteps=max_timesteps, | |
| ) | |
| alpha = create_alpha_train_state(**to_state_dict(alpha_args)) | |
| return SACState( | |
| rng=rng, | |
| eval_rng=rng, | |
| actor_state=actor_state, | |
| critic_state=critic_state, | |
| alpha=alpha, | |
| collector_state=collector_state, | |
| lambda_param=1.0, | |
| ) | |
| @partial(jax.jit, static_argnames=["recurrent", "gamma", "reward_scale"]) | |
| def value_loss_function( | |
| critic_params: FrozenDict, | |
| critic_states: LoadedTrainState, | |
| rng: jax.Array, | |
| actor_state: LoadedTrainState, | |
| actions: jax.Array, | |
| observations: jax.Array, | |
| next_observations: jax.Array, | |
| dones: jax.Array, | |
| rewards: jax.Array, | |
| gamma: float, | |
| alpha: jax.Array, | |
| recurrent: bool, | |
| reward_scale: float = 5.0, # Add reward scaling factor here | |
| ) -> Tuple[jax.Array, ValueAuxiliaries]: | |
| """ | |
| Compute the value loss for the critic networks. | |
| Args: | |
| critic_params (FrozenDict): Parameters of the critic networks. | |
| critic_states (LoadedTrainState): Critic train states. | |
| rng (jax.Array): Random number generator key. | |
| actor_state (LoadedTrainState): Actor train state. | |
| actions (jax.Array): Actions taken. | |
| observations (jax.Array): Current observations. | |
| next_observations (jax.Array): Next observations. | |
| dones (jax.Array): Done flags. | |
| rewards (jax.Array): Rewards received. | |
| gamma (float): Discount factor. | |
| alpha (jax.Array): Temperature parameter. | |
| recurrent (bool): Whether the model is recurrent. | |
| reward_scale (float): Reward scaling factor. | |
| Returns: | |
| Tuple[jax.Array, Dict[str, jax.Array]]: Loss and auxiliary metrics. | |
| """ | |
| # Apply the reward scaling here | |
| rewards = rewards * reward_scale | |
| # Sample next actions from policy π(a|s_{t+1}) | |
| next_pi, _ = get_pi( | |
| actor_state=actor_state, | |
| actor_params=actor_state.params, | |
| obs=next_observations, | |
| done=dones, | |
| recurrent=recurrent, | |
| ) | |
| sample_key, rng = jax.random.split(rng) | |
| next_actions, log_probs = next_pi.sample_and_log_prob(seed=sample_key) | |
| # Predict Q-values from critics | |
| q_preds = predict_value( | |
| critic_state=critic_states, | |
| critic_params=critic_params, | |
| x=jnp.concatenate((observations, jax.lax.stop_gradient(actions)), axis=-1), | |
| ) | |
| var_preds = q_preds.var(axis=0, keepdims=True) | |
| # Target Q-values using target networks | |
| assert ( | |
| critic_states.target_params is not None | |
| ), "Target parameters are not set in critic states." | |
| q_targets = predict_value( | |
| critic_state=critic_states, | |
| critic_params=critic_states.target_params, | |
| x=jnp.concatenate((next_observations, next_actions), axis=-1), | |
| ) | |
| # Unpack and unsqueeze if needed | |
| q1_pred, q2_pred = jnp.split(q_preds, 2, axis=0) | |
| q1_target, q2_target = jnp.split(q_targets, 2, axis=0) | |
| # Bellman target and losses | |
| min_q_target = jnp.min(q_targets, axis=0, keepdims=False) | |
| log_probs = log_probs.sum(-1, keepdims=True) | |
| target_q = jax.lax.stop_gradient( | |
| rewards + gamma * (1.0 - dones) * (min_q_target - alpha * log_probs), | |
| ) | |
| assert target_q.shape == q_preds.shape[1:], f"{target_q.shape} != {q_preds.shape}" | |
| assert min_q_target.shape == log_probs.shape | |
| total_loss = jnp.mean((q_preds - target_q) ** 2) | |
| # loss_q1 = 0.5 * jnp.mean((q1_pred.squeeze(0) - target_q) ** 2) | |
| # loss_q2 = 0.5 * jnp.mean((q2_pred.squeeze(0) - target_q) ** 2) | |
| # total_loss = loss_q1 + loss_q2 | |
| return total_loss, ValueAuxiliaries( | |
| critic_loss=total_loss, | |
| q1_pred=q1_pred.mean().flatten(), | |
| q2_pred=q2_pred.mean().flatten(), | |
| target_q=target_q.mean().flatten(), | |
| log_probs=log_probs.mean().flatten(), | |
| var_preds=var_preds.mean().flatten(), | |
| ) | |
| @partial( | |
| jax.jit, | |
| static_argnames=[ | |
| "recurrent", | |
| "expert_policy", | |
| "distance_to_stable", | |
| "imitation_coef_offset", | |
| ], | |
| ) | |
| def policy_loss_function( | |
| actor_params: FrozenDict, | |
| actor_state: LoadedTrainState, | |
| critic_states: LoadedTrainState, | |
| observations: jax.Array, | |
| dones: Optional[jax.Array], | |
| recurrent: bool, | |
| alpha: jax.Array, | |
| rng: jax.random.PRNGKey, | |
| raw_observations: Optional[jax.Array] = None, | |
| expert_policy: Optional[Callable] = None, | |
| imitation_coef: float = 0.01, | |
| distance_to_stable: Callable = get_one, | |
| imitation_coef_offset: float = 1e-3, | |
| ) -> Tuple[jax.Array, PolicyAuxiliaries]: | |
| """ | |
| Compute the policy loss for the actor network. | |
| Args: | |
| actor_params (FrozenDict): Parameters of the actor network. | |
| actor_state (LoadedTrainState): Actor train state. | |
| critic_states (LoadedTrainState): Critic train states. | |
| observations (jax.Array): Current observations. | |
| dones (Optional[jax.Array]): Done flags. | |
| recurrent (bool): Whether the model is recurrent. | |
| alpha (jax.Array): Temperature parameter. | |
| rng (jax.random.PRNGKey): Random number generator key. | |
| Returns: | |
| Tuple[jax.Array, Dict[str, jax.Array]]: Loss and auxiliary metrics. | |
| """ | |
| pi, _ = get_pi( | |
| actor_state=actor_state, | |
| actor_params=actor_params, | |
| obs=observations, | |
| done=dones, | |
| recurrent=recurrent, | |
| ) | |
| sample_key, rng = jax.random.split(rng) | |
| actions, log_probs = pi.sample_and_log_prob(seed=sample_key) | |
| std_loss = ( | |
| (pi.unsquashed_stddev() if isinstance(pi, SquashedNormal) else pi.stddev()) ** 2 | |
| ).mean() | |
| # Predict Q-values from critics | |
| q_preds = predict_value( | |
| critic_state=critic_states, | |
| critic_params=critic_states.params, | |
| x=jnp.hstack((observations, actions)), | |
| ) | |
| q_expert = jnp.max( | |
| predict_value( | |
| critic_state=critic_states, | |
| critic_params=critic_states.params, | |
| x=jnp.hstack((observations, expert_policy(raw_observations))), | |
| ), | |
| axis=0, | |
| keepdims=True, | |
| ) | |
| # Unpack and unsqueeze if needed | |
| # q1_pred, q2_pred = jnp.split(q_preds, 2, axis=0) | |
| q_min = jnp.min(q_preds, axis=0, keepdims=True).squeeze(0) | |
| log_probs = log_probs.sum(-1, keepdims=True) | |
| imitation_loss = compute_imitation_score( | |
| pi, | |
| expert_policy, | |
| raw_observations, | |
| distance_to_stable, | |
| imitation_coef_offset, | |
| q_min, | |
| q_expert, | |
| ).mean() | |
| assert log_probs.shape == q_min.shape, f"{log_probs.shape} != {q_min.shape}" | |
| loss_actor = alpha * log_probs - (q_min - q_expert) | |
| imitation_coef = imitation_coef if imitation_coef is not None else 0.0 | |
| # imitation_coef = jnp.mean(jnp.maximum(0, q_expert - q_preds)) | |
| imitation_coef = 0.0 | |
| std_loss = 0.0 | |
| total_loss = (loss_actor + imitation_coef * imitation_loss + std_loss).mean() | |
| # jax.debug.print( | |
| # "relative:{x}, abs_loss:{y}", | |
| # x=jnp.round( | |
| # std_loss | |
| # / ( | |
| # abs(loss_actor) + abs(imitation_coef * imitation_loss) + std_loss | |
| # ).mean(), | |
| # 2, | |
| # ), | |
| # y=std_loss, | |
| # ) | |
| # jax.debug.print( | |
| # "alpha:{alpha}, loss_actor:{loss_actor}, imit:{x}", | |
| # alpha=alpha, | |
| # loss_actor=loss_actor.mean(), | |
| # x=imitation_coef * imitation_loss, | |
| # ) | |
| # jax.debug.print( | |
| # "Policy loss: {loss}, Imitation loss: {imitation_loss}", | |
| # loss=total_loss, | |
| # imitation_loss=imitation_loss, | |
| # ) | |
| return total_loss, PolicyAuxiliaries( | |
| policy_loss=total_loss, | |
| log_pi=log_probs.mean(), | |
| q_min=q_min.mean(), | |
| imitation_loss=imitation_loss, | |
| raw_loss=loss_actor.mean(), | |
| ) | |
| @partial( | |
| jax.jit, | |
| static_argnames=["target_entropy"], | |
| ) | |
| def temperature_loss_function( | |
| log_alpha_params: FrozenDict, | |
| corrected_log_probs: jax.Array, | |
| target_entropy: float, | |
| ) -> Tuple[jax.Array, TemperatureAuxiliaries]: | |
| """ | |
| Compute the loss for the temperature parameter (alpha). | |
| Args: | |
| log_alpha_params (FrozenDict): Logarithm of alpha parameters. | |
| corrected_log_probs (jax.Array): Log probabilities of actions. | |
| target_entropy (float): Target entropy value. | |
| Returns: | |
| Tuple[jax.Array, Dict[str, Any]]: Loss and auxiliary metrics. | |
| """ | |
| log_alpha = log_alpha_params["log_alpha"] | |
| alpha = jnp.exp(log_alpha) | |
| loss = ( | |
| log_alpha * jax.lax.stop_gradient(-corrected_log_probs - target_entropy) | |
| ).mean() | |
| # jax.debug.print( | |
| # ( | |
| # "target_entropy:{x} entropy_loss:{y}" | |
| # " corrected_log_probs:{corrected_log_probs} log_alpha:{log_alpha}" | |
| # ), | |
| # x=target_entropy, | |
| # y=loss, | |
| # corrected_log_probs=corrected_log_probs.mean(), | |
| # log_alpha=log_alpha, | |
| # ) | |
| return loss, TemperatureAuxiliaries( | |
| alpha_loss=loss, alpha=alpha, log_alpha=log_alpha | |
| ) | |
| @partial( | |
| jax.jit, | |
| static_argnames=["recurrent", "gamma", "reward_scale"], | |
| ) | |
| def update_value_functions( | |
| agent_state: SACState, | |
| observations: jax.Array, | |
| actions: jax.Array, | |
| next_observations: jax.Array, | |
| dones: Optional[jax.Array], | |
| recurrent: bool, | |
| rewards: jax.Array, | |
| gamma: float, | |
| reward_scale: float = 1.0, # Add reward scaling factor here | |
| ) -> Tuple[SACState, Dict[str, Any]]: | |
| """ | |
| Update the critic networks using the value loss. | |
| Args: | |
| agent_state (SACState): Current SAC agent state. | |
| observations (jax.Array): Current observations. | |
| actions (jax.Array): Actions taken. | |
| next_observations (jax.Array): Next observations. | |
| dones (Optional[jax.Array]): Done flags. | |
| recurrent (bool): Whether the model is recurrent. | |
| rewards (jax.Array): Rewards received. | |
| gamma (float): Discount factor. | |
| reward_scale (float): Reward scaling factor. | |
| Returns: | |
| Tuple[SACState, Dict[str, Any]]: Updated agent state and auxiliary metrics. | |
| """ | |
| value_loss_key, rng = jax.random.split(agent_state.rng) | |
| value_and_grad_fn = jax.value_and_grad(value_loss_function, has_aux=True) | |
| log_alpha = agent_state.alpha.params["log_alpha"] | |
| alpha = jnp.exp(log_alpha) | |
| (loss, aux), grads = value_and_grad_fn( | |
| agent_state.critic_state.params, | |
| agent_state.critic_state, | |
| value_loss_key, | |
| agent_state.actor_state, | |
| actions, | |
| observations, | |
| next_observations, | |
| dones, | |
| rewards, | |
| gamma, | |
| alpha, | |
| recurrent, | |
| reward_scale, | |
| ) | |
| updated_critic_state = agent_state.critic_state.apply_gradients(grads=grads) | |
| agent_state = agent_state.replace( | |
| rng=rng, | |
| critic_state=updated_critic_state, | |
| ) | |
| return agent_state, aux | |
| def update_lambda(lambda_old, mean_imitation_loss, target=0.2, eta=0.01): | |
| # Dual ascent update | |
| delta = mean_imitation_loss - target | |
| lambda_new = lambda_old * jnp.exp(eta * jax.lax.stop_gradient(delta)) | |
| # Optional: clip lambda to avoid extreme values | |
| lambda_new = jnp.clip(lambda_new, 1e-6, 1e6) | |
| return lambda_new | |
| @partial( | |
| jax.jit, | |
| static_argnames=[ | |
| "recurrent", | |
| "expert_policy", | |
| "distance_to_stable", | |
| "imitation_coef_offset", | |
| ], | |
| ) | |
| def update_policy( | |
| agent_state: SACState, | |
| observations: jax.Array, | |
| done: Optional[jax.Array], | |
| recurrent: bool, | |
| raw_observations: jax.Array, | |
| expert_policy: Optional[Callable] = None, | |
| imitation_coef: float = 1e-3, | |
| distance_to_stable: Callable = get_one, | |
| imitation_coef_offset: float = 1e-3, | |
| ) -> Tuple[SACState, Dict[str, Any]]: | |
| """ | |
| Update the actor network using the policy loss. | |
| Args: | |
| agent_state (SACState): Current SAC agent state. | |
| observations (jax.Array): Current observations. | |
| done (Optional[jax.Array]): Done flags. | |
| recurrent (bool): Whether the model is recurrent. | |
| Returns: | |
| Tuple[SACState, Dict[str, Any]]: Updated agent state and auxiliary metrics. | |
| """ | |
| rng, policy_key = jax.random.split(agent_state.rng) | |
| value_and_grad_fn = jax.value_and_grad( | |
| policy_loss_function, has_aux=True, argnums=0 | |
| ) | |
| log_alpha = agent_state.alpha.params["log_alpha"] | |
| alpha = jnp.maximum(jnp.exp(log_alpha), 0.1) | |
| (loss, aux), grads = value_and_grad_fn( | |
| agent_state.actor_state.params, | |
| agent_state.actor_state, | |
| agent_state.critic_state, | |
| observations, | |
| done, | |
| recurrent, | |
| alpha, | |
| policy_key, | |
| raw_observations=raw_observations, | |
| expert_policy=expert_policy, | |
| imitation_coef=imitation_coef, | |
| distance_to_stable=distance_to_stable, | |
| imitation_coef_offset=imitation_coef_offset, | |
| ) | |
| def print_log_std_stats(g): | |
| # This runs outside JIT, so we can use standard numpy/jnp functions | |
| flat_grads = jnp.concatenate( | |
| [jnp.ravel(x) for x in jax.tree_util.tree_leaves(g)] | |
| ) | |
| mean_val = jnp.mean(flat_grads) | |
| max_val = jnp.max(jnp.abs(flat_grads)) | |
| print(f"[DEBUG] log_std Grads -> Mean: {mean_val:.8f} | MaxAbs: {max_val:.8f}") | |
| # # This will trigger every time the JIT'd function is called | |
| # # jax.debug.callback(print_log_std_stats, grads["params"]["log_std"]) | |
| # # 1. Flatten both trees to lists of arrays | |
| # params_flat = jax.tree_util.tree_leaves(agent_state.actor_state.params) | |
| # grads_flat = jax.tree_util.tree_leaves(grads) | |
| # # 2. Update the values in the flat list | |
| # # This assumes they have the same number of leaves (params) | |
| # updated_flat = [p - 0.05 * g for p, g in zip(params_flat, grads_flat)] | |
| # # 3. Reconstruct the dictionary using the original parameter structure | |
| # manual_params = jax.tree_util.tree_unflatten( | |
| # jax.tree_util.tree_structure(agent_state.actor_state.params), updated_flat | |
| # ) | |
| # # 2. Update the actor state with these manual parameters | |
| # updated_actor_state = agent_state.actor_state.replace(params=manual_params) | |
| # # jax.debug.callback(print_log_std_stats, manual_update) | |
| updated_actor_state = agent_state.actor_state.apply_gradients(grads=grads) | |
| agent_state = agent_state.replace( | |
| rng=rng, | |
| actor_state=updated_actor_state, | |
| ) | |
| return agent_state, aux | |
| @partial( | |
| jax.jit, | |
| static_argnames=["target_entropy", "recurrent"], | |
| ) | |
| def update_temperature( | |
| agent_state: SACState, | |
| observations: jax.Array, | |
| dones: Optional[jax.Array], | |
| target_entropy: float, | |
| recurrent: bool, | |
| ) -> Tuple[SACState, Dict[str, Any]]: | |
| """ | |
| Update the temperature parameter (alpha) using the alpha loss. | |
| Args: | |
| agent_state (SACState): Current SAC agent state. | |
| observations (jax.Array): Current observations. | |
| dones (Optional[jax.Array]): Done flags. | |
| target_entropy (float): Target entropy value. | |
| recurrent (bool): Whether the model is recurrent. | |
| Returns: | |
| Tuple[SACState, Dict[str, Any]]: Updated agent state and auxiliary metrics. | |
| """ | |
| loss_fn = jax.value_and_grad(temperature_loss_function, has_aux=True) | |
| pi, _ = get_pi( | |
| actor_state=agent_state.actor_state, | |
| actor_params=agent_state.actor_state.params, | |
| obs=observations, | |
| done=dones, | |
| recurrent=recurrent, | |
| ) | |
| rng, sample_key = jax.random.split(agent_state.rng) | |
| actions, log_probs = pi.sample_and_log_prob(seed=sample_key) | |
| unsquashed_entropy = pi.unsquashed_entropy().sum(-1) | |
| entropy = ( | |
| unsquashed_entropy + jnp.log(1 - jnp.tanh(jnp.arctanh(actions) ** 2)).mean() | |
| ) | |
| current_entropy = pi.effective_entropy(sample_key) | |
| (loss, aux), grads = loss_fn( | |
| agent_state.alpha.params, | |
| current_entropy, | |
| # log_probs.sum(-1), | |
| target_entropy, | |
| ) | |
| new_alpha_state = agent_state.alpha.apply_gradients(grads=grads) | |
| new_agent_state = agent_state.replace( | |
| rng=rng, | |
| alpha=new_alpha_state, | |
| ) | |
| # jax.debug.print("Alpha state:{x}", x=jnp.exp(new_alpha_state.params["log_alpha"])) | |
| # jax.debug.print( | |
| # "log_probs:{y}, unsquashed_stddev:{z}", | |
| # y=log_probs.sum(-1).mean(), | |
| # z=pi.unsquashed_stddev().mean(), | |
| # ) | |
| return new_agent_state, jax.lax.stop_gradient(aux) | |
| @partial( | |
| jax.jit, | |
| static_argnames=["tau"], | |
| ) | |
| def update_target_networks( | |
| agent_state: SACState, | |
| tau: float, | |
| ) -> SACState: | |
| """ | |
| Perform a soft update of the target networks. | |
| Args: | |
| agent_state (SACState): Current SAC agent state. | |
| tau (float): Soft update coefficient. | |
| Returns: | |
| SACState: Updated agent state. | |
| """ | |
| new_critic_state = agent_state.critic_state.soft_update(tau=tau) | |
| return agent_state.replace( | |
| critic_state=new_critic_state, | |
| ) | |
| @partial( | |
| jax.jit, | |
| static_argnames=[ | |
| "recurrent", | |
| "buffer", | |
| "gamma", | |
| "tau", | |
| "action_dim", | |
| "num_critic_updates", | |
| "reward_scale", | |
| "target_update_frequency", | |
| "transition_mix_fraction", | |
| "expert_policy", | |
| # "imitation_coef", | |
| "distance_to_stable", | |
| "imitation_coef_offset", | |
| "target_entropy", | |
| ], | |
| ) | |
| def update_agent( | |
| agent_state: SACState, | |
| _: Any, | |
| buffer: BufferType, | |
| recurrent: bool, | |
| gamma: float, | |
| action_dim: int, | |
| target_entropy: float, | |
| tau: float, | |
| num_critic_updates: int = 1, | |
| target_update_frequency: int = 1, | |
| reward_scale: float = 5.0, | |
| additional_transition: Optional[Any] = None, | |
| transition_mix_fraction: float = 1.0, # part of original buffer sample to keep TODO : add control over this hyperparameter | |
| expert_policy: Optional[Callable] = None, | |
| imitation_coef: float = 0.0, | |
| distance_to_stable: Callable = get_one, | |
| imitation_coef_offset: float = 0.0, | |
| policy_update_start: int = 20_000, | |
| ) -> Tuple[SACState, AuxiliaryLogs]: | |
| """ | |
| Update the SAC agent, including critic, actor, and temperature updates. | |
| Args: | |
| agent_state (SACState): Current SAC agent state. | |
| _ (Any): Placeholder for scan compatibility. | |
| buffer (BufferType): Replay buffer. | |
| recurrent (bool): Whether the model is recurrent. | |
| gamma (float): Discount factor. | |
| action_dim (int): Action dimensionality. | |
| tau (float): Soft update coefficient. | |
| num_critic_updates (int): Number of critic updates per step. | |
| target_update_frequency (int): Frequency of target network updates. | |
| reward_scale (float): Reward scaling factor. | |
| Returns: | |
| Tuple[SACState, None]: Updated agent state. | |
| """ | |
| # Sample buffer | |
| sample_key, rng = jax.random.split(agent_state.rng) | |
| if buffer is not None and agent_state.collector_state.buffer_state is not None: | |
| ( | |
| observations, | |
| terminated, | |
| truncated, | |
| next_observations, | |
| rewards, | |
| actions, | |
| raw_observations, | |
| ) = get_batch_from_buffer( | |
| buffer, | |
| agent_state.collector_state.buffer_state, | |
| sample_key, | |
| ) | |
| original_transition = Transition( | |
| observations, actions, rewards, terminated, truncated, next_observations | |
| ) | |
| if additional_transition is not None and transition_mix_fraction < 1.0: | |
| assert transition_mix_fraction >= 0 and transition_mix_fraction <= 1.0, ( | |
| "transition_mix_fraction should be between 0 and 1, got", | |
| transition_mix_fraction, | |
| ) | |
| len_original = len(observations) | |
| n_samples_from_original = floor(transition_mix_fraction * len_original) | |
| n_samples_from_transition = len_original - n_samples_from_original | |
| print( | |
| f"samples from buffer:{n_samples_from_original} online" | |
| f" samples:{n_samples_from_transition}" | |
| ) | |
| additional_transition = jax.tree.map( | |
| lambda x: jax.random.choice( | |
| sample_key, x, shape=(n_samples_from_transition,) | |
| ), | |
| additional_transition, | |
| ) | |
| transition = jax.tree.map( | |
| lambda x, y: ( | |
| None | |
| if (x is None or y is None) | |
| else jnp.concatenate([x[:n_samples_from_original], y], axis=0) | |
| ), | |
| original_transition, | |
| additional_transition, | |
| is_leaf=lambda x: x is None, | |
| ) | |
| else: | |
| transition = original_transition | |
| elif additional_transition is not None: | |
| # If no buffer is provided, use the collector state to get the latest transition | |
| transition = additional_transition | |
| terminated = additional_transition.terminated | |
| truncated = additional_transition.truncated | |
| agent_state = agent_state.replace(rng=rng) | |
| dones = jnp.logical_or(transition.terminated, transition.truncated) | |
| # Update Q functions | |
| def critic_update_step(carry, _): | |
| agent_state = carry | |
| agent_state, aux_value = update_value_functions( | |
| observations=transition.obs, | |
| actions=transition.action, | |
| next_observations=transition.next_obs, | |
| rewards=transition.reward, | |
| dones=dones, | |
| agent_state=agent_state, | |
| recurrent=recurrent, | |
| gamma=gamma, | |
| reward_scale=reward_scale, | |
| ) | |
| return agent_state, aux_value | |
| agent_state, aux_value = jax.lax.scan( | |
| critic_update_step, | |
| agent_state, | |
| None, | |
| length=num_critic_updates, | |
| ) | |
| # Update policy | |
| new_agent_state, aux_policy = update_policy( | |
| observations=transition.obs, | |
| done=dones, | |
| agent_state=agent_state, | |
| recurrent=recurrent, | |
| raw_observations=raw_observations, | |
| expert_policy=expert_policy, | |
| imitation_coef=agent_state.lambda_param, | |
| distance_to_stable=distance_to_stable, | |
| imitation_coef_offset=imitation_coef_offset, | |
| ) | |
| agent_state = jax.lax.cond( | |
| agent_state.collector_state.timestep >= policy_update_start, | |
| lambda: new_agent_state, | |
| lambda: agent_state, | |
| ) | |
| # new_lambda = update_lambda( | |
| # agent_state.lambda_param, aux_policy.imitation_loss, target=0.1 | |
| # ) | |
| new_lambda = aux_value.var_preds.mean() / jnp.exp( | |
| agent_state.alpha.params["log_alpha"] | |
| ) | |
| # jax.debug.print("{x}", x=new_lambda) | |
| agent_state = agent_state.replace(lambda_param=new_lambda) | |
| # Adjust temperature | |
| # target_entropy = -action_dim | |
| agent_state, aux_temperature = update_temperature( | |
| agent_state, | |
| observations=transition.obs, | |
| target_entropy=target_entropy, | |
| recurrent=recurrent, | |
| dones=dones, | |
| ) | |
| # jax.debug.breakpoint() | |
| # Update target networks | |
| # TODO : Only update every update_target_network steps | |
| agent_state = update_target_networks(agent_state, tau=tau) | |
| aux = AuxiliaryLogs( | |
| temperature=aux_temperature, | |
| policy=aux_policy, | |
| value=ValueAuxiliaries( | |
| **{key: val.flatten() for key, val in to_state_dict(aux_value).items()} | |
| ), | |
| ) | |
| return agent_state, aux | |
| @partial( | |
| jax.jit, | |
| static_argnames=[ | |
| "env_args", | |
| "mode", | |
| "recurrent", | |
| "buffer", | |
| "log_frequency", | |
| "num_episode_test", | |
| "log_fn", | |
| "log", | |
| "verbose", | |
| "action_dim", | |
| "lstm_hidden_size", | |
| "agent_config", | |
| "horizon", | |
| "total_timesteps", | |
| "n_epochs", | |
| "transition_mix_fraction", | |
| "expert_policy", | |
| "imitation_coef", | |
| "distance_to_stable", | |
| "imitation_coef_offset", | |
| "action_scale", | |
| "early_termination_condition", | |
| ], | |
| ) | |
| def training_iteration( | |
| agent_state: SACState, | |
| _: Any, | |
| env_args: EnvironmentConfig, | |
| mode: str, | |
| recurrent: bool, | |
| buffer: BufferType, | |
| agent_config: SACConfig, | |
| action_dim: int, | |
| total_timesteps: int, | |
| lstm_hidden_size: Optional[int] = None, | |
| log_frequency: int = 1000, | |
| horizon: int = 10000, | |
| num_episode_test: int = 10, | |
| log_fn: Optional[Callable] = None, | |
| index: Optional[int] = None, | |
| log: bool = False, | |
| verbose: bool = False, | |
| n_epochs: int = 1, | |
| transition_mix_fraction: float = 1.0, | |
| expert_policy: Optional[Callable] = None, | |
| imitation_coef: float = 1e-3, | |
| distance_to_stable: Callable = get_one, | |
| imitation_coef_offset: float = 1e-3, | |
| action_scale: float = 1.0, | |
| early_termination_condition: Optional[Callable] = None, | |
| ) -> tuple[SACState, None]: | |
| """ | |
| Perform one training iteration, including experience collection and agent updates. | |
| Args: | |
| agent_state (SACState): Current SAC agent state. | |
| _ (Any): Placeholder for scan compatibility. | |
| env_args (EnvironmentConfig): Environment configuration. | |
| mode (str): Environment mode ("gymnax" or "brax"). | |
| recurrent (bool): Whether the model is recurrent. | |
| buffer (BufferType): Replay buffer. | |
| agent_config (SACConfig): SAC agent configuration. | |
| action_dim (int): Action dimensionality. | |
| lstm_hidden_size (Optional[int]): LSTM hidden size for recurrent models. | |
| log_frequency (int): Frequency of logging and evaluation. | |
| num_episode_test (int): Number of episodes for evaluation. | |
| Returns: | |
| Tuple[SACState, None]: Updated agent state. | |
| """ | |
| # collector_state = agent_state.collector_state | |
| timestep = agent_state.collector_state.timestep | |
| uniform = should_use_uniform_sampling(timestep, agent_config.learning_starts * 2) | |
| # always_expert = should_use_uniform_sampling(timestep, agent_config.learning_starts) | |
| # always_expert = ( | |
| # should_use_uniform_sampling(timestep, agent_config.learning_starts) | |
| # and expert_policy is not None | |
| # ) | |
| collect_scan_fn = partial( | |
| collect_experience, | |
| recurrent=recurrent, | |
| mode=mode, | |
| env_args=env_args, | |
| buffer=buffer, | |
| uniform=uniform, | |
| expert_policy=expert_policy, | |
| action_scale=action_scale, | |
| always_expert=False, # always_expert if expert_policy is not None else False, | |
| ) | |
| agent_state, transition = jax.lax.scan( | |
| collect_scan_fn, agent_state, xs=None, length=1 | |
| ) | |
| timestep = agent_state.collector_state.timestep | |
| def do_update(agent_state): | |
| update_scan_fn = partial( | |
| update_agent, | |
| buffer=buffer, | |
| recurrent=recurrent, | |
| gamma=agent_config.gamma, | |
| action_dim=action_dim, | |
| target_entropy=agent_config.target_entropy, | |
| tau=agent_config.tau, | |
| reward_scale=agent_config.reward_scale, | |
| additional_transition=( | |
| jax.tree.map(lambda x: x.squeeze(0), transition) | |
| if transition_mix_fraction < 1.0 | |
| else None | |
| ), | |
| transition_mix_fraction=transition_mix_fraction, | |
| expert_policy=expert_policy, | |
| imitation_coef=imitation_coef.at( | |
| agent_state.collector_state.train_time_fraction | |
| ), | |
| distance_to_stable=distance_to_stable, | |
| imitation_coef_offset=imitation_coef_offset, | |
| ) | |
| agent_state, aux = jax.lax.scan( | |
| update_scan_fn, agent_state, xs=None, length=n_epochs | |
| ) | |
| aux = jax.tree.map( | |
| lambda x: x[-1].reshape((1,)), aux | |
| ) # keep only the final state across epochs | |
| aux = aux.replace( | |
| value=ValueAuxiliaries( | |
| **{key: val.flatten() for key, val in to_state_dict(aux.value).items()} | |
| ) | |
| ) | |
| return agent_state, aux | |
| def fill_with_nan(dataclass): | |
| """ | |
| Recursively fills all fields of a dataclass with jnp.nan. | |
| """ | |
| nan = jnp.ones(1) * jnp.nan | |
| dict = {} | |
| for field in fields(dataclass): | |
| sub_dataclass = field.type | |
| if hasattr( | |
| sub_dataclass, "__dataclass_fields__" | |
| ): # Check if the field is another dataclass | |
| dict[field.name] = fill_with_nan(sub_dataclass) | |
| else: | |
| dict[field.name] = nan | |
| return dataclass(**dict) | |
| def skip_update(agent_state): | |
| return agent_state, fill_with_nan(AuxiliaryLogs) | |
| agent_state, aux = jax.lax.cond( | |
| timestep >= agent_config.learning_starts, | |
| do_update, | |
| skip_update, | |
| operand=agent_state, | |
| ) | |
| agent_state, metrics_to_log = evaluate_and_log( | |
| agent_state, | |
| aux, | |
| index, | |
| mode, | |
| env_args, | |
| num_episode_test, | |
| recurrent, | |
| lstm_hidden_size, | |
| log, | |
| verbose, | |
| log_fn, | |
| log_frequency, | |
| total_timesteps, | |
| expert_policy=expert_policy, | |
| imitation_coef=imitation_coef, | |
| action_scale=action_scale, | |
| early_termination_condition=early_termination_condition, | |
| train_frac=agent_state.collector_state.train_time_fraction, | |
| ) | |
| return agent_state, metrics_to_log | |
| def make_train( | |
| env_args: EnvironmentConfig, | |
| actor_optimizer_args: OptimizerConfig, | |
| critic_optimizer_args: OptimizerConfig, | |
| network_args: NetworkConfig, | |
| buffer: BufferType, | |
| agent_config: SACConfig, | |
| alpha_args: AlphaConfig, | |
| total_timesteps: int, | |
| num_episode_test: int, | |
| run_ids: Optional[Sequence[str]] = None, | |
| logging_config: Optional[LoggingConfig] = None, | |
| cloning_args: Optional[CloningConfig] = None, | |
| expert_policy: Optional[Callable] = None, | |
| early_termination_condition: Optional[Callable] = None, | |
| residual: bool = False, | |
| fixed_alpha: bool = False, | |
| num_critics: int = 2, | |
| ): | |
| """ | |
| Create the training function for the SAC agent. | |
| Args: | |
| env_args (EnvironmentConfig): Environment configuration. | |
| optimizer_args (OptimizerConfig): Optimizer configuration. | |
| network_args (NetworkConfig): Network configuration. | |
| buffer (BufferType): Replay buffer. | |
| agent_config (SACConfig): SAC agent configuration. | |
| alpha_args (AlphaConfig): Alpha configuration. | |
| total_timesteps (int): Total timesteps for training. | |
| num_episode_test (int): Number of episodes for evaluation during training. | |
| Returns: | |
| Callable: JIT-compiled training function. | |
| """ | |
| mode = "gymnax" if check_env_is_gymnax(env_args.env) else "brax" | |
| log = logging_config is not None | |
| log_fn = partial(vmap_log, run_ids=run_ids, logging_config=logging_config) | |
| # Start async logging if logging is enabled | |
| if logging_config is not None: | |
| start_async_logging() | |
| @partial(jax.jit) | |
| def train(key, index: Optional[int] = None): | |
| """Train the SAC agent.""" | |
| init_key, expert_key = jax.random.split(key) | |
| agent_state = init_SAC( | |
| key=init_key, | |
| env_args=env_args, | |
| actor_optimizer_args=actor_optimizer_args, | |
| critic_optimizer_args=critic_optimizer_args, | |
| network_args=network_args, | |
| alpha_args=alpha_args, | |
| buffer=buffer, | |
| expert_policy=expert_policy, | |
| residual=residual, | |
| fixed_alpha=fixed_alpha, | |
| max_timesteps=total_timesteps, | |
| num_critics=num_critics, | |
| ) | |
| # pre-train agent | |
| ( | |
| cloning_parameters, | |
| pre_train_n_steps, | |
| ) = get_cloning_args(cloning_args, total_timesteps) | |
| if pre_train_n_steps > 0: | |
| agent_state = get_pre_trained_agent( | |
| agent_state, | |
| expert_policy, | |
| expert_key, | |
| env_args, | |
| cloning_args, | |
| mode, | |
| agent_config, | |
| actor_optimizer_args, | |
| critic_optimizer_args, | |
| ) | |
| num_updates = total_timesteps // env_args.n_envs | |
| _, action_shape = get_state_action_shapes(env_args.env) | |
| training_iteration_scan_fn = partial( | |
| training_iteration, | |
| buffer=buffer, | |
| recurrent=network_args.lstm_hidden_size is not None, | |
| action_dim=action_shape[0], | |
| agent_config=agent_config, | |
| mode=mode, | |
| env_args=env_args, | |
| num_episode_test=num_episode_test, | |
| log_fn=log_fn, | |
| index=index, | |
| log=log, | |
| total_timesteps=total_timesteps, | |
| log_frequency=( | |
| logging_config.log_frequency if logging_config is not None else None | |
| ), | |
| horizon=(logging_config.horizon if logging_config is not None else None), | |
| expert_policy=expert_policy, | |
| early_termination_condition=early_termination_condition, | |
| **cloning_parameters, | |
| ) | |
| agent_state, out = jax.lax.scan( | |
| f=training_iteration_scan_fn, | |
| init=agent_state, | |
| xs=None, | |
| length=num_updates, | |
| ) | |
| return agent_state, out | |
| return train |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment