Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Last active September 9, 2024 03:09
Show Gist options
  • Save razhangwei/1d1be48b2d744ad00092ea0b8c71e11a to your computer and use it in GitHub Desktop.
Save razhangwei/1d1be48b2d744ad00092ea0b8c71e11a to your computer and use it in GitHub Desktop.
DPO #pseudocode
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
class LanguageModel(nn.Module):
def __init__(self):
super().__init__()
# Define transformer layers, embeddings, etc.
def forward(self, input_ids):
# Implement the forward pass
return logits
def dpo_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, beta):
pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs]
ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs]
pi_logratios = pi_yw_logps - pi_yl_logps
ref_logratios = ref_yw_logps - ref_yl_logps
losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios))
rewards = beta * (pi_logps - ref_logps).detach()
return losses.mean(), rewards
# Initialize models
policy_model = LanguageModel()
ref_model = LanguageModel()
ref_model.load_state_dict(policy_model.state_dict()) # Initialize with same weights
ref_model.eval() # Freeze reference model
optimizer = optim.Adam(policy_model.parameters())
beta = 0.1 # DPO temperature
# Training loop
for epoch in range(num_epochs):
for batch in dataloader:
input_ids, attention_mask, yw_idxs, yl_idxs = batch
# Forward pass
policy_logits = policy_model(input_ids)
with torch.no_grad():
ref_logits = ref_model(input_ids)
# Compute log probabilities
policy_logps = F.log_softmax(policy_logits, dim=-1)
ref_logps = F.log_softmax(ref_logits, dim=-1)
# Gather relevant log probs for chosen tokens
policy_logps = gather_log_probs(policy_logps, input_ids)
ref_logps = gather_log_probs(ref_logps, input_ids)
# Compute DPO loss
loss, rewards = dpo_loss(policy_logps, ref_logps, yw_idxs, yl_idxs, beta)
# Backpropagation and optimization
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
optimizer.step()
# Evaluation and logging
evaluate_model(policy_model)
log_metrics()
# Save the fine-tuned model
torch.save(policy_model.state_dict(), 'dpo_finetuned_model.pth')
def gather_log_probs(logps, tokens):
return torch.gather(logps, -1, tokens.unsqueeze(-1)).squeeze(-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment