Created
January 21, 2025 02:59
-
-
Save kalomaze/7e6ce53ce562da9708053797ee13bae5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RescaleDescentTrainer(Trainer): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
# Initialize all buffers | |
self.tokens_buffer = [] # for raw token loss | |
self.weighted_tokens_buffer = [] # for entropy weighted token loss | |
self.unigram_rate_buffer = [] | |
self.bigram_rate_buffer = [] | |
self.trigram_rate_buffer = [] | |
self.weighted_unigram_buffer = [] | |
self.weighted_bigram_buffer = [] | |
self.weighted_trigram_buffer = [] | |
self.moving_avg_window = 10 | |
# treat everything here as potentially nonsense slop code and a biohazard to your training pipelines | |
def compute_loss(self, model, inputs, return_outputs=False): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
labels = inputs["labels"] | |
# Standard AR loss calculation | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous() | |
standard_loss = F.cross_entropy( | |
shift_logits.view(-1, shift_logits.size(-1)), | |
shift_labels.view(-1), | |
ignore_index=-100 | |
) | |
# Get probabilities and mask padding | |
probs = F.softmax(shift_logits, dim=-1) | |
non_pad_mask = (shift_labels != -100) | |
valid_probs = probs[non_pad_mask] | |
valid_labels = shift_labels[non_pad_mask] | |
# Calculate entropy and weighting | |
entropy = -torch.sum(valid_probs * torch.log(valid_probs), dim=-1) | |
max_entropy = math.log(valid_probs.size(-1)) | |
entropy_weight = entropy / max_entropy | |
# Get token probabilities and calculate both losses | |
token_probs = valid_probs[torch.arange(len(valid_labels)), valid_labels] | |
token_loss = (1 - token_probs).mean() | |
weighted_token_loss = (1 - token_probs * entropy_weight).mean() | |
# Stochastic sampling for n-gram metrics | |
sampled = torch.multinomial(valid_probs, num_samples=1).squeeze(-1) | |
matches = (sampled == valid_labels).float() | |
# Regular n-gram rates | |
unigram_rate = matches.mean() | |
if len(matches) > 1: | |
bigram_matches = matches[:-1] * matches[1:] | |
bigram_rate = bigram_matches.mean() | |
else: | |
bigram_rate = torch.tensor(0.0, device=logits.device) | |
if len(matches) > 2: | |
trigram_matches = matches[:-2] * matches[1:-1] * matches[2:] | |
trigram_rate = trigram_matches.mean() | |
else: | |
trigram_rate = torch.tensor(0.0, device=logits.device) | |
# Weighted n-gram rates | |
weighted_matches = matches * entropy_weight | |
weighted_unigram_rate = weighted_matches.mean() | |
if len(matches) > 1: | |
bigram_entropy_weight = (entropy_weight[:-1] + entropy_weight[1:]) / 2 | |
weighted_bigram_rate = (bigram_matches * bigram_entropy_weight).mean() | |
else: | |
weighted_bigram_rate = torch.tensor(0.0, device=logits.device) | |
if len(matches) > 2: | |
trigram_entropy_weight = (entropy_weight[:-2] + entropy_weight[1:-1] + entropy_weight[2:]) / 3 | |
weighted_trigram_rate = (trigram_matches * trigram_entropy_weight).mean() | |
else: | |
weighted_trigram_rate = torch.tensor(0.0, device=logits.device) | |
# Average token and weighted token loss for final scaling | |
# change this to standard_loss * token_loss for the non weirdo entropy case. i also toyed with a stochastic variant which also worked | |
combined = standard_loss * weighted_token_loss | |
outputs.loss = combined | |
# Update all buffers | |
self.tokens_buffer.append(token_loss.item()) | |
self.weighted_tokens_buffer.append(weighted_token_loss.item()) | |
self.unigram_rate_buffer.append(unigram_rate.item()) | |
self.bigram_rate_buffer.append(bigram_rate.item()) | |
self.trigram_rate_buffer.append(trigram_rate.item()) | |
self.weighted_unigram_buffer.append(weighted_unigram_rate.item()) | |
self.weighted_bigram_buffer.append(weighted_bigram_rate.item()) | |
self.weighted_trigram_buffer.append(weighted_trigram_rate.item()) | |
if len(self.tokens_buffer) >= self.moving_avg_window: | |
metrics = { | |
"standard_loss": standard_loss.item(), | |
"token_loss": sum(self.tokens_buffer[-self.moving_avg_window:]) / self.moving_avg_window, | |
"weighted_token_loss": sum(self.weighted_tokens_buffer[-self.moving_avg_window:]) / self.moving_avg_window, | |
"unigram_rate": sum(self.unigram_rate_buffer[-self.moving_avg_window:]) / self.moving_avg_window, | |
"bigram_rate": sum(self.bigram_rate_buffer[-self.moving_avg_window:]) / self.moving_avg_window, | |
"trigram_rate": sum(self.trigram_rate_buffer[-self.moving_avg_window:]) / self.moving_avg_window, | |
"weighted_unigram_rate": sum(self.weighted_unigram_buffer[-self.moving_avg_window:]) / self.moving_avg_window, | |
"weighted_bigram_rate": sum(self.weighted_bigram_buffer[-self.moving_avg_window:]) / self.moving_avg_window, | |
"weighted_trigram_rate": sum(self.weighted_trigram_buffer[-self.moving_avg_window:]) / self.moving_avg_window, | |
"combined_loss": combined.item() | |
} | |
self.log(metrics) | |
# Trim all buffers | |
self.tokens_buffer = self.tokens_buffer[-self.moving_avg_window:] | |
self.weighted_tokens_buffer = self.weighted_tokens_buffer[-self.moving_avg_window:] | |
self.unigram_rate_buffer = self.unigram_rate_buffer[-self.moving_avg_window:] | |
self.bigram_rate_buffer = self.bigram_rate_buffer[-self.moving_avg_window:] | |
self.trigram_rate_buffer = self.trigram_rate_buffer[-self.moving_avg_window:] | |
self.weighted_unigram_buffer = self.weighted_unigram_buffer[-self.moving_avg_window:] | |
self.weighted_bigram_buffer = self.weighted_bigram_buffer[-self.moving_avg_window:] | |
self.weighted_trigram_buffer = self.weighted_trigram_buffer[-self.moving_avg_window:] | |
return (combined, outputs) if return_outputs else combined |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment