Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Last active September 8, 2024 23:49
Show Gist options
  • Save razhangwei/1c64803f6d820d6f5a880c3a54fe61c0 to your computer and use it in GitHub Desktop.
Save razhangwei/1c64803f6d820d6f5a880c3a54fe61c0 to your computer and use it in GitHub Desktop.
ppo pseudo code
# Pseudocode for estimating advantage function in RLHF using PPO
import numpy as np
import torch
import torch.nn.functional as F
def compute_ppo_loss(policy, old_policy, token_sequences, advantages, returns, clip_epsilon=0.2):
"""
Compute the PPO loss for language model policy update
policy: Current policy (language model)
old_policy: Policy used to generate the token sequences
token_sequences: List of generated token sequences
advantages: Estimated advantages for each token prediction
returns: Estimated returns for each token prediction
clip_epsilon: PPO clipping parameter
"""
total_loss = 0
for sequence, sequence_advantages, sequence_returns in zip(token_sequences, advantages, returns):
for t in range(len(sequence) - 1): # -1 because we're predicting the next token
current_sequence = sequence[:t+1]
next_token = sequence[t+1]
# Compute log probabilities
current_log_prob = policy.log_prob(current_sequence, next_token)
old_log_prob = old_policy.log_prob(current_sequence, next_token)
# Compute probability ratio
ratio = torch.exp(current_log_prob - old_log_prob)
# Compute surrogate losses
surrogate1 = ratio * sequence_advantages[t]
surrogate2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * sequence_advantages[t]
# Compute policy loss
policy_loss = -torch.min(surrogate1, surrogate2).mean()
# Compute entropy bonus (optional, to encourage exploration)
entropy = policy.entropy(current_sequence).mean()
# Combine losses
total_loss += policy_loss - ENTROPY_COEFFICIENT * entropy
return total_loss / len(token_sequences)
def compute_value_loss(value_function, token_sequences, returns):
"""
Compute the value function loss
value_function: Current value function
token_sequences: List of generated token sequences
returns: Estimated returns for each token prediction
"""
total_loss = 0
for sequence, sequence_returns in zip(token_sequences, returns):
for t in range(len(sequence)):
current_sequence = sequence[:t+1]
predicted_value = value_function(current_sequence)
actual_return = sequence_returns[t]
# Compute MSE loss
value_loss = F.mse_loss(predicted_value, actual_return)
total_loss += value_loss
return total_loss / len(token_sequences)
def estimate_advantages(trajectories, reward_model, value_function):
"""
trajectories: List of (state, action, reward) tuples
reward_model: Learned reward function from RLHF
value_function: Current estimate of the value function
"""
advantages = []
returns = []
for trajectory in trajectories:
trajectory_advantages = []
trajectory_returns = []
# Compute returns (discounted sum of rewards)
R = 0
for t in reversed(range(len(trajectory))):
state, action, _ = trajectory[t]
# Use the learned reward model to estimate the reward
reward = reward_model.predict(state, action)
R = reward + GAMMA * R # GAMMA is the discount factor
trajectory_returns.insert(0, R)
# Estimate advantages
for t in range(len(trajectory)):
state, _, _ = trajectory[t]
if t + 1 < len(trajectory):
next_state, _, _ = trajectory[t+1]
td_error = trajectory_returns[t] + GAMMA * value_function(next_state) - value_function(state)
else:
td_error = trajectory_returns[t] - value_function(state)
trajectory_advantages.append(td_error)
# Normalize advantages (optional, but often helps with training stability)
trajectory_advantages = (trajectory_advantages - np.mean(trajectory_advantages)) / (np.std(trajectory_advantages) + 1e-8)
advantages.extend(trajectory_advantages)
returns.extend(trajectory_returns)
return advantages, returns
def update_policy_ppo(policy, value_function, trajectories, reward_model):
"""
Main PPO update loop
"""
for _ in range(PPO_EPOCHS):
# Sample mini-batches of trajectories
mini_batches = sample_mini_batches(trajectories)
for mini_batch in mini_batches:
# Estimate advantages and returns
advantages, returns = estimate_advantages(mini_batch, reward_model, value_function)
# Compute PPO loss
ppo_loss = compute_ppo_loss(policy, mini_batch, advantages)
# Update policy
policy.update(ppo_loss)
# Update value function
value_loss = compute_value_loss(value_function, mini_batch, returns)
value_function.update(value_loss)
# Main training loop
reward_model = train_reward_model(human_feedback_data)
policy = initialize_policy()
value_function = initialize_value_function()
for iteration in range(NUM_ITERATIONS):
trajectories = generate_trajectories(policy)
update_policy_ppo(policy, value_function, trajectories, reward_model)
# Print the names of any files created or modified during execution:
print("No files were created or modified during the execution of this pseudocode.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment