Last active
November 9, 2025 00:07
-
-
Save wassname/2a8e4169a8910ccf9174a17e24d42b66 to your computer and use it in GitHub Desktop.
generate_with_input_logits and clone_dynamic_cache
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import torch | |
| def generate_with_input_logits(model, tokenizer, batch2, **kwargs): | |
| """ | |
| problem: generate does not return logits for inputs, but we need them for nll | |
| but forward -> generate with past key values does, and it doesn't recompute the input logits | |
| so this is a helper that does both | |
| """ | |
| forward_out = model(**batch2, use_cache=True) | |
| logits = forward_out.logits # [b, s, vocab] | |
| past_key_values = forward_out.past_key_values | |
| next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[:, None] | |
| new_attn_mask = torch.cat( | |
| [batch2['attention_mask'], torch.ones_like(next_input_ids)], | |
| dim=1 | |
| ) | |
| # Shift logits and labels for NLL: predict token t from tokens 0..t-1 | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = batch2['input_ids'][:, 1:].contiguous() | |
| # Compute NLL per token, masking padding | |
| shift_mask = (shift_labels != tokenizer.pad_token_id).float() | |
| loss_fct = torch.nn.CrossEntropyLoss(reduction='none') | |
| token_nll = loss_fct( | |
| shift_logits.view(-1, shift_logits.size(-1)), | |
| shift_labels.view(-1) | |
| ).view(shift_labels.size()) | |
| # Average NLL per sequence (excluding padding) | |
| seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1) | |
| # Continue generation from the cached KV states | |
| input_ids = batch2['input_ids'] | |
| # past_key_values_cropped = clone_dynamic_cache( | |
| # past_key_values, #crop=input_ids.shape[1] - 1 | |
| # ) | |
| n = past_key_values.get_seq_length() | |
| outputs = model.generate( | |
| input_ids=next_input_ids, # Last token as new input | |
| attention_mask=new_attn_mask, # Keep full mask | |
| past_key_values=past_key_values, | |
| cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device), | |
| output_logits=True, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| **kwargs | |
| ) | |
| # now we need to modify this as generate does return the full sequences, including inputs ids | |
| outputs.sequences = torch.concat([input_ids, outputs.sequences], 1) | |
| outputs.logits = (forward_out.logits[:, -1],) + outputs.logits | |
| return outputs, seq_nll |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 1, | |
| "id": "129fa913", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "\n", | |
| "from transformers.cache_utils import DynamicCache\n", | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "import torch\n", | |
| "from tqdm.auto import tqdm\n", | |
| "from collections import defaultdict\n", | |
| "from typing import Optional" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 2, | |
| "id": "4a6bfe95", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
| "\n", | |
| "model_id=\"HuggingFaceTB/SmolLM2-135M-Instruct\"\n", | |
| "model = AutoModelForCausalLM.from_pretrained(model_id)\n", | |
| "tokenizer = AutoTokenizer.from_pretrained(model_id)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "f555052f", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "{'input_ids': torch.Size([1, 33]), 'attention_mask': torch.Size([1, 33])}" | |
| ] | |
| }, | |
| "execution_count": 14, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "\n", | |
| "batch2 = tokenizer.apply_chat_template([{'role': 'user', 'content': 'The eifel tower is in'}], return_tensors='pt', padding=True, return_dict=True)\n", | |
| "batch2 = {k: v.to(model.device) for k, v in batch2.items()}\n", | |
| "{k: v.shape for k, v in batch2.items()}" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "id": "0930bb2d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "1.58 s ± 23.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "outg2 = model.generate(\n", | |
| " input_ids=batch2['input_ids'], # Last token as new input\n", | |
| " attention_mask=batch2['attention_mask'], # Keep full mask\n", | |
| " output_logits=True,\n", | |
| " output_scores=True,\n", | |
| " return_dict_in_generate=True,\n", | |
| " max_new_tokens=32+1,\n", | |
| " min_new_tokens=32+1,\n", | |
| ")\n", | |
| "print(outg2.sequences.shape, len(outg2.logits))\n", | |
| "# 1.54 s ± 34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 22, | |
| "id": "187ae64d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([1, 33]) 32\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 33]) 32\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 33]) 32\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 33]) 32\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 33]) 32\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 33]) 32\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 33]) 32\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 33]) 32\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "1.56 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%time\n", | |
| "forward_out = model(**batch2, use_cache=True)\n", | |
| "logits = forward_out.logits # [b, s, vocab]\n", | |
| "past_key_values = forward_out.past_key_values\n", | |
| "next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[None]\n", | |
| "new_attn_mask = torch.cat(\n", | |
| " [batch2['attention_mask'], torch.ones_like(next_input_ids)],\n", | |
| " dim=1\n", | |
| ")\n", | |
| "\n", | |
| "# Shift logits and labels for NLL: predict token t from tokens 0..t-1\n", | |
| "shift_logits = logits[:, :-1, :].contiguous()\n", | |
| "shift_labels = batch2['input_ids'][:, 1:].contiguous()\n", | |
| "\n", | |
| "# Compute NLL per token, masking padding\n", | |
| "shift_mask = (shift_labels != tokenizer.pad_token_id).float()\n", | |
| "loss_fct = torch.nn.CrossEntropyLoss(reduction='none')\n", | |
| "token_nll = loss_fct(\n", | |
| " shift_logits.view(-1, shift_logits.size(-1)),\n", | |
| " shift_labels.view(-1)\n", | |
| ").view(shift_labels.size())\n", | |
| "\n", | |
| "# Average NLL per sequence (excluding padding)\n", | |
| "seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)\n", | |
| "\n", | |
| "# Continue generation from the cached KV states\n", | |
| "# Cache must be seq_len-1 since we're passing the last input token as new input\n", | |
| "input_ids = batch2['input_ids']\n", | |
| "n = past_key_values.get_seq_length()\n", | |
| "outputs = model.generate(\n", | |
| " input_ids=next_input_ids, # Last token as new input\n", | |
| " attention_mask=new_attn_mask, # Keep full mask\n", | |
| " past_key_values=past_key_values,\n", | |
| " cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device),\n", | |
| " output_logits=True,\n", | |
| " output_scores=True,\n", | |
| " return_dict_in_generate=True,\n", | |
| " max_new_tokens=32,\n", | |
| " min_new_tokens=32,\n", | |
| ")\n", | |
| "\n", | |
| "\n", | |
| "print(outputs.sequences.shape, len(outputs.logits))\n", | |
| "# now we need to modify this as generate does return the full sequences, including inputs ids\n", | |
| "outputs.sequences = torch.concat([input_ids, outputs.sequences], 1)\n", | |
| "outputs.logits = (forward_out.logits[:, -1],) + outputs.logits\n", | |
| "print(outputs.sequences.shape, len(outputs.logits))\n", | |
| "# 1.56 s ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 27, | |
| "id": "ffdaa26a", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "torch.Size([1, 66]) 33\n", | |
| "1.6 s ± 26.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "%%timeit\n", | |
| "def generate_with_input_logits(model, tokenizer, batch2, **kwargs):\n", | |
| " \"\"\"\n", | |
| " problem: generate does not return logits for inputs, but we need them for nll\n", | |
| "\n", | |
| " but forward -> generate with past key values does, and it doesn't recompute the input logits\n", | |
| "\n", | |
| " so this is a helper that does both\n", | |
| " \"\"\"\n", | |
| " forward_out = model(**batch2, use_cache=True)\n", | |
| " logits = forward_out.logits # [b, s, vocab]\n", | |
| " past_key_values = forward_out.past_key_values\n", | |
| " next_input_ids = forward_out.logits[:, -1].log_softmax(-1).argmax(-1)[:, None]\n", | |
| " new_attn_mask = torch.cat(\n", | |
| " [batch2['attention_mask'], torch.ones_like(next_input_ids)],\n", | |
| " dim=1\n", | |
| " )\n", | |
| " \n", | |
| " # Shift logits and labels for NLL: predict token t from tokens 0..t-1\n", | |
| " shift_logits = logits[:, :-1, :].contiguous()\n", | |
| " shift_labels = batch2['input_ids'][:, 1:].contiguous()\n", | |
| " \n", | |
| " # Compute NLL per token, masking padding\n", | |
| " shift_mask = (shift_labels != tokenizer.pad_token_id).float()\n", | |
| " loss_fct = torch.nn.CrossEntropyLoss(reduction='none')\n", | |
| " token_nll = loss_fct(\n", | |
| " shift_logits.view(-1, shift_logits.size(-1)),\n", | |
| " shift_labels.view(-1)\n", | |
| " ).view(shift_labels.size())\n", | |
| " \n", | |
| " # Average NLL per sequence (excluding padding)\n", | |
| " seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)\n", | |
| "\n", | |
| " # Continue generation from the cached KV states\n", | |
| " input_ids = batch2['input_ids']\n", | |
| " n = past_key_values.get_seq_length()\n", | |
| " outputs = model.generate(\n", | |
| " input_ids=next_input_ids, # Last token as new input\n", | |
| " attention_mask=new_attn_mask, # Keep full mask\n", | |
| " past_key_values=past_key_values,\n", | |
| " cache_position=torch.arange(n, n+1, dtype=torch.long, device=input_ids.device),\n", | |
| " output_logits=True,\n", | |
| " output_scores=True,\n", | |
| " return_dict_in_generate=True,\n", | |
| " **kwargs\n", | |
| " )\n", | |
| "\n", | |
| " # now we need to modify this as generate does return the full sequences, including inputs ids\n", | |
| " outputs.sequences = torch.concat([input_ids, outputs.sequences], 1)\n", | |
| " outputs.logits = (forward_out.logits[:, -1],) + outputs.logits\n", | |
| "\n", | |
| " return outputs, seq_nll\n", | |
| "\n", | |
| "\n", | |
| "out3 = generate_with_input_logits(model, tokenizer, batch2, max_new_tokens=32, min_new_tokens=32)\n", | |
| "print(out3[0].sequences.shape, len(out3[0].logits))" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "c73db330", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": ".venv", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.16" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment