Skip to content

Instantly share code, notes, and snippets.

@tiandiao123
Created September 25, 2024 08:16
Show Gist options
  • Save tiandiao123/3c47e049c71bf6081109fdc0c99d552a to your computer and use it in GitHub Desktop.
Save tiandiao123/3c47e049c71bf6081109fdc0c99d552a to your computer and use it in GitHub Desktop.
gae.py
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
# Define the policy network
class PolicyNetwork(nn.Module):
def __init__(self, state_dim, action_dim):
super(PolicyNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, action_dim),
nn.Softmax(dim=-1)
)
def forward(self, x):
return self.fc(x)
# Define the value network
class ValueNetwork(nn.Module):
def __init__(self, state_dim):
super(ValueNetwork, self).__init__()
self.fc = nn.Sequential(
nn.Linear(state_dim, 64),
nn.ReLU(),
nn.Linear(64, 1)
)
def forward(self, x):
return self.fc(x)
# Function to compute GAE
def compute_gae(rewards, values, gamma=0.99, lambda_=0.95):
advantages = []
gae = 0
values = values + [0]
for t in reversed(range(len(rewards))):
delta = rewards[t] + gamma * values[t + 1] - values[t]
gae = delta + gamma * lambda_ * gae
advantages.insert(0, gae)
return advantages
# PPO training function
def train_ppo(env, policy_net, value_net, policy_optimizer, value_optimizer, epochs=10, gamma=0.99, lambda_=0.95, epsilon=0.2):
for epoch in range(epochs):
state = env.reset()
log_probs, values, rewards, states, actions = [], [], [], [], []
# Collect trajectory
done = False
while not done:
state_tensor = torch.FloatTensor(state).unsqueeze(0)
dist = Categorical(policy_net(state_tensor))
action = dist.sample()
next_state, reward, done, _ = env.step(action.item())
log_prob = dist.log_prob(action)
value = value_net(state_tensor)
log_probs.append(log_prob)
values.append(value.item())
rewards.append(reward)
states.append(state)
actions.append(action)
state = next_state
# Compute GAE
advantages = compute_gae(rewards, values, gamma, lambda_)
returns = [adv + val for adv, val in zip(advantages, values)]
# Convert lists to tensors
log_probs = torch.stack(log_probs)
returns = torch.FloatTensor(returns)
advantages = torch.FloatTensor(advantages)
states = torch.FloatTensor(states)
actions = torch.stack(actions)
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-10)
# Update policy network
for _ in range(4): # PPO typically uses multiple epochs
dist = Categorical(policy_net(states))
new_log_probs = dist.log_prob(actions)
ratio = (new_log_probs - log_probs.detach()).exp()
surr1 = ratio * advantages
surr2 = torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon) * advantages
policy_loss = -torch.min(surr1, surr2).mean()
policy_optimizer.zero_grad()
policy_loss.backward()
policy_optimizer.step()
# Update value network
value_optimizer.zero_grad()
value_loss = nn.MSELoss()(value_net(states).squeeze(), returns)
value_loss.backward()
value_optimizer.step()
# Example usage (assuming you have an environment `env`):
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
policy_net = PolicyNetwork(state_dim, action_dim)
value_net = ValueNetwork(state_dim)
policy_optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
value_optimizer = optim.Adam(value_net.parameters(), lr=1e-3)
train_ppo(env, policy_net, value_net, policy_optimizer, value_optimizer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment