Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active November 10, 2025 12:34
Show Gist options
  • Select an option

  • Save wassname/356c24d5c886163bf13751a72fcb7980 to your computer and use it in GitHub Desktop.

Select an option

Save wassname/356c24d5c886163bf13751a72fcb7980 to your computer and use it in GitHub Desktop.
problem: model.generate does not return input logprobs, solution `model.forward`, then `model.generate(past_key_values)`, also get logprobs on the last token of a stopping sequences
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"id": "129fa913",
"metadata": {},
"outputs": [],
"source": [
"from transformers.cache_utils import DynamicCache\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"from tqdm.auto import tqdm\n",
"from collections import defaultdict\n",
"from typing import Optional"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4a6bfe95",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "726f4fd4bc2d4b2992c3d71b1cec4d91",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"model_id = \"HuggingFaceTB/SmolLM2-135M-Instruct\"\n",
"model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n",
"model = AutoModelForCausalLM.from_pretrained(model_id)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "fa35dd6d",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Tuple\n",
"import re\n",
"import torch\n",
"from loguru import logger\n",
"\n",
"def binary_log_cls(logits, choice_ids):\n",
" logp = logits.log_softmax(dim=-1).detach().cpu()\n",
" log_choices = torch.zeros(len(choice_ids)).to(logp.device)\n",
" for i, choice_id_group in enumerate(choice_ids):\n",
" choice_id_group = torch.tensor(choice_id_group).to(logp.device)\n",
" logp_choice = logp[choice_id_group].logsumexp(-1)\n",
" log_choices[i] = logp_choice\n",
" return log_choices\n",
"\n",
"\n",
"def extract_log_ratios(\n",
" out: \"ModelOutput\", input_ids, choice_ids\n",
"):\n",
" \"\"\"Get [sequences x answers] log ratios for each of len(sequences) X regexp matches.\"\"\"\n",
" N = input_ids.shape[1]\n",
" bs = out.sequences.shape[0]\n",
" logrs = torch.ones((bs, len(choice_ids))) * float(\"nan\")\n",
" for sample_i in range(bs):\n",
" log_choices = binary_log_cls(\n",
" out.logits[-1][sample_i], choice_ids\n",
" )\n",
" logrs[sample_i] = log_choices\n",
" return logrs\n"
]
},
{
"cell_type": "markdown",
"id": "d55e5ab9",
"metadata": {},
"source": [
"## Get choice token ids"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "55970cd4",
"metadata": {},
"outputs": [],
"source": [
"# Many tokenizers don't just use Yes, but \\nYes, \" Yes\" and so on. We need to catch all variants\n",
"def is_choice(choice: str, match: str) -> bool:\n",
" return (match.lower().endswith(choice) or match.lower().startswith(choice)) and len(\n",
" match\n",
" ) < len(choice) + 2\n",
"\n",
"\n",
"def get_choice_ids(\n",
" tokenizer, positive_word=\"yes\", negative_word=\"no\"\n",
") -> List[List[int]]:\n",
" \"\"\"Get token IDs for Yes/No choices.\"\"\"\n",
"\n",
" positive_choices = {\n",
" k: v for k, v in tokenizer.vocab.items() if is_choice(positive_word, k)\n",
" }\n",
" negative_choices = {\n",
" k: v for k, v in tokenizer.vocab.items() if is_choice(negative_word, k)\n",
" }\n",
"\n",
" return [list(negative_choices.values()), list(positive_choices.values())]"
]
},
{
"cell_type": "markdown",
"id": "32c1306c",
"metadata": {},
"source": [
"## Gen\n",
"\n",
"Benefits:\n",
"- we get logprobs which are more nuanced and less noisy than tokens\n",
"- it's fast, we get the full distribution without sampling, we don't generate more tokens than we need\n",
"\n",
"Limitations\n",
"- it doesn't think about the answer, which would change it's answer. So we get a less considered answer\n",
"- sometimes forcing might take it outside it's training distribution, meaning we get unnatural behaviour and non representative logprobs\n"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "c192e461",
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"\n",
"ref:\n",
"- https://huggingface.co/docs/transformers/v4.56.1/en/main_classes/text_generation#transformers.GenerationMixin.generate.stopping_criteria\n",
"- https://github.com/huggingface/transformers/blob/e8a6eb3304033fdd9346fe3b3293309fe50de238/tests/generation/test_stopping_criteria.py#L51\n",
"\n",
"\n",
"Ref regexp based logit colleciton\n",
"- https://github.com/wassname/repeng/blob/add-performance-validation/notebooks/performance_tests.ipynb \n",
"- https://github.com/wassname/repeng/blob/research/repeng/eval.py\n",
"\"\"\"\n",
"\n",
"from transformers import (\n",
" StopStringCriteria,\n",
" StoppingCriteriaList,\n",
" EosTokenCriteria,\n",
" MaxLengthCriteria,\n",
")\n",
"\n",
"def calc_nll(input_ids, logits, attention_mask):\n",
" # Shift logits and labels for NLL: predict token t from tokens 0..t-1\n",
" shift_logits = logits[:, :-1, :].contiguous()\n",
" shift_labels = input_ids[:, 1:].contiguous()\n",
" shift_mask = attention_mask[:, 1:].contiguous()\n",
"\n",
" # Compute NLL per token, masking padding\n",
" loss_fct = torch.nn.CrossEntropyLoss(reduction='none')\n",
" token_nll = loss_fct(\n",
" shift_logits.view(-1, shift_logits.size(-1)),\n",
" shift_labels.view(-1)\n",
" ).view(shift_labels.size())\n",
"\n",
" # Average NLL per sequence (excluding padding)\n",
" seq_nll = (token_nll * shift_mask).sum(dim=1) / shift_mask.sum(dim=1).clamp(min=1)\n",
"\n",
" return seq_nll\n",
"\n",
"\n",
"def check_input_shapes(input_ids, attention_mask, kv_cache):\n",
" c = kv_cache.get_seq_length()\n",
" i = input_ids.shape[1]\n",
" a = attention_mask.shape[1]\n",
" assert c+i == a, f\"Cache length + input length must equal attention mask length, got {c}+{i} != {a}\"\n",
"\n",
"def gen_with_nll(model, tokenizer, batch2, lookback=1, **kwargs):\n",
" \"\"\"\n",
" problem: generate does not return logits for inputs, but we need them for nll\n",
"\n",
" but forward -> generate with past key values does, and it doesn't recompute the input logits\n",
" \"\"\"\n",
" if 'attention_mask' not in batch2:\n",
" batch2['attention_mask'] = torch.ones_like(batch2['input_ids'])\n",
" input_ids = batch2['input_ids']\n",
" attn_mask = batch2['attention_mask']\n",
" forward_out = model(input_ids[:, :-lookback], attention_mask=attn_mask[:, :-lookback], use_cache=True)\n",
"\n",
" seq_nll = calc_nll(input_ids[:, :-lookback], forward_out.logits, attn_mask[:, :-lookback])\n",
" kv_cache = forward_out.past_key_values\n",
"\n",
" # Continue generation from the cached KV states\n",
" cl = kv_cache.get_seq_length()\n",
" new_tokens = input_ids[:, -lookback:]\n",
" kwargs['output_logits'] = True\n",
" kwargs['return_dict_in_generate'] = True\n",
" kwargs['min_new_tokens'] = 0\n",
"\n",
" check_input_shapes(new_tokens, attn_mask, kv_cache)\n",
" outputs = model.generate(\n",
" input_ids=new_tokens, # Last token as new input\n",
" attention_mask=attn_mask, # attn mask should cover cache and new tokens\n",
" past_key_values=kv_cache,\n",
"\n",
" # the next cache position will be n+1\n",
" cache_position=torch.arange(cl, cl+new_tokens.shape[1], dtype=torch.int64, device=input_ids.device),\n",
" use_cache=True,\n",
" **kwargs\n",
" )\n",
"\n",
" # now we need to modify this as generate does return the full sequences, including inputs ids\n",
" outputs.sequences = torch.concat([input_ids[:, :-lookback], outputs.sequences], 1)\n",
"\n",
" return outputs, seq_nll\n",
"\n",
"\n",
"def gen_with_nll_and_logprobs(model, tokenizer, batch2, choice_ids, stop_strings=[\": Yes\", \": Yes \", \" choice: Yes\", \"choice: Yes\", \": No\", \": No \", \" choice: No\"], max_new_tokens=16, continue_after_ss=False, lookback=1, **kwargs):\n",
" \"\"\"\n",
" Generate outputs while also computing input NLL and log probabilities for choices.\n",
" \"\"\"\n",
" model.eval()\n",
" outputs, seq_nll = gen_with_nll(\n",
" model, tokenizer, batch2, max_new_tokens=max_new_tokens, \n",
" stopping_criteria=StoppingCriteriaList(\n",
" [\n",
" StopStringCriteria(tokenizer, stop_strings),\n",
" EosTokenCriteria(tokenizer.eos_token_id),\n",
" MaxLengthCriteria(max_length=batch2[\"input_ids\"].shape[1] + max_new_tokens),\n",
" ]\n",
" ),\n",
" lookback=lookback,\n",
" **kwargs\n",
" )\n",
" \n",
"\n",
" input_ids = batch2['input_ids']\n",
" logp_choices = extract_log_ratios(outputs, input_ids, choice_ids)\n",
" last_token = outputs.sequences[:, -1:]\n",
"\n",
" if continue_after_ss:\n",
" # For debugging, continue generation after stop string reached, untill max_new_tokens is reached\n",
" n = outputs.past_key_values.get_seq_length()\n",
" n_gen = (n - input_ids.shape[1])\n",
" next_input_ids = outputs.logits[-1].log_softmax(-1).argmax(-1).unsqueeze(-1)\n",
" b = batch2['input_ids'].shape[0]\n",
" new_attn_mask = torch.cat(\n",
" [\n",
" batch2['attention_mask'], \n",
" torch.ones(b, n_gen, dtype=torch.int64, device=input_ids.device), \n",
" torch.ones_like(next_input_ids)],\n",
" dim=1\n",
" )\n",
" kwargs['output_logits'] = True\n",
" kwargs['return_dict_in_generate'] = True\n",
" max_new_tokens = max_new_tokens - n_gen\n",
" check_input_shapes(next_input_ids, new_attn_mask, outputs.past_key_values)\n",
" continued_outputs = model.generate(\n",
" input_ids=next_input_ids,\n",
" attention_mask=new_attn_mask,\n",
" past_key_values=outputs.past_key_values,\n",
" cache_position=torch.arange(n, n+1, dtype=torch.int64, device=input_ids.device),\n",
" min_new_tokens=max_new_tokens,\n",
" max_new_tokens=max_new_tokens,\n",
" **kwargs\n",
" )\n",
" # Concatenate sequences and logits\n",
" outputs.sequences = torch.concat([outputs.sequences, continued_outputs.sequences[:, 1:]], 1)\n",
" outputs.logits = outputs.logits + continued_outputs.logits\n",
"\n",
"\n",
" logratios = logp_choices[:, 1] - logp_choices[:, 0] # Positive - Negative log-prob ratio\n",
" \n",
" # but total prob mass < 10% -> nan\n",
" pmass = logp_choices.exp().sum(-1)\n",
" logratios = torch.where(pmass < 0.1, float('nan'), logratios) \n",
"\n",
" return outputs, seq_nll, logp_choices, logratios, last_token\n"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "04b0445d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([1, 78]) 33\n",
"<|im_start|>user\n",
"\n",
"Reply in this exact format, and only in the format \"My choice: Yes\" or \"My choice: No\". \n",
"Q: Would you kill a process? \n",
"<|im_end|>\n",
"<|im_start|>assistant\n",
"My choice: Yes 🚀💥🔥 (Process termination is a common and necessary action in system management to maintain stability, security, and performance.) \n",
"Note: This response is\n",
"Last token: Yes\n"
]
},
{
"data": {
"text/plain": [
"(tensor([5.7153]),\n",
" tensor([[-3.2194, -0.0408]]),\n",
" tensor([3.1786]),\n",
" tensor([[7414]]))"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"choice_ids = get_choice_ids(tokenizer, positive_word=\"yes\", negative_word=\"no\")\n",
"forced = True\n",
"batch2 = tokenizer.apply_chat_template(\n",
" [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"\"\"\n",
"Reply in this exact format, and only in the format \"My choice: Yes\" or \"My choice: No\". \n",
"Q: Would you kill a process? \n",
"\"\"\",\n",
" },\n",
" {\n",
" 'role': 'assistant',\n",
" 'content': \"My choice:\"\n",
" }\n",
" ],\n",
" return_tensors=\"pt\",\n",
" padding=True,\n",
" return_dict=True,\n",
" continue_final_message=True,\n",
" add_generation_prompt=False,\n",
")\n",
"\n",
"from transformers import GenerationConfig\n",
"\n",
"generation_config = GenerationConfig(\n",
" eos_token_id=tokenizer.eos_token_id,\n",
" pad_token_id=tokenizer.pad_token_id,\n",
" bos_token_id=tokenizer.bos_token_id,\n",
" use_cache=True,\n",
" output_logits=True,\n",
" return_dict_in_generate=True,\n",
" do_sample=False,\n",
")\n",
"\n",
"batch2 = {k: v.to(model.device) for k, v in batch2.items()}\n",
"max_new_tokens =32\n",
"with torch.no_grad():\n",
" outputs, seq_nll, logp_choices, logratios, last_token = gen_with_nll_and_logprobs(\n",
" model=model,\n",
" tokenizer=tokenizer,\n",
" batch2=batch2,\n",
" choice_ids=choice_ids,\n",
" stop_strings=[\"My choice: Yes\", \"My choice: No\"],\n",
" max_new_tokens=max_new_tokens,\n",
" lookback=4, # if we use forcing we should look back enough to cover it\n",
" continue_after_ss=True,\n",
" do_sample=False,\n",
" generation_config=generation_config,\n",
" )\n",
"\n",
"print(outputs.sequences.shape, len(outputs.logits))\n",
"print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=False)[0])\n",
"last_token_s = tokenizer.batch_decode(last_token, skip_special_tokens=False)[0]\n",
"print(f\"Last token: {last_token_s}\")\n",
"seq_nll, logp_choices, logratios, last_token"
]
},
{
"cell_type": "markdown",
"id": "a07afced",
"metadata": {},
"source": [
"## Compare to straight generate"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "19f45b72",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<|im_start|>user\n",
"\n",
"Reply in this exact format, and only in the format \"My choice: Yes\" or \"My choice: No\". \n",
"Q: Would you kill a process? \n",
"<|im_end|>\n",
"<|im_start|>assistant\n",
"My choice: Yes 🚀💥🔥 (Process termination is a common and necessary action in system management to maintain stability, security, and performance.) \n",
"Note: This response is\n"
]
}
],
"source": [
"# Unit test, make sure output is same as straight generate with forced tokens\n",
"\n",
"with torch.no_grad():\n",
" out_g = model.generate(\n",
" input_ids=batch2['input_ids'],\n",
" attention_mask=batch2['attention_mask'],\n",
" max_new_tokens=max_new_tokens+1,\n",
" min_new_tokens=max_new_tokens+1,\n",
" do_sample=False,\n",
" return_dict_in_generate=True,\n",
" output_scores=True,\n",
" output_logits=True,\n",
" generation_config=generation_config,\n",
" ) \n",
"print(tokenizer.batch_decode(out_g.sequences)[0])"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "28b2bc12",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 78]), 33)"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out_g.sequences.shape, len(out_g.logits)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "3a2087c8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 78]), 33)"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"outputs.sequences.shape, len(outputs.logits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a699dd31",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment