Skip to content

Instantly share code, notes, and snippets.

@scturtle
Last active December 24, 2024 03:56
Show Gist options
  • Save scturtle/9d488758e98e99dae13828d2ef4c9710 to your computer and use it in GitHub Desktop.
Save scturtle/9d488758e98e99dae13828d2ef4c9710 to your computer and use it in GitHub Desktop.
Proximal Policy Optimization
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gym
# Hyperparameters
num_inputs = 4
num_actions = 2
num_hidden = 256
learning_rate = 0.001
gamma = 0.98
class Actor(nn.Module):
def __init__(self):
super(Actor, self).__init__()
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.fc2 = nn.Linear(num_hidden, num_actions)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.fc2 = nn.Linear(num_hidden, 1)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
def make_batch(data):
data = np.array(data, dtype=object)
s, a, r, s_prime, done = map(np.stack, zip(*data))
return (
torch.tensor(s, dtype=torch.float32),
torch.tensor(a, dtype=torch.int64).unsqueeze(-1),
torch.tensor(r, dtype=torch.float32).unsqueeze(-1),
torch.tensor(s_prime, dtype=torch.float32),
torch.tensor(1 - done, dtype=torch.float32).unsqueeze(-1),
)
def train(actor, critic, data, opta, optc):
s, a, r, s_prime, done_mask = make_batch(data)
# Optimize Critic
values = critic(s)
td_target = r + gamma * critic(s_prime) * done_mask
critic_loss = F.mse_loss(values, td_target.detach())
optc.zero_grad()
critic_loss.backward()
optc.step()
# Optimize Actor
logits = actor(s)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
log_probs = dist.log_prob(a.squeeze(-1))
advantages = (td_target - values).detach().squeeze(-1)
actor_loss = -(log_probs * advantages).mean()
opta.zero_grad()
actor_loss.backward()
opta.step()
def main():
env = gym.make("CartPole-v1")
actor = Actor()
critic = Critic()
opta = optim.Adam(actor.parameters(), lr=learning_rate)
optc = optim.Adam(critic.parameters(), lr=learning_rate)
data = []
running_reward = 0
for n_epi in range(5000):
s, _ = env.reset()
done = False
episode_reward = 0
while not done:
obs = torch.from_numpy(s).unsqueeze(0)
with torch.no_grad():
logits = actor(obs)
prob = F.softmax(logits, dim=-1).squeeze(0)
dist = torch.distributions.Categorical(prob)
a = dist.sample().item()
s_prime, r, done, _, info = env.step(a)
data.append((s, a, r / 100.0, s_prime, done))
s = s_prime
episode_reward += r
if done:
break
if data:
train(actor, critic, data, opta, optc)
data = []
running_reward = 0.05 * episode_reward + (1 - 0.05) * running_reward
if n_epi % 100 == 0:
template = "running reward: {:.2f} at episode {}"
print(template.format(running_reward, n_epi))
if running_reward > 195: # Condition to consider the task solved
print("solved at episode {}!".format(n_epi))
break
env.close()
if __name__ == "__main__":
main()
# https://huggingface.co/learn/deep-rl-course/unit4/policy-gradient
# https://huggingface.co/learn/deep-rl-course/unit6/advantage-actor-critic
# https://huggingface.co/learn/deep-rl-course/unit8/clipped-surrogate-objective
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import gym
# Hyperparameters
num_inputs = 4
num_actions = 2
num_hidden = 256
learning_rate = 0.001
gamma = 0.98
eps_clip = 0.1
K_epoch = 10
class Actor(nn.Module):
def __init__(self):
super(Actor, self).__init__()
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.fc2 = nn.Linear(num_hidden, num_actions)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
class Critic(nn.Module):
def __init__(self):
super(Critic, self).__init__()
self.fc1 = nn.Linear(num_inputs, num_hidden)
self.fc2 = nn.Linear(num_hidden, 1)
def forward(self, x):
return self.fc2(F.relu(self.fc1(x)))
def make_batch(data):
data = np.array(data, dtype=object)
s, a, r, s_prime, prob_a, done = map(np.stack, zip(*data))
return (
torch.tensor(s, dtype=torch.float32),
torch.tensor(a, dtype=torch.int64),
torch.tensor(r, dtype=torch.float32).unsqueeze(-1),
torch.tensor(s_prime, dtype=torch.float32),
# done_mask = 0 if done else 1
torch.tensor(1 - done, dtype=torch.float32).unsqueeze(-1),
torch.tensor(prob_a, dtype=torch.float32),
)
def train(actor, critic, data, opta, optc):
s, a, r, s_prime, done_mask, old_prob_a = make_batch(data)
for _ in range(K_epoch):
# Optimize Critic
old_values = critic(s)
new_values = critic(s_prime)
td_target = r + gamma * new_values * done_mask
critic_loss = F.mse_loss(old_values, td_target.detach())
advantages = (td_target - old_values).detach().squeeze(-1)
optc.zero_grad()
critic_loss.backward()
optc.step()
# Optimize Actor
logits = actor(s)
probs = F.softmax(logits, dim=-1)
dist = torch.distributions.Categorical(probs)
prob_a = torch.exp(dist.log_prob(a))
ratios = prob_a / old_prob_a
surr1 = ratios * advantages
surr2 = torch.clamp(ratios, 1 - eps_clip, 1 + eps_clip) * advantages
actor_loss = -torch.min(surr1, surr2).mean()
opta.zero_grad()
actor_loss.backward()
opta.step()
def main():
env = gym.make("CartPole-v1")
actor = Actor()
critic = Critic()
opta = optim.Adam(actor.parameters(), lr=learning_rate)
optc = optim.Adam(critic.parameters(), lr=learning_rate)
data = []
score = 0.0
running_reward = 0
for n_epi in range(5000):
s, _ = env.reset()
episode_reward = 0
while True:
obs = torch.from_numpy(s).unsqueeze(0)
with torch.no_grad():
logits = actor(obs)
prob = F.softmax(logits, dim=-1).squeeze(0)
dist = torch.distributions.Categorical(prob)
a = dist.sample().item()
s_prime, r, done, _, _ = env.step(a)
data.append((s, a, r / 100.0, s_prime, prob[a].item(), done))
s = s_prime
score += r
episode_reward += r
if done:
break
if data:
train(actor, critic, data, opta, optc)
data = []
running_reward = 0.05 * episode_reward + (1 - 0.05) * running_reward
if n_epi % 10 == 0:
template = "running reward: {:.2f} at episode {}"
print(template.format(running_reward, n_epi))
if running_reward > 200: # Condition to consider the task solved
print("solved at episode {}!".format(n_epi))
break
env.close()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment