Skip to content

Instantly share code, notes, and snippets.

View elumixor's full-sized avatar
🏠
Working from home

Vladyslav Yazykov elumixor

🏠
Working from home
  • Prague
  • 16:19 (UTC +02:00)
View GitHub Profile
def flat_grad(y, x, retain_graph=False, create_graph=False):
if create_graph:
retain_graph = True
g = torch.autograd.grad(y, x, retain_graph=retain_graph, create_graph=create_graph)
g = torch.cat([t.view(-1) for t in g])
return g
def conjugate_gradient(A, b, delta=0., max_iterations=float('inf')):
x = torch.zeros_like(b)
r = b.clone()
p = b.clone()
i = 0
while i < max_iterations:
AVP = A(p)
dot_old = r @ r
def apply_update(grad_flattened):
n = 0
for p in actor.parameters():
numel = p.numel()
g = grad_flattened[n:n + numel].view(p.shape)
p.data += g
n += numel
import gym
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.optim import Adam
from torch.distributions import Categorical
from collections import namedtuple
env = gym.make('CartPole-v0')