Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active November 9, 2025 00:07
Show Gist options
  • Select an option

  • Save wassname/2a8e4169a8910ccf9174a17e24d42b66 to your computer and use it in GitHub Desktop.

Select an option

Save wassname/2a8e4169a8910ccf9174a17e24d42b66 to your computer and use it in GitHub Desktop.
generate_with_input_logits and clone_dynamic_cache
import torch
def generate_with_input_logits(model, tokenizer, batch2, **kwargs):
"""
problem: generate does not return logits for inputs, but we need them for nll
but forward -> generate with past key values does, and it doesn't recompute the input logits
so this is a helper that does both
"""
forward_out = model(**batch2, use_cache=True)
logits = forward_out.logits # [b, s, vocab]
past_key_values = forward_out.past_key_values
next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[:, None]
new_attn_mask = torch.cat(
[batch2['attention_mask'], torch.ones_like(next_input_ids)],
dim=1
)
# Shift logits and labels for NLL: predict token t from tokens 0..t-1
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = batch2['input_ids'][:, 1:].contiguous()
# Compute NLL per token, masking padding
shift_mask = (shift_labels != tokenizer.pad_token_id).float()
loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
token_nll = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
).view(shift_labels.size())
# Average NLL per sequence (excluding padding)
seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)
# Continue generation from the cached KV states
input_ids = batch2['input_ids']
# past_key_values_cropped = clone_dynamic_cache(
# past_key_values, #crop=input_ids.shape[1] - 1
# )
n = past_key_values.get_seq_length()
outputs = model.generate(
input_ids=next_input_ids, # Last token as new input
attention_mask=new_attn_mask, # Keep full mask
past_key_values=past_key_values,
cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device),
output_logits=True,
output_scores=True,
return_dict_in_generate=True,
**kwargs
)
# now we need to modify this as generate does return the full sequences, including inputs ids
outputs.sequences = torch.concat([input_ids, outputs.sequences], 1)
outputs.logits = (forward_out.logits[:, -1],) + outputs.logits
return outputs, seq_nll
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "129fa913",
"metadata": {},
"outputs": [],
"source": [
"\n",
"from transformers.cache_utils import DynamicCache\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from tqdm.auto import tqdm\n",
"from collections import defaultdict\n",
"from typing import Optional"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4a6bfe95",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"model_id=\"HuggingFaceTB/SmolLM2-135M-Instruct\"\n",
"model = AutoModelForCausalLM.from_pretrained(model_id)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f555052f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': torch.Size([1, 33]), 'attention_mask': torch.Size([1, 33])}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\n",
"batch2 = tokenizer.apply_chat_template([{'role': 'user', 'content': 'The eifel tower is in'}], return_tensors='pt', padding=True, return_dict=True)\n",
"batch2 = {k: v.to(model.device) for k, v in batch2.items()}\n",
"{k: v.shape for k, v in batch2.items()}"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "0930bb2d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"1.58 s ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"outg2 = model.generate(\n",
" input_ids=batch2['input_ids'], # Last token as new input\n",
" attention_mask=batch2['attention_mask'], # Keep full mask\n",
" output_logits=True,\n",
" output_scores=True,\n",
" return_dict_in_generate=True,\n",
" max_new_tokens=32+1,\n",
" min_new_tokens=32+1,\n",
")\n",
"print(outg2.sequences.shape, len(outg2.logits))\n",
"# 1.54 s ± 34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "187ae64d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 33]) 32\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 33]) 32\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 33]) 32\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 33]) 32\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 33]) 32\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 33]) 32\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 33]) 32\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 33]) 32\n",
"torch.Size([1, 66]) 33\n",
"1.56 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%time\n",
"forward_out = model(**batch2, use_cache=True)\n",
"logits = forward_out.logits # [b, s, vocab]\n",
"past_key_values = forward_out.past_key_values\n",
"next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[None]\n",
"new_attn_mask = torch.cat(\n",
" [batch2['attention_mask'], torch.ones_like(next_input_ids)],\n",
" dim=1\n",
")\n",
"\n",
"# Shift logits and labels for NLL: predict token t from tokens 0..t-1\n",
"shift_logits = logits[:, :-1, :].contiguous()\n",
"shift_labels = batch2['input_ids'][:, 1:].contiguous()\n",
"\n",
"# Compute NLL per token, masking padding\n",
"shift_mask = (shift_labels != tokenizer.pad_token_id).float()\n",
"loss_fct = torch.nn.CrossEntropyLoss(reduction='none')\n",
"token_nll = loss_fct(\n",
" shift_logits.view(-1, shift_logits.size(-1)),\n",
" shift_labels.view(-1)\n",
").view(shift_labels.size())\n",
"\n",
"# Average NLL per sequence (excluding padding)\n",
"seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)\n",
"\n",
"# Continue generation from the cached KV states\n",
"# Cache must be seq_len-1 since we're passing the last input token as new input\n",
"input_ids = batch2['input_ids']\n",
"n = past_key_values.get_seq_length()\n",
"outputs = model.generate(\n",
" input_ids=next_input_ids, # Last token as new input\n",
" attention_mask=new_attn_mask, # Keep full mask\n",
" past_key_values=past_key_values,\n",
" cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device),\n",
" output_logits=True,\n",
" output_scores=True,\n",
" return_dict_in_generate=True,\n",
" max_new_tokens=32,\n",
" min_new_tokens=32,\n",
")\n",
"\n",
"\n",
"print(outputs.sequences.shape, len(outputs.logits))\n",
"# now we need to modify this as generate does return the full sequences, including inputs ids\n",
"outputs.sequences = torch.concat([input_ids, outputs.sequences], 1)\n",
"outputs.logits = (forward_out.logits[:, -1],) + outputs.logits\n",
"print(outputs.sequences.shape, len(outputs.logits))\n",
"# 1.56 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "ffdaa26a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"torch.Size([1, 66]) 33\n",
"1.6 s ± 26.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"%%timeit\n",
"def generate_with_input_logits(model, tokenizer, batch2, **kwargs):\n",
" \"\"\"\n",
" problem: generate does not return logits for inputs, but we need them for nll\n",
"\n",
" but forward -> generate with past key values does, and it doesn't recompute the input logits\n",
"\n",
" so this is a helper that does both\n",
" \"\"\"\n",
" forward_out = model(**batch2, use_cache=True)\n",
" logits = forward_out.logits # [b, s, vocab]\n",
" past_key_values = forward_out.past_key_values\n",
" next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[:, None]\n",
" new_attn_mask = torch.cat(\n",
" [batch2['attention_mask'], torch.ones_like(next_input_ids)],\n",
" dim=1\n",
" )\n",
" \n",
" # Shift logits and labels for NLL: predict token t from tokens 0..t-1\n",
" shift_logits = logits[:, :-1, :].contiguous()\n",
" shift_labels = batch2['input_ids'][:, 1:].contiguous()\n",
" \n",
" # Compute NLL per token, masking padding\n",
" shift_mask = (shift_labels != tokenizer.pad_token_id).float()\n",
" loss_fct = torch.nn.CrossEntropyLoss(reduction='none')\n",
" token_nll = loss_fct(\n",
" shift_logits.view(-1, shift_logits.size(-1)),\n",
" shift_labels.view(-1)\n",
" ).view(shift_labels.size())\n",
" \n",
" # Average NLL per sequence (excluding padding)\n",
" seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)\n",
"\n",
" # Continue generation from the cached KV states\n",
" input_ids = batch2['input_ids']\n",
" n = past_key_values.get_seq_length()\n",
" outputs = model.generate(\n",
" input_ids=next_input_ids, # Last token as new input\n",
" attention_mask=new_attn_mask, # Keep full mask\n",
" past_key_values=past_key_values,\n",
" cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device),\n",
" output_logits=True,\n",
" output_scores=True,\n",
" return_dict_in_generate=True,\n",
" **kwargs\n",
" )\n",
"\n",
" # now we need to modify this as generate does return the full sequences, including inputs ids\n",
" outputs.sequences = torch.concat([input_ids, outputs.sequences], 1)\n",
" outputs.logits = (forward_out.logits[:, -1],) + outputs.logits\n",
"\n",
" return outputs, seq_nll\n",
"\n",
"\n",
"out3 = generate_with_input_logits(model, tokenizer, batch2, max_new_tokens=32, min_new_tokens=32)\n",
"print(out3[0].sequences.shape, len(out3[0].logits))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c73db330",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment