Skip to content

Instantly share code, notes, and snippets.

View elumixor's full-sized avatar
🏠
Working from home

Vladyslav Yazykov elumixor

🏠
Working from home
  • Prague
  • 19:09 (UTC +02:00)
View GitHub Profile
delta = 0.01. # Should be low (approximately betwween 0.01 and 0.05
max_length = torch.sqrt(2 * delta / (search_dir @ HVP(search_dir)))
max_step = max_length * search_dir
def HVP(v):
return flat_grad(d_kl @ v, parameters, retain_graph=True)
search_dir = conjugate_gradient(HVP, g)
@elumixor
elumixor / TRPO_update_2.py
Last active May 24, 2020 23:02
TRPO update 2
# We will calculate the gradient wrt to the new probabilities (surrogate function),
# so second probabilities should be treated as a constant
L = surrogate_loss(probabilities, probabilities.detach(), advantages)
KL = kl_div(distribution, distribution)
parameters = list(actor.parameters())
# Retain, because we will use the graph several times
g = flat_grad(L, parameters, retain_graph=True)
@elumixor
elumixor / TRPO_update_1.py
Last active May 24, 2020 23:23
TRPO update 1
def update_agent(rollouts):
states = torch.cat([r.states for r in rollouts], dim=0)
actions = torch.cat([r.actions for r in rollouts], dim=0).flatten()
advantages = [estimate_advantages(states, next_states[-1], rewards) for states, _, rewards, next_states in rollouts]
advantages = torch.cat(advantages, dim=0).flatten()
# Normalize advantages to reduce skewness and improve convergence
advantages = (advantages - advantages.mean()) / advantages.std()
@elumixor
elumixor / TRPO_critic_update.py
Last active May 24, 2020 22:45
TRPO critic update
from torch.optim import Adam
critic_optimizer = Adam(critic.parameters(), lr=0.005)
def update_critic(advantages):
loss = .5 * (advantages ** 2).mean()
critic_optimizer.zero_grad()
loss.backward()
@elumixor
elumixor / TRPO_critic.py
Created May 24, 2020 22:39
TRPO critic network
# Critic takes a state and returns its values
critic_hidden = 32
critic = nn.Sequential(nn.Linear(obs_shape[0], critic_hidden),
nn.ReLU(),
nn.Linear(critic_hidden, 1))
@elumixor
elumixor / TRPO_get_action.py
Created May 24, 2020 22:32
TRPO Get Action
from torch.distributions import Categorical
def get_action(state):
state = torch.tensor(state).float().unsqueeze(0) # Turn state into a batch with a single element
dist = Categorical(actor(state)) # Create a distribution from probabilities for actions
return dist.sample().item()
@elumixor
elumixor / TRPO_actor.py
Created May 24, 2020 22:28
Medium TRPO Actor
import torch.nn as nn
actor_hidden = 32
actor = nn.Sequential(nn.Linear(state_size, actor_hidden),
nn.ReLU(),
nn.Linear(actor_hidden, num_actions),
nn.Softmax(dim=1))
@elumixor
elumixor / gym_generic_train.py
Last active May 24, 2020 22:27
Medium TRPO Files
from collections import namedtuple
import gym
import torch
env = gym.make('CartPole-v0')
obs_size = env.observation_space.shape[0]
num_actions = env.action_space.n
import numpy as np
import torch
from torch.utils.data import SubsetRandomSampler, DataLoader
from time import sleep
from IPython.display import clear_output, display
import os
class Trainer: