Skip to content

Instantly share code, notes, and snippets.

@kalomaze
Created April 4, 2025 14:41
Show Gist options
  • Save kalomaze/efedd1ac5340e05799fda1820e643770 to your computer and use it in GitHub Desktop.
Save kalomaze/efedd1ac5340e05799fda1820e643770 to your computer and use it in GitHub Desktop.
subsequence cross entropy loss function
class IntermediateSequenceAverageLoss(torch.nn.Module):
"""
Custom loss function that calculates the average of intermediate sequence averages,
with proper token shifting for causal language modeling.
For a sequence of length n, this calculates:
1. Average loss of token 1 predicting token 2
2. Average loss of tokens 1-2 predicting tokens 2-3
3. Average loss of tokens 1-3 predicting tokens 2-4
...
n-1. Average loss of tokens 1-(n-1) predicting tokens 2-n
Then takes the average of all these averages.
"""
def __init__(self, ignore_index=-100):
super().__init__()
self.ignore_index = ignore_index
self.ce_loss = torch.nn.CrossEntropyLoss(ignore_index=ignore_index, reduction='none')
def forward(self, logits, labels):
# logits shape: [batch_size, seq_len, vocab_size]
# labels shape: [batch_size, seq_len]
batch_size, seq_len, vocab_size = logits.shape
# Apply proper shifting for causal language modeling
# Shift logits left by one position (drop the last position)
shift_logits = logits[..., :-1, :].contiguous()
# Shift labels right by one position (drop the first position)
shift_labels = labels[..., 1:].contiguous()
# Get new sequence length after shifting
shifted_seq_len = shift_logits.size(1)
# Reshape for cross entropy
shift_logits_flat = shift_logits.view(-1, vocab_size)
shift_labels_flat = shift_labels.view(-1)
# Calculate token-wise loss (without reduction)
token_losses = self.ce_loss(shift_logits_flat, shift_labels_flat)
# Reshape back to [batch_size, shifted_seq_len]
token_losses = token_losses.view(batch_size, shifted_seq_len)
# Calculate the cumulative average loss for each prefix length
# and then average those averages
batch_losses = []
for i in range(batch_size):
seq_losses = token_losses[i]
mask = (shift_labels[i] != self.ignore_index).float()
# Calculate cumulative sum of losses and valid tokens
cum_losses = torch.cumsum(seq_losses * mask, dim=0)
cum_valid_tokens = torch.cumsum(mask, dim=0)
# Calculate average loss for each prefix length (avoiding division by zero)
prefix_averages = cum_losses / (cum_valid_tokens + 1e-10)
# Only consider positions with valid tokens
valid_averages = prefix_averages[mask.bool()]
if len(valid_averages) > 0:
# Calculate the average of all prefix averages
batch_losses.append(valid_averages.mean())
else:
# Fall back to zero loss if no valid tokens
batch_losses.append(torch.tensor(0.0, device=logits.device))
# Average across the batch
return torch.stack(batch_losses).mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment