Skip to content

Instantly share code, notes, and snippets.

@kengz
Created June 22, 2018 05:10
Show Gist options
  • Save kengz/20d43856877462cae13d041fd6f7f7fa to your computer and use it in GitHub Desktop.
Save kengz/20d43856877462cae13d041fd6f7f7fa to your computer and use it in GitHub Desktop.
def train_shared(self):
'''
Trains the network when the actor and critic share parameters
'''
if self.to_train == 1:
# onpolicy a2c update
a2c_loss = super(SIL, self).train_shared()
# offpolicy sil update with random minibatch
total_sil_loss = torch.tensor(0.0)
for _ in range(self.training_epoch):
batch = self.replay_sample()
sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch)
sil_loss = self.policy_loss_coef * sil_policy_loss + self.val_loss_coef * sil_val_loss
self.net.training_step(loss=sil_loss)
total_sil_loss += sil_loss
sil_loss = total_sil_loss / self.training_epoch
loss = a2c_loss + sil_loss
self.last_loss = loss.item()
return self.last_loss
def train_separate(self):
'''
Trains the network when the actor and critic are separate networks
'''
if self.to_train == 1:
# onpolicy a2c update
a2c_loss = super(SIL, self).train_separate()
# offpolicy sil update with random minibatch
total_sil_loss = torch.tensor(0.0)
for _ in range(self.training_epoch):
batch = self.replay_sample()
sil_policy_loss, sil_val_loss = self.calc_sil_policy_val_loss(batch)
sil_policy_loss = self.policy_loss_coef * sil_policy_loss
sil_val_loss = self.val_loss_coef * sil_val_loss
self.net.training_step(loss=sil_policy_loss, retain_graph=True)
self.critic.training_step(loss=sil_val_loss)
total_sil_loss += sil_policy_loss + sil_val_loss
sil_loss = total_sil_loss / self.training_epoch
loss = a2c_loss + sil_loss
self.last_loss = loss.item()
return self.last_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment