Skip to content

Instantly share code, notes, and snippets.

@longern
Last active June 20, 2025 08:00
Show Gist options
  • Save longern/805b14698c46e957dd0582e0a000cd19 to your computer and use it in GitHub Desktop.
Save longern/805b14698c46e957dd0582e0a000cd19 to your computer and use it in GitHub Desktop.
Self-Reflective Dual-Context Mixture Decoding
import torch
from transformers.generation.logits_process import LogitsProcessor
def combine_logits(logits1, logits2, w1=0.5, w2=None, epsilon=1e-8):
w2 = 1 - w1 if w2 is None else w2
probs1 = torch.softmax(logits1, dim=-1)
probs2 = torch.softmax(logits2, dim=-1)
combined_probs = (probs1 ** w1) * (probs2 ** w2)
combined_probs_smooth = combined_probs + epsilon
sum_combined = torch.sum(combined_probs_smooth, dim=-1, keepdim=True)
logits_combined = torch.log(combined_probs_smooth) - torch.log(sum_combined)
return logits_combined
class DualContextLogitsProcessor(LogitsProcessor):
def __init__(self, model, prompt_ids):
self.model = model
self.prompt_ids = prompt_ids
with torch.no_grad():
out_retry = model(prompt_ids, use_cache=True)
self.cached_past_key_values = out_retry.past_key_values
def __call__(self, input_ids, next_token_logits):
last_token = input_ids[:, -1:]
with torch.no_grad():
output = self.model(
last_token,
past_key_values=self.cached_past_key_values,
use_cache=True
)
retry_logits = output.logits[:, -1]
self.cached_past_key_values = output.past_key_values
return combine_logits(next_token_logits, retry_logits)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment