Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active November 9, 2025 00:07
Show Gist options
  • Select an option

  • Save wassname/2a8e4169a8910ccf9174a17e24d42b66 to your computer and use it in GitHub Desktop.

Select an option

Save wassname/2a8e4169a8910ccf9174a17e24d42b66 to your computer and use it in GitHub Desktop.
generate_with_input_logits and clone_dynamic_cache
import torch
def generate_with_input_logits(model, tokenizer, batch2, **kwargs):
"""
problem: generate does not return logits for inputs, but we need them for nll
but forward -> generate with past key values does, and it doesn't recompute the input logits
so this is a helper that does both
"""
forward_out = model(**batch2, use_cache=True)
logits = forward_out.logits # [b, s, vocab]
past_key_values = forward_out.past_key_values
next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[:, None]
new_attn_mask = torch.cat(
[batch2['attention_mask'], torch.ones_like(next_input_ids)],
dim=1
)
# Shift logits and labels for NLL: predict token t from tokens 0..t-1
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = batch2['input_ids'][:, 1:].contiguous()
# Compute NLL per token, masking padding
shift_mask = (shift_labels != tokenizer.pad_token_id).float()
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
token_nll = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
).view(shift_labels.size())
# Average NLL per sequence (excluding padding)
seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)
# Continue generation from the cached KV states
input_ids = batch2['input_ids']
# past_key_values_cropped = clone_dynamic_cache(
# past_key_values, #crop=input_ids.shape[1] - 1
# )
n = past_key_values.get_seq_length()
outputs = model.generate(
input_ids=next_input_ids, # Last token as new input
attention_mask=new_attn_mask, # Keep full mask
past_key_values=past_key_values,
cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device),
output_logits=True,
output_scores=True,
return_dict_in_generate=True,
**kwargs
)
# now we need to modify this as generate does return the full sequences, including inputs ids
outputs.sequences = torch.concat([input_ids, outputs.sequences], 1)
outputs.logits = (forward_out.logits[:, -1],) + outputs.logits
return outputs, seq_nll
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment