Skip to content

Instantly share code, notes, and snippets.

@infoslack
Created January 27, 2025 17:59
Show Gist options
  • Save infoslack/f0e0aec9a882c2c76e1dc1bdd510f279 to your computer and use it in GitHub Desktop.
Save infoslack/f0e0aec9a882c2c76e1dc1bdd510f279 to your computer and use it in GitHub Desktop.
Group Relative Policy Optimization (GRPO) implementation
# This implementation is based on the paper: https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf
#
# pip install torch transformers
# python grpo_demo.py
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel
# GRPO Configuration parameters as per formula
G = 4 # Number of samples in the group (G in the formula)
epsilon = 0.15 # ε in the formula - Clipping limit
beta = 0.0005 # β in the formula - KL penalty weight
learning_rate = 0.001
# Example data - Simulating a single q from distribution P(Q)
question = "What is the capital of Brazil?"
possible_answers = ["Brasília", "Rio de Janeiro", "São Paulo", "Fortaleza"]
correct_answer_idx = 0 # Brasília
# Preprocessing with BERT to get input representations
tokenizer = BertTokenizer.from_pretrained("neuralmind/bert-base-portuguese-cased")
model = BertModel.from_pretrained("neuralmind/bert-base-portuguese-cased")
def get_embedding(text):
"""Convert text to embedding using BERT"""
inputs = tokenizer(
text, return_tensors="pt", padding=True, truncation=True, max_length=512
)
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).detach()
question_embedding = get_embedding(question)
# Policy Model (π_θ in the formula)
class PolicyModel(nn.Module):
def __init__(self, num_actions, embedding_dim=768):
super().__init__()
self.fc = nn.Linear(embedding_dim, num_actions)
def forward(self, x):
# Returns π_θ(o|q) - action probabilities given the state
return torch.softmax(self.fc(x), dim=-1)
# Initialize π_θ (current policy) and π_θ_old (old policy)
num_actions = len(possible_answers)
policy = PolicyModel(num_actions) # π_θ in the formula
old_policy = PolicyModel(num_actions) # π_θ_old in the formula
old_policy.load_state_dict(policy.state_dict())
def train_step():
"""
Implements one optimization step of GRPO according to the formula:
J_GRPO(θ) = E[...] 1/G ∑(min(π_θ/π_θ_old * A_i, clip(π_θ/π_θ_old, 1-ε, 1+ε) * A_i))
"""
# 1. Sample G outputs from old policy π_θ_old
with torch.no_grad():
probs_old = old_policy(question_embedding)
sampled_actions = torch.multinomial(probs_old.squeeze(), G, replacement=True)
# 2. Calculate probabilities from new policy π_θ
probs_new = policy(question_embedding)
# 3. Calculate ratio π_θ/π_θ_old
ratios = probs_new[0, sampled_actions] / probs_old[0, sampled_actions]
# 4. Calculate rewards and advantages (A_i in the formula)
rewards = torch.tensor(
[1.0 if idx == correct_answer_idx else -0.1 for idx in sampled_actions]
)
# A_i = (r_i - mean({r_1,...,r_G})) / std({r_1,...,r_G})
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# 5. Implement clipping as per formula
clipped_ratios = torch.clamp(ratios, 1 - epsilon, 1 + epsilon)
# 6. Calculate loss according to min(.) in formula
loss_policy = -torch.min(ratios * advantages, clipped_ratios * advantages).mean()
# 7. Calculate KL divergence as per formula (2)
# D_KL(π_θ||π_ref) = π_ref(o_i|q)/π_θ(o_i|q) - log(π_ref(o_i|q)/π_θ(o_i|q)) - 1
ratio_kl = probs_old.detach() / probs_new
kl_penalty = (ratio_kl - torch.log(ratio_kl) - 1).mean()
# 8. Total loss with KL penalty
total_loss = loss_policy + beta * kl_penalty
# 9. Update policy
optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
return total_loss, loss_policy, kl_penalty
# Train the model
print("Starting training...")
for epoch in range(100):
loss, policy_loss, kl = train_step()
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch + 1}")
print(f" Total Loss: {loss.item():.4f}")
print(f" Policy Loss: {policy_loss.item():.4f}")
print(f" KL Divergence: {kl.item():.4f}")
# Test the trained policy
with torch.no_grad():
probs_final = policy(question_embedding)
predicted_answer_idx = torch.argmax(probs_final).item()
probabilities = probs_final[0].numpy()
print("\nFinal Results:")
print(f"Predicted answer: '{possible_answers[predicted_answer_idx]}'")
print("\nProbabilities for each answer:")
for answer, prob in zip(possible_answers, probabilities):
print(f"{answer}: {prob:.4f}")
@infoslack
Copy link
Author

Output

[staff] in ~/Projects/grpo_demo via 🐍 v3.12.8 $ python grpo_nlp.py
Starting training...
Epoch 10
  Total Loss: 0.0198
  Policy Loss: 0.0198
  KL Divergence: 0.0024
Epoch 20
  Total Loss: -0.1117
  Policy Loss: -0.1118
  KL Divergence: 0.0470
Epoch 30
  Total Loss: -0.0985
  Policy Loss: -0.0985
  KL Divergence: 0.0091
Epoch 40
  Total Loss: 0.0000
  Policy Loss: -0.0000
  KL Divergence: 0.0345
Epoch 50
  Total Loss: -0.0981
  Policy Loss: -0.0981
  KL Divergence: 0.0169
Epoch 60
  Total Loss: 0.0000
  Policy Loss: -0.0000
  KL Divergence: 0.0305
Epoch 70
  Total Loss: 0.0138
  Policy Loss: 0.0138
  KL Divergence: 0.0018
Epoch 80
  Total Loss: -0.0842
  Policy Loss: -0.0842
  KL Divergence: 0.0093
Epoch 90
  Total Loss: -0.0864
  Policy Loss: -0.0864
  KL Divergence: 0.0354
Epoch 100
  Total Loss: -0.1285
  Policy Loss: -0.1285
  KL Divergence: 0.0312

Final Results:
Predicted answer: 'Brasília'

Probabilities for each answer:
Brasília: 0.3410
Rio de Janeiro: 0.2605
São Paulo: 0.2238
Fortaleza: 0.1746

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment