Created
August 10, 2019 06:50
-
-
Save kengz/9b53ddbfe07fc44a02c08b92d19a54fd to your computer and use it in GitHub Desktop.
SAC training loop
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_alpha(self, alpha_loss): | |
'''Custom method to train the alpha variable''' | |
self.alpha_lr_scheduler.step(epoch=self.body.env.clock.frame) | |
self.alpha_optim.zero_grad() | |
alpha_loss.backward() | |
self.alpha_optim.step() | |
self.alpha = self.log_alpha.detach().exp() | |
def train(self): | |
'''Train actor critic by computing the loss in batch efficiently''' | |
if util.in_eval_lab_modes(): | |
return np.nan | |
clock = self.body.env.clock | |
if self.to_train == 1: | |
for _ in range(self.training_iter): | |
batch = self.sample() | |
clock.set_batch_size(len(batch)) | |
states = batch['states'] | |
actions = self.guard_q_actions(batch['actions']) | |
q_targets = self.calc_q_targets(batch) | |
# Q-value loss for both Q nets | |
q1_preds = self.calc_q(states, actions, self.q1_net) | |
q1_loss = self.calc_reg_loss(q1_preds, q_targets) | |
self.q1_net.train_step(q1_loss, self.q1_optim, self.q1_lr_scheduler, clock=clock, global_net=self.global_q1_net) | |
q2_preds = self.calc_q(states, actions, self.q2_net) | |
q2_loss = self.calc_reg_loss(q2_preds, q_targets) | |
self.q2_net.train_step(q2_loss, self.q2_optim, self.q2_lr_scheduler, clock=clock, global_net=self.global_q2_net) | |
# policy loss | |
action_pd = policy_util.init_action_pd(self.body.ActionPD, self.calc_pdparam(states)) | |
log_probs, reparam_actions = self.calc_log_prob_action(action_pd, reparam=True) | |
policy_loss = self.calc_policy_loss(batch, log_probs, reparam_actions) | |
self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) | |
# alpha loss | |
alpha_loss = self.calc_alpha_loss(log_probs) | |
self.train_alpha(alpha_loss) | |
loss = q1_loss + q2_loss + policy_loss + alpha_loss | |
# update target networks | |
self.update_nets() | |
# update PER priorities if availalbe | |
self.try_update_per(torch.min(q1_preds, q2_preds), q_targets) | |
# reset | |
self.to_train = 0 | |
logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.env.total_reward}, loss: {loss:g}') | |
return loss.item() | |
else: | |
return np.nan |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment