Skip to content

Instantly share code, notes, and snippets.

@elumixor
Last active May 24, 2020 23:23
Show Gist options
  • Save elumixor/c816bdff32e47c01f9af6ba77ceb5a35 to your computer and use it in GitHub Desktop.
Save elumixor/c816bdff32e47c01f9af6ba77ceb5a35 to your computer and use it in GitHub Desktop.
TRPO update 1
def update_agent(rollouts):
states = torch.cat([r.states for r in rollouts], dim=0)
actions = torch.cat([r.actions for r in rollouts], dim=0).flatten()
advantages = [estimate_advantages(states, next_states[-1], rewards) for states, _, rewards, next_states in rollouts]
advantages = torch.cat(advantages, dim=0).flatten()
# Normalize advantages to reduce skewness and improve convergence
advantages = (advantages - advantages.mean()) / advantages.std()
update_critic(advantages)
distribution = actor(states)
# Important! We clamp the probabilities, so they do not reach zero
distribution = torch.distributions.utils.clamp_probs(distribution).
probabilities = distribution[range(distribution.shape[0]), actions]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment