Last active
November 9, 2025 00:07
-
-
Save wassname/2a8e4169a8910ccf9174a17e24d42b66 to your computer and use it in GitHub Desktop.
generate_with_input_logits and clone_dynamic_cache
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| def generate_with_input_logits(model, tokenizer, batch2, **kwargs): | |
| """ | |
| problem: generate does not return logits for inputs, but we need them for nll | |
| but forward -> generate with past key values does, and it doesn't recompute the input logits | |
| so this is a helper that does both | |
| """ | |
| forward_out = model(**batch2, use_cache=True) | |
| logits = forward_out.logits # [b, s, vocab] | |
| past_key_values = forward_out.past_key_values | |
| next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[:, None] | |
| new_attn_mask = torch.cat( | |
| [batch2['attention_mask'], torch.ones_like(next_input_ids)], | |
| dim=1 | |
| ) | |
| # Shift logits and labels for NLL: predict token t from tokens 0..t-1 | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = batch2['input_ids'][:, 1:].contiguous() | |
| # Compute NLL per token, masking padding | |
| shift_mask = (shift_labels != tokenizer.pad_token_id).float() | |
| loss_fct = torch.nn.CrossEntropyLoss(reduction='none') | |
| token_nll = loss_fct( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1) | |
| ).view(shift_labels.size()) | |
| # Average NLL per sequence (excluding padding) | |
| seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1) | |
| # Continue generation from the cached KV states | |
| input_ids = batch2['input_ids'] | |
| # past_key_values_cropped = clone_dynamic_cache( | |
| # past_key_values, #crop=input_ids.shape[1] - 1 | |
| # ) | |
| n = past_key_values.get_seq_length() | |
| outputs = model.generate( | |
| input_ids=next_input_ids, # Last token as new input | |
| attention_mask=new_attn_mask, # Keep full mask | |
| past_key_values=past_key_values, | |
| cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device), | |
| output_logits=True, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| **kwargs | |
| ) | |
| # now we need to modify this as generate does return the full sequences, including inputs ids | |
| outputs.sequences = torch.concat([input_ids, outputs.sequences], 1) | |
| outputs.logits = (forward_out.logits[:, -1],) + outputs.logits | |
| return outputs, seq_nll |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment