Created
April 4, 2025 14:41
-
-
Save kalomaze/efedd1ac5340e05799fda1820e643770 to your computer and use it in GitHub Desktop.
subsequence cross entropy loss function
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class IntermediateSequenceAverageLoss(torch.nn.Module): | |
""" | |
Custom loss function that calculates the average of intermediate sequence averages, | |
with proper token shifting for causal language modeling. | |
For a sequence of length n, this calculates: | |
1. Average loss of token 1 predicting token 2 | |
2. Average loss of tokens 1-2 predicting tokens 2-3 | |
3. Average loss of tokens 1-3 predicting tokens 2-4 | |
... | |
n-1. Average loss of tokens 1-(n-1) predicting tokens 2-n | |
Then takes the average of all these averages. | |
""" | |
def __init__(self, ignore_index=-100): | |
super().__init__() | |
self.ignore_index = ignore_index | |
self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none') | |
def forward(self, logits, labels): | |
# logits shape: [batch_size, seq_len, vocab_size] | |
# labels shape: [batch_size, seq_len] | |
batch_size, seq_len, vocab_size = logits.shape | |
# Apply proper shifting for causal language modeling | |
# Shift logits left by one position (drop the last position) | |
shift_logits = logits[..., :-1, :].contiguous() | |
# Shift labels right by one position (drop the first position) | |
shift_labels = labels[..., 1:].contiguous() | |
# Get new sequence length after shifting | |
shifted_seq_len = shift_logits.size(1) | |
# Reshape for cross entropy | |
shift_logits_flat = shift_logits.view(-1, vocab_size) | |
shift_labels_flat = shift_labels.view(-1) | |
# Calculate token-wise loss (without reduction) | |
token_losses = self.ce_loss(shift_logits_flat, shift_labels_flat) | |
# Reshape back to [batch_size, shifted_seq_len] | |
token_losses = token_losses.view(batch_size, shifted_seq_len) | |
# Calculate the cumulative average loss for each prefix length | |
# and then average those averages | |
batch_losses = [] | |
for i in range(batch_size): | |
seq_losses = token_losses[i] | |
mask = (shift_labels[i] != self.ignore_index).float() | |
# Calculate cumulative sum of losses and valid tokens | |
cum_losses = torch.cumsum(seq_losses * mask, dim=0) | |
cum_valid_tokens = torch.cumsum(mask, dim=0) | |
# Calculate average loss for each prefix length (avoiding division by zero) | |
prefix_averages = cum_losses / (cum_valid_tokens + 1e-10) | |
# Only consider positions with valid tokens | |
valid_averages = prefix_averages[mask.bool()] | |
if len(valid_averages) > 0: | |
# Calculate the average of all prefix averages | |
batch_losses.append(valid_averages.mean()) | |
else: | |
# Fall back to zero loss if no valid tokens | |
batch_losses.append(torch.tensor(0.0, device=logits.device)) | |
# Average across the batch | |
return torch.stack(batch_losses).mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment