Skip to content

Instantly share code, notes, and snippets.

@kalomaze
Created January 21, 2025 02:59
Show Gist options
  • Save kalomaze/7e6ce53ce562da9708053797ee13bae5 to your computer and use it in GitHub Desktop.
Save kalomaze/7e6ce53ce562da9708053797ee13bae5 to your computer and use it in GitHub Desktop.
class RescaleDescentTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Initialize all buffers
self.tokens_buffer = [] # for raw token loss
self.weighted_tokens_buffer = [] # for entropy weighted token loss
self.unigram_rate_buffer = []
self.bigram_rate_buffer = []
self.trigram_rate_buffer = []
self.weighted_unigram_buffer = []
self.weighted_bigram_buffer = []
self.weighted_trigram_buffer = []
self.moving_avg_window = 10
# treat everything here as potentially nonsense slop code and a biohazard to your training pipelines
def compute_loss(self, model, inputs, return_outputs=False):
outputs = model(**inputs)
logits = outputs.logits
labels = inputs["labels"]
# Standard AR loss calculation
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
standard_loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100
)
# Get probabilities and mask padding
probs = F.softmax(shift_logits, dim=-1)
non_pad_mask = (shift_labels != -100)
valid_probs = probs[non_pad_mask]
valid_labels = shift_labels[non_pad_mask]
# Calculate entropy and weighting
entropy = -torch.sum(valid_probs * torch.log(valid_probs), dim=-1)
max_entropy = math.log(valid_probs.size(-1))
entropy_weight = entropy / max_entropy
# Get token probabilities and calculate both losses
token_probs = valid_probs[torch.arange(len(valid_labels)), valid_labels]
token_loss = (1 - token_probs).mean()
weighted_token_loss = (1 - token_probs * entropy_weight).mean()
# Stochastic sampling for n-gram metrics
sampled = torch.multinomial(valid_probs, num_samples=1).squeeze(-1)
matches = (sampled == valid_labels).float()
# Regular n-gram rates
unigram_rate = matches.mean()
if len(matches) > 1:
bigram_matches = matches[:-1] * matches[1:]
bigram_rate = bigram_matches.mean()
else:
bigram_rate = torch.tensor(0.0, device=logits.device)
if len(matches) > 2:
trigram_matches = matches[:-2] * matches[1:-1] * matches[2:]
trigram_rate = trigram_matches.mean()
else:
trigram_rate = torch.tensor(0.0, device=logits.device)
# Weighted n-gram rates
weighted_matches = matches * entropy_weight
weighted_unigram_rate = weighted_matches.mean()
if len(matches) > 1:
bigram_entropy_weight = (entropy_weight[:-1] + entropy_weight[1:]) / 2
weighted_bigram_rate = (bigram_matches * bigram_entropy_weight).mean()
else:
weighted_bigram_rate = torch.tensor(0.0, device=logits.device)
if len(matches) > 2:
trigram_entropy_weight = (entropy_weight[:-2] + entropy_weight[1:-1] + entropy_weight[2:]) / 3
weighted_trigram_rate = (trigram_matches * trigram_entropy_weight).mean()
else:
weighted_trigram_rate = torch.tensor(0.0, device=logits.device)
# Average token and weighted token loss for final scaling
# change this to standard_loss * token_loss for the non weirdo entropy case. i also toyed with a stochastic variant which also worked
combined = standard_loss * weighted_token_loss
outputs.loss = combined
# Update all buffers
self.tokens_buffer.append(token_loss.item())
self.weighted_tokens_buffer.append(weighted_token_loss.item())
self.unigram_rate_buffer.append(unigram_rate.item())
self.bigram_rate_buffer.append(bigram_rate.item())
self.trigram_rate_buffer.append(trigram_rate.item())
self.weighted_unigram_buffer.append(weighted_unigram_rate.item())
self.weighted_bigram_buffer.append(weighted_bigram_rate.item())
self.weighted_trigram_buffer.append(weighted_trigram_rate.item())
if len(self.tokens_buffer) >= self.moving_avg_window:
metrics = {
"standard_loss": standard_loss.item(),
"token_loss": sum(self.tokens_buffer[-self.moving_avg_window:]) / self.moving_avg_window,
"weighted_token_loss": sum(self.weighted_tokens_buffer[-self.moving_avg_window:]) / self.moving_avg_window,
"unigram_rate": sum(self.unigram_rate_buffer[-self.moving_avg_window:]) / self.moving_avg_window,
"bigram_rate": sum(self.bigram_rate_buffer[-self.moving_avg_window:]) / self.moving_avg_window,
"trigram_rate": sum(self.trigram_rate_buffer[-self.moving_avg_window:]) / self.moving_avg_window,
"weighted_unigram_rate": sum(self.weighted_unigram_buffer[-self.moving_avg_window:]) / self.moving_avg_window,
"weighted_bigram_rate": sum(self.weighted_bigram_buffer[-self.moving_avg_window:]) / self.moving_avg_window,
"weighted_trigram_rate": sum(self.weighted_trigram_buffer[-self.moving_avg_window:]) / self.moving_avg_window,
"combined_loss": combined.item()
}
self.log(metrics)
# Trim all buffers
self.tokens_buffer = self.tokens_buffer[-self.moving_avg_window:]
self.weighted_tokens_buffer = self.weighted_tokens_buffer[-self.moving_avg_window:]
self.unigram_rate_buffer = self.unigram_rate_buffer[-self.moving_avg_window:]
self.bigram_rate_buffer = self.bigram_rate_buffer[-self.moving_avg_window:]
self.trigram_rate_buffer = self.trigram_rate_buffer[-self.moving_avg_window:]
self.weighted_unigram_buffer = self.weighted_unigram_buffer[-self.moving_avg_window:]
self.weighted_bigram_buffer = self.weighted_bigram_buffer[-self.moving_avg_window:]
self.weighted_trigram_buffer = self.weighted_trigram_buffer[-self.moving_avg_window:]
return (combined, outputs) if return_outputs else combined
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment