Last active
August 11, 2019 18:16
-
-
Save kengz/ba6ec0b097eb17f6beaa45c4ce82e7a7 to your computer and use it in GitHub Desktop.
SAC loss functions
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def calc_q(self, state, action, net): | |
'''Forward-pass to calculate the predicted state-action-value from q1_net.''' | |
q_pred = net(state, action).view(-1) | |
return q_pred | |
def calc_q_targets(self, batch): | |
'''Q_tar = r + gamma * (target_Q(s', a') - alpha * log pi(a'|s'))''' | |
next_states = batch['next_states'] | |
with torch.no_grad(): | |
pdparams = self.calc_pdparam(next_states) | |
action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) | |
next_log_probs, next_actions = self.calc_log_prob_action(action_pd) | |
next_actions = self.guard_q_actions(next_actions) # non-reparam discrete actions need to be converted into one-hot | |
next_target_q1_preds = self.calc_q(next_states, next_actions, self.target_q1_net) | |
next_target_q2_preds = self.calc_q(next_states, next_actions, self.target_q2_net) | |
next_target_q_preds = torch.min(next_target_q1_preds, next_target_q2_preds) | |
q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * (next_target_q_preds - self.alpha * next_log_probs) | |
return q_targets | |
def calc_reg_loss(self, preds, targets): | |
'''Calculate the regression loss for V and Q values, using the same loss function from net_spec''' | |
assert preds.shape == targets.shape, f'{preds.shape} != {targets.shape}' | |
reg_loss = self.net.loss_fn(preds, targets) | |
return reg_loss | |
def calc_policy_loss(self, batch, log_probs, reparam_actions): | |
'''policy_loss = alpha * log pi(f(a)|s) - Q1(s, f(a)), where f(a) = reparametrized action''' | |
states = batch['states'] | |
q1_preds = self.calc_q(states, reparam_actions, self.q1_net) | |
q2_preds = self.calc_q(states, reparam_actions, self.q2_net) | |
q_preds = torch.min(q1_preds, q2_preds) | |
policy_loss = (self.alpha * log_probs - q_preds).mean() | |
return policy_loss | |
def calc_alpha_loss(self, log_probs): | |
alpha_loss = - (self.log_alpha * (log_probs.detach() + self.target_entropy)).mean() | |
return alpha_loss |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment