Skip to content

Instantly share code, notes, and snippets.

@elumixor
Created May 24, 2020 23:17
Show Gist options
  • Save elumixor/1221af4313e14bc971b057c1137f75dc to your computer and use it in GitHub Desktop.
Save elumixor/1221af4313e14bc971b057c1137f75dc to your computer and use it in GitHub Desktop.
def update_agent(rollouts):
states = torch.cat([r.states for r in rollouts], dim=0)
actions = torch.cat([r.actions for r in rollouts], dim=0).flatten()
advantages = [estimate_advantages(states, next_states[-1], rewards) for states, _, rewards, next_states in rollouts]
advantages = normalize(torch.cat(advantages, dim=0).flatten())
update_critic(advantages)
distribution = actor(states)
distribution = torch.distributions.utils.clamp_probs(distribution)
probabilities = distribution[range(distribution.shape[0]), actions]
# We will calculate the gradient wrt to the new probabilities (surrogate function),
# so second probabilities should be treated as a constant
L = surrogate_loss(probabilities, probabilities.detach(), advantages)
KL = kl_div(distribution, distribution)
parameters = list(actor.parameters())
g = flat_grad(L, parameters, retain_graph=True) # We will use the graph several times
d_kl = flat_grad(KL, parameters, create_graph=True) # Create graph, because we will call backward() on it (for HVP)
def HVP(v):
return flat_grad(d_kl @ v, parameters, retain_graph=True)
search_dir = conjugate_gradient(HVP, g, max_iterations=iterations)
max_length = torch.sqrt(2 * delta / (search_dir @ HVP(search_dir)))
max_step = max_length * search_dir
def criterion(step):
# Apply parameters' update
apply_update(step)
with torch.no_grad():
distribution_new = actor(states)
distribution_new = torch.distributions.utils.clamp_probs(distribution_new)
probabilities_new = distribution_new[range(distribution_new.shape[0]), actions]
L_new = surrogate_loss(probabilities_new, probabilities, advantages)
KL_new = kl_div(distribution, distribution_new)
L_improvement = L_new - L
if L_improvement > 0 and KL_new <= delta:
return True
# Step size too big, reverse
apply_update(-step)
return False
i = 0
while not criterion((0.9 ** i) * max_step) and i < 10:
i += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment