Skip to content

Instantly share code, notes, and snippets.

View elumixor's full-sized avatar
🏠
Working from home

Vladyslav Yazykov elumixor

🏠
Working from home
  • Prague
  • 12:19 (UTC +02:00)
View GitHub Profile
@elumixor
elumixor / TRPO_update_1.py
Last active May 24, 2020 23:23
TRPO update 1
def update_agent(rollouts):
states = torch.cat([r.states for r in rollouts], dim=0)
actions = torch.cat([r.actions for r in rollouts], dim=0).flatten()
advantages = [estimate_advantages(states, next_states[-1], rewards) for states, _, rewards, next_states in rollouts]
advantages = torch.cat(advantages, dim=0).flatten()
# Normalize advantages to reduce skewness and improve convergence
advantages = (advantages - advantages.mean()) / advantages.std()
@elumixor
elumixor / TRPO_update_2.py
Last active May 24, 2020 23:02
TRPO update 2
# We will calculate the gradient wrt to the new probabilities (surrogate function),
# so second probabilities should be treated as a constant
L = surrogate_loss(probabilities, probabilities.detach(), advantages)
KL = kl_div(distribution, distribution)
parameters = list(actor.parameters())
# Retain, because we will use the graph several times
g = flat_grad(L, parameters, retain_graph=True)
def HVP(v):
return flat_grad(d_kl @ v, parameters, retain_graph=True)
search_dir = conjugate_gradient(HVP, g)
delta = 0.01. # Should be low (approximately betwween 0.01 and 0.05
max_length = torch.sqrt(2 * delta / (search_dir @ HVP(search_dir)))
max_step = max_length * search_dir
def criterion(step):
# Apply parameters' update
apply_update(step)
with torch.no_grad():
distribution_new = actor(states)
distribution_new = torch.distributions.utils.clamp_probs(distribution_new)
probabilities_new = distribution_new[range(distribution_new.shape[0]), actions]
L_new = surrogate_loss(probabilities_new, probabilities, advantages)
i = 0
while not criterion((0.9 ** i) * max_step) and i < 10:
i += 1
def update_agent(rollouts):
states = torch.cat([r.states for r in rollouts], dim=0)
actions = torch.cat([r.actions for r in rollouts], dim=0).flatten()
advantages = [estimate_advantages(states, next_states[-1], rewards) for states, _, rewards, next_states in rollouts]
advantages = normalize(torch.cat(advantages, dim=0).flatten())
update_critic(advantages)
distribution = actor(states)
def estimate_advantages(states, last_state, rewards):
values = critic(states)
last_value = critic(last_state.unsqueeze(0))
next_values = torch.zeros_like(rewards)
for i in reversed(range(rewards.shape[0])):
last_value = next_values[i] = rewards[i] + 0.99 * last_value
advantages = next_values - values
return advantages
def surrogate_loss(new_probabilities, old_probabilities, advantages):
return (new_probabilities / old_probabilities * advantages).mean()
def kl_div(p, q):
p = p.detach()
return (p * (p.log() - q.log())).sum(-1).mean()