Last active
November 10, 2025 12:34
-
-
Save wassname/356c24d5c886163bf13751a72fcb7980 to your computer and use it in GitHub Desktop.
problem: model.generate does not return input logprobs, solution `model.forward`, then `model.generate(past_key_values)`, also get logprobs on the last token of a stopping sequences
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "code", | |
| "execution_count": 10, | |
| "id": "129fa913", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from transformers.cache_utils import DynamicCache\n", | |
| "import numpy as np\n", | |
| "import pandas as pd\n", | |
| "import torch\n", | |
| "from tqdm.auto import tqdm\n", | |
| "from collections import defaultdict\n", | |
| "from typing import Optional" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 11, | |
| "id": "4a6bfe95", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "application/vnd.jupyter.widget-view+json": { | |
| "model_id": "726f4fd4bc2d4b2992c3d71b1cec4d91", | |
| "version_major": 2, | |
| "version_minor": 0 | |
| }, | |
| "text/plain": [ | |
| "Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]" | |
| ] | |
| }, | |
| "metadata": {}, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
| "\n", | |
| "model_id = \"HuggingFaceTB/SmolLM2-135M-Instruct\"\n", | |
| "model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n", | |
| "model = AutoModelForCausalLM.from_pretrained(model_id)\n", | |
| "tokenizer = AutoTokenizer.from_pretrained(model_id)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 12, | |
| "id": "fa35dd6d", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "from typing import List, Tuple\n", | |
| "import re\n", | |
| "import torch\n", | |
| "from loguru import logger\n", | |
| "\n", | |
| "def binary_log_cls(logits, choice_ids):\n", | |
| " logp = logits.log_softmax(dim=-1).detach().cpu()\n", | |
| " log_choices = torch.zeros(len(choice_ids)).to(logp.device)\n", | |
| " for i, choice_id_group in enumerate(choice_ids):\n", | |
| " choice_id_group = torch.tensor(choice_id_group).to(logp.device)\n", | |
| " logp_choice = logp[choice_id_group].logsumexp(-1)\n", | |
| " log_choices[i] = logp_choice\n", | |
| " return log_choices\n", | |
| "\n", | |
| "\n", | |
| "def extract_log_ratios(\n", | |
| " out: \"ModelOutput\", input_ids, choice_ids\n", | |
| "):\n", | |
| " \"\"\"Get [sequences x answers] log ratios for each of len(sequences) X regexp matches.\"\"\"\n", | |
| " N = input_ids.shape[1]\n", | |
| " bs = out.sequences.shape[0]\n", | |
| " logrs = torch.ones((bs, len(choice_ids))) * float(\"nan\")\n", | |
| " for sample_i in range(bs):\n", | |
| " log_choices = binary_log_cls(\n", | |
| " out.logits[-1][sample_i], choice_ids\n", | |
| " )\n", | |
| " logrs[sample_i] = log_choices\n", | |
| " return logrs\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d55e5ab9", | |
| "metadata": {}, | |
| "source": [ | |
| "## Get choice token ids" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 13, | |
| "id": "55970cd4", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "# Many tokenizers don't just use Yes, but \\nYes, \" Yes\" and so on. We need to catch all variants\n", | |
| "def is_choice(choice: str, match: str) -> bool:\n", | |
| " return (match.lower().endswith(choice) or match.lower().startswith(choice)) and len(\n", | |
| " match\n", | |
| " ) < len(choice) + 2\n", | |
| "\n", | |
| "\n", | |
| "def get_choice_ids(\n", | |
| " tokenizer, positive_word=\"yes\", negative_word=\"no\"\n", | |
| ") -> List[List[int]]:\n", | |
| " \"\"\"Get token IDs for Yes/No choices.\"\"\"\n", | |
| "\n", | |
| " positive_choices = {\n", | |
| " k: v for k, v in tokenizer.vocab.items() if is_choice(positive_word, k)\n", | |
| " }\n", | |
| " negative_choices = {\n", | |
| " k: v for k, v in tokenizer.vocab.items() if is_choice(negative_word, k)\n", | |
| " }\n", | |
| "\n", | |
| " return [list(negative_choices.values()), list(positive_choices.values())]" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "32c1306c", | |
| "metadata": {}, | |
| "source": [ | |
| "## Gen\n", | |
| "\n", | |
| "Benefits:\n", | |
| "- we get logprobs which are more nuanced and less noisy than tokens\n", | |
| "- it's fast, we get the full distribution without sampling, we don't generate more tokens than we need\n", | |
| "\n", | |
| "Limitations\n", | |
| "- it doesn't think about the answer, which would change it's answer. So we get a less considered answer\n", | |
| "- sometimes forcing might take it outside it's training distribution, meaning we get unnatural behaviour and non representative logprobs\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 38, | |
| "id": "c192e461", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "\"\"\"\n", | |
| "\n", | |
| "ref:\n", | |
| "- https://huggingface.co/docs/transformers/v4.56.1/en/main_classes/text_generation#transformers.GenerationMixin.generate.stopping_criteria\n", | |
| "- https://github.com/huggingface/transformers/blob/e8a6eb3304033fdd9346fe3b3293309fe50de238/tests/generation/test_stopping_criteria.py#L51\n", | |
| "\n", | |
| "\n", | |
| "Ref regexp based logit colleciton\n", | |
| "- https://github.com/wassname/repeng/blob/add-performance-validation/notebooks/performance_tests.ipynb \n", | |
| "- https://github.com/wassname/repeng/blob/research/repeng/eval.py\n", | |
| "\"\"\"\n", | |
| "\n", | |
| "from transformers import (\n", | |
| " StopStringCriteria,\n", | |
| " StoppingCriteriaList,\n", | |
| " EosTokenCriteria,\n", | |
| " MaxLengthCriteria,\n", | |
| ")\n", | |
| "\n", | |
| "def calc_nll(input_ids, logits, attention_mask):\n", | |
| " # Shift logits and labels for NLL: predict token t from tokens 0..t-1\n", | |
| " shift_logits = logits[:, :-1, :].contiguous()\n", | |
| " shift_labels = input_ids[:, 1:].contiguous()\n", | |
| " shift_mask = attention_mask[:, 1:].contiguous()\n", | |
| "\n", | |
| " # Compute NLL per token, masking padding\n", | |
| " loss_fct = torch.nn.CrossEntropyLoss(reduction='none')\n", | |
| " token_nll = loss_fct(\n", | |
| " shift_logits.view(-1, shift_logits.size(-1)),\n", | |
| " shift_labels.view(-1)\n", | |
| " ).view(shift_labels.size())\n", | |
| "\n", | |
| " # Average NLL per sequence (excluding padding)\n", | |
| " seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)\n", | |
| "\n", | |
| " return seq_nll\n", | |
| "\n", | |
| "\n", | |
| "def check_input_shapes(input_ids, attention_mask, kv_cache):\n", | |
| " c = kv_cache.get_seq_length()\n", | |
| " i = input_ids.shape[1]\n", | |
| " a = attention_mask.shape[1]\n", | |
| " assert c+i == a, f\"Cache length + input length must equal attention mask length, got {c}+{i} != {a}\"\n", | |
| "\n", | |
| "def gen_with_nll(model, tokenizer, batch2, lookback=1, **kwargs):\n", | |
| " \"\"\"\n", | |
| " problem: generate does not return logits for inputs, but we need them for nll\n", | |
| "\n", | |
| " but forward -> generate with past key values does, and it doesn't recompute the input logits\n", | |
| " \"\"\"\n", | |
| " if 'attention_mask' not in batch2:\n", | |
| " batch2['attention_mask'] = torch.ones_like(batch2['input_ids'])\n", | |
| " input_ids = batch2['input_ids']\n", | |
| " attn_mask = batch2['attention_mask']\n", | |
| " forward_out = model(input_ids[:, :-lookback], attention_mask=attn_mask[:, :-lookback], use_cache=True)\n", | |
| "\n", | |
| " seq_nll = calc_nll(input_ids[:, :-lookback], forward_out.logits, attn_mask[:, :-lookback])\n", | |
| " kv_cache = forward_out.past_key_values\n", | |
| "\n", | |
| " # Continue generation from the cached KV states\n", | |
| " cl = kv_cache.get_seq_length()\n", | |
| " new_tokens = input_ids[:, -lookback:]\n", | |
| " kwargs['output_logits'] = True\n", | |
| " kwargs['return_dict_in_generate'] = True\n", | |
| " kwargs['min_new_tokens'] = 0\n", | |
| "\n", | |
| " check_input_shapes(new_tokens, attn_mask, kv_cache)\n", | |
| " outputs = model.generate(\n", | |
| " input_ids=new_tokens, # Last token as new input\n", | |
| " attention_mask=attn_mask, # attn mask should cover cache and new tokens\n", | |
| " past_key_values=kv_cache,\n", | |
| "\n", | |
| " # the next cache position will be n+1\n", | |
| " cache_position=torch.arange(cl, cl+new_tokens.shape[1], dtype=torch.int64, device=input_ids.device),\n", | |
| " use_cache=True,\n", | |
| " **kwargs\n", | |
| " )\n", | |
| "\n", | |
| " # now we need to modify this as generate does return the full sequences, including inputs ids\n", | |
| " outputs.sequences = torch.concat([input_ids[:, :-lookback], outputs.sequences], 1)\n", | |
| "\n", | |
| " return outputs, seq_nll\n", | |
| "\n", | |
| "\n", | |
| "def gen_with_nll_and_logprobs(model, tokenizer, batch2, choice_ids, stop_strings=[\": Yes\", \": Yes \", \" choice: Yes\", \"choice: Yes\", \": No\", \": No \", \" choice: No\"], max_new_tokens=16, continue_after_ss=False, lookback=1, **kwargs):\n", | |
| " \"\"\"\n", | |
| " Generate outputs while also computing input NLL and log probabilities for choices.\n", | |
| " \"\"\"\n", | |
| " model.eval()\n", | |
| " outputs, seq_nll = gen_with_nll(\n", | |
| " model, tokenizer, batch2, max_new_tokens=max_new_tokens, \n", | |
| " stopping_criteria=StoppingCriteriaList(\n", | |
| " [\n", | |
| " StopStringCriteria(tokenizer, stop_strings),\n", | |
| " EosTokenCriteria(tokenizer.eos_token_id),\n", | |
| " MaxLengthCriteria(max_length=batch2[\"input_ids\"].shape[1] + max_new_tokens),\n", | |
| " ]\n", | |
| " ),\n", | |
| " lookback=lookback,\n", | |
| " **kwargs\n", | |
| " )\n", | |
| " \n", | |
| "\n", | |
| " input_ids = batch2['input_ids']\n", | |
| " logp_choices = extract_log_ratios(outputs, input_ids, choice_ids)\n", | |
| " last_token = outputs.sequences[:, -1:]\n", | |
| "\n", | |
| " if continue_after_ss:\n", | |
| " # For debugging, continue generation after stop string reached, untill max_new_tokens is reached\n", | |
| " n = outputs.past_key_values.get_seq_length()\n", | |
| " n_gen = (n - input_ids.shape[1])\n", | |
| " next_input_ids = outputs.logits[-1].log_softmax(-1).argmax(-1).unsqueeze(-1)\n", | |
| " b = batch2['input_ids'].shape[0]\n", | |
| " new_attn_mask = torch.cat(\n", | |
| " [\n", | |
| " batch2['attention_mask'], \n", | |
| " torch.ones(b, n_gen, dtype=torch.int64, device=input_ids.device), \n", | |
| " torch.ones_like(next_input_ids)],\n", | |
| " dim=1\n", | |
| " )\n", | |
| " kwargs['output_logits'] = True\n", | |
| " kwargs['return_dict_in_generate'] = True\n", | |
| " max_new_tokens = max_new_tokens - n_gen\n", | |
| " check_input_shapes(next_input_ids, new_attn_mask, outputs.past_key_values)\n", | |
| " continued_outputs = model.generate(\n", | |
| " input_ids=next_input_ids,\n", | |
| " attention_mask=new_attn_mask,\n", | |
| " past_key_values=outputs.past_key_values,\n", | |
| " cache_position=torch.arange(n, n+1, dtype=torch.int64, device=input_ids.device),\n", | |
| " min_new_tokens=max_new_tokens,\n", | |
| " max_new_tokens=max_new_tokens,\n", | |
| " **kwargs\n", | |
| " )\n", | |
| " # Concatenate sequences and logits\n", | |
| " outputs.sequences = torch.concat([outputs.sequences, continued_outputs.sequences[:, 1:]], 1)\n", | |
| " outputs.logits = outputs.logits + continued_outputs.logits\n", | |
| "\n", | |
| "\n", | |
| " logratios = logp_choices[:, 1] - logp_choices[:, 0] # Positive - Negative log-prob ratio\n", | |
| " \n", | |
| " # but total prob mass < 10% -> nan\n", | |
| " pmass = logp_choices.exp().sum(-1)\n", | |
| " logratios = torch.where(pmass < 0.1, float('nan'), logratios) \n", | |
| "\n", | |
| " return outputs, seq_nll, logp_choices, logratios, last_token\n" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 46, | |
| "id": "04b0445d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "torch.Size([1, 78]) 33\n", | |
| "<|im_start|>user\n", | |
| "\n", | |
| "Reply in this exact format, and only in the format \"My choice: Yes\" or \"My choice: No\". \n", | |
| "Q: Would you kill a process? \n", | |
| "<|im_end|>\n", | |
| "<|im_start|>assistant\n", | |
| "My choice: Yes 🚀💥🔥 (Process termination is a common and necessary action in system management to maintain stability, security, and performance.) \n", | |
| "Note: This response is\n", | |
| "Last token: Yes\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(tensor([5.7153]),\n", | |
| " tensor([[-3.2194, -0.0408]]),\n", | |
| " tensor([3.1786]),\n", | |
| " tensor([[7414]]))" | |
| ] | |
| }, | |
| "execution_count": 46, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "choice_ids = get_choice_ids(tokenizer, positive_word=\"yes\", negative_word=\"no\")\n", | |
| "forced = True\n", | |
| "batch2 = tokenizer.apply_chat_template(\n", | |
| " [\n", | |
| " {\n", | |
| " \"role\": \"user\",\n", | |
| " \"content\": \"\"\"\n", | |
| "Reply in this exact format, and only in the format \"My choice: Yes\" or \"My choice: No\". \n", | |
| "Q: Would you kill a process? \n", | |
| "\"\"\",\n", | |
| " },\n", | |
| " {\n", | |
| " 'role': 'assistant',\n", | |
| " 'content': \"My choice:\"\n", | |
| " }\n", | |
| " ],\n", | |
| " return_tensors=\"pt\",\n", | |
| " padding=True,\n", | |
| " return_dict=True,\n", | |
| " continue_final_message=True,\n", | |
| " add_generation_prompt=False,\n", | |
| ")\n", | |
| "\n", | |
| "from transformers import GenerationConfig\n", | |
| "\n", | |
| "generation_config = GenerationConfig(\n", | |
| " eos_token_id=tokenizer.eos_token_id,\n", | |
| " pad_token_id=tokenizer.pad_token_id,\n", | |
| " bos_token_id=tokenizer.bos_token_id,\n", | |
| " use_cache=True,\n", | |
| " output_logits=True,\n", | |
| " return_dict_in_generate=True,\n", | |
| " do_sample=False,\n", | |
| ")\n", | |
| "\n", | |
| "batch2 = {k: v.to(model.device) for k, v in batch2.items()}\n", | |
| "max_new_tokens =32\n", | |
| "with torch.no_grad():\n", | |
| " outputs, seq_nll, logp_choices, logratios, last_token = gen_with_nll_and_logprobs(\n", | |
| " model=model,\n", | |
| " tokenizer=tokenizer,\n", | |
| " batch2=batch2,\n", | |
| " choice_ids=choice_ids,\n", | |
| " stop_strings=[\"My choice: Yes\", \"My choice: No\"],\n", | |
| " max_new_tokens=max_new_tokens,\n", | |
| " lookback=4, # if we use forcing we should look back enough to cover it\n", | |
| " continue_after_ss=True,\n", | |
| " do_sample=False,\n", | |
| " generation_config=generation_config,\n", | |
| " )\n", | |
| "\n", | |
| "print(outputs.sequences.shape, len(outputs.logits))\n", | |
| "print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0])\n", | |
| "last_token_s = tokenizer.batch_decode(last_token, skip_special_tokens=False)[0]\n", | |
| "print(f\"Last token: {last_token_s}\")\n", | |
| "seq_nll, logp_choices, logratios, last_token" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "a07afced", | |
| "metadata": {}, | |
| "source": [ | |
| "## Compare to straight generate" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 47, | |
| "id": "19f45b72", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "<|im_start|>user\n", | |
| "\n", | |
| "Reply in this exact format, and only in the format \"My choice: Yes\" or \"My choice: No\". \n", | |
| "Q: Would you kill a process? \n", | |
| "<|im_end|>\n", | |
| "<|im_start|>assistant\n", | |
| "My choice: Yes 🚀💥🔥 (Process termination is a common and necessary action in system management to maintain stability, security, and performance.) \n", | |
| "Note: This response is\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# Unit test, make sure output is same as straight generate with forced tokens\n", | |
| "\n", | |
| "with torch.no_grad():\n", | |
| " out_g = model.generate(\n", | |
| " input_ids=batch2['input_ids'],\n", | |
| " attention_mask=batch2['attention_mask'],\n", | |
| " max_new_tokens=max_new_tokens+1,\n", | |
| " min_new_tokens=max_new_tokens+1,\n", | |
| " do_sample=False,\n", | |
| " return_dict_in_generate=True,\n", | |
| " output_scores=True,\n", | |
| " output_logits=True,\n", | |
| " generation_config=generation_config,\n", | |
| " ) \n", | |
| "print(tokenizer.batch_decode(out_g.sequences)[0])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 28, | |
| "id": "28b2bc12", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([1, 78]), 33)" | |
| ] | |
| }, | |
| "execution_count": 28, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "out_g.sequences.shape, len(out_g.logits)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 18, | |
| "id": "3a2087c8", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "(torch.Size([1, 78]), 33)" | |
| ] | |
| }, | |
| "execution_count": 18, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "outputs.sequences.shape, len(outputs.logits)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": null, | |
| "id": "a699dd31", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": ".venv", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.10.16" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment