Last active
September 9, 2024 03:09
-
-
Save razhangwei/1d1be48b2d744ad00092ea0b8c71e11a to your computer and use it in GitHub Desktop.
DPO #pseudocode
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
class LanguageModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
# Define transformer layers, embeddings, etc. | |
def forward(self, input_ids): | |
# Implement the forward pass | |
return logits | |
def dpo_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, beta): | |
pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs] | |
ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs] | |
pi_logratios = pi_yw_logps - pi_yl_logps | |
ref_logratios = ref_yw_logps - ref_yl_logps | |
losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios)) | |
rewards = beta * (pi_logps - ref_logps).detach() | |
return losses.mean(), rewards | |
# Initialize models | |
policy_model = LanguageModel() | |
ref_model = LanguageModel() | |
ref_model.load_state_dict(policy_model.state_dict()) # Initialize with same weights | |
ref_model.eval() # Freeze reference model | |
optimizer = optim.Adam(policy_model.parameters()) | |
beta = 0.1 # DPO temperature | |
# Training loop | |
for epoch in range(num_epochs): | |
for batch in dataloader: | |
input_ids, attention_mask, yw_idxs, yl_idxs = batch | |
# Forward pass | |
policy_logits = policy_model(input_ids) | |
with torch.no_grad(): | |
ref_logits = ref_model(input_ids) | |
# Compute log probabilities | |
policy_logps = F.log_softmax(policy_logits, dim=-1) | |
ref_logps = F.log_softmax(ref_logits, dim=-1) | |
# Gather relevant log probs for chosen tokens | |
policy_logps = gather_log_probs(policy_logps, input_ids) | |
ref_logps = gather_log_probs(ref_logps, input_ids) | |
# Compute DPO loss | |
loss, rewards = dpo_loss(policy_logps, ref_logps, yw_idxs, yl_idxs, beta) | |
# Backpropagation and optimization | |
optimizer.zero_grad() | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0) | |
optimizer.step() | |
# Evaluation and logging | |
evaluate_model(policy_model) | |
log_metrics() | |
# Save the fine-tuned model | |
torch.save(policy_model.state_dict(), 'dpo_finetuned_model.pth') | |
def gather_log_probs(logps, tokens): | |
return torch.gather(logps, -1, tokens.unsqueeze(-1)).squeeze(-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment