Skip to content

Instantly share code, notes, and snippets.

@kengz
Last active February 18, 2019 23:50
Show Gist options
  • Save kengz/81e244e153ca0ef7a3726c365964ba81 to your computer and use it in GitHub Desktop.
Save kengz/81e244e153ca0ef7a3726c365964ba81 to your computer and use it in GitHub Desktop.
def calc_sil_policy_val_loss(self, batch):
'''
Calculate the SIL policy losses for actor and critic
sil_policy_loss = -log_prob * max(R - v_pred, 0)
sil_val_loss = norm(max(R - v_pred, 0)) / 2
This is called on a randomly-sample batch from experience replay
'''
returns = math_util.calc_returns(batch, self.gamma)
v_preds = self.calc_v(batch['states'])
clipped_advs = torch.clamp(returns - v_preds, min=0.0)
log_probs = self.calc_log_probs(batch)
sil_policy_loss = torch.mean(- log_probs * clipped_advs)
sil_val_loss = torch.norm(clipped_advs ** 2) / 2
if torch.cuda.is_available() and self.net.gpu:
sil_policy_loss = sil_policy_loss.cuda()
sil_val_loss = sil_val_loss.cuda()
return sil_policy_loss, sil_val_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment