Skip to content

Instantly share code, notes, and snippets.

View elumixor's full-sized avatar
🏠
Working from home

Vladyslav Yazykov elumixor

🏠
Working from home
  • Prague
  • 17:44 (UTC +02:00)
View GitHub Profile
import gym
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim import Adam
from torch.distributions import Categorical
from collections import namedtuple
env = gym.make('CartPole-v0')
def apply_update(grad_flattened):
n = 0
for p in actor.parameters():
numel = p.numel()
g = grad_flattened[n:n + numel].view(p.shape)
p.data += g
n += numel
def conjugate_gradient(A, b, delta=0., max_iterations=float('inf')):
x = torch.zeros_like(b)
r = b.clone()
p = b.clone()
i = 0
while i < max_iterations:
AVP = A(p)
dot_old = r @ r
def flat_grad(y, x, retain_graph=False, create_graph=False):
if create_graph:
retain_graph = True
g = torch.autograd.grad(y, x, retain_graph=retain_graph, create_graph=create_graph)
g = torch.cat([t.view(-1) for t in g])
return g
def kl_div(p, q):
p = p.detach()
return (p * (p.log() - q.log())).sum(-1).mean()
def surrogate_loss(new_probabilities, old_probabilities, advantages):
return (new_probabilities / old_probabilities * advantages).mean()
def estimate_advantages(states, last_state, rewards):
values = critic(states)
last_value = critic(last_state.unsqueeze(0))
next_values = torch.zeros_like(rewards)
for i in reversed(range(rewards.shape[0])):
last_value = next_values[i] = rewards[i] + 0.99 * last_value
advantages = next_values - values
return advantages
def update_agent(rollouts):
states = torch.cat([r.states for r in rollouts], dim=0)
actions = torch.cat([r.actions for r in rollouts], dim=0).flatten()
advantages = [estimate_advantages(states, next_states[-1], rewards) for states, _, rewards, next_states in rollouts]
advantages = normalize(torch.cat(advantages, dim=0).flatten())
update_critic(advantages)
distribution = actor(states)
i = 0
while not criterion((0.9 ** i) * max_step) and i < 10:
i += 1
def criterion(step):
# Apply parameters' update
apply_update(step)
with torch.no_grad():
distribution_new = actor(states)
distribution_new = torch.distributions.utils.clamp_probs(distribution_new)
probabilities_new = distribution_new[range(distribution_new.shape[0]), actions]
L_new = surrogate_loss(probabilities_new, probabilities, advantages)