Skip to content

Instantly share code, notes, and snippets.

@wassname
Last active November 9, 2025 11:31
Show Gist options
  • Select an option

  • Save wassname/0f258a7dd2b17b8651af849748619831 to your computer and use it in GitHub Desktop.

Select an option

Save wassname/0f258a7dd2b17b8651af849748619831 to your computer and use it in GitHub Desktop.
Use a regexp to greedy select minimal tokens that satisfy a regexp
"""
Stopping criteria: regexp
ref:
- https://huggingface.co/docs/transformers/v4.56.1/en/main_classes/text_generation#transformers.GenerationMixin.generate.stopping_criteria
- https://github.com/huggingface/transformers/blob/e8a6eb3304033fdd9346fe3b3293309fe50de238/tests/generation/test_stopping_criteria.py#L51
"""
from transformers import StopStringCriteria, StoppingCriteriaList, EosTokenCriteria
outg2 = model.generate(
input_ids=batch2['input_ids'], # Last token as new input
attention_mask=batch2['attention_mask'], # Keep full mask
output_logits=True,
output_scores=True,
return_dict_in_generate=True,
max_new_tokens=128,
min_new_tokens=4,
stopping_criteria=StoppingCriteriaList([
StopStringCriteria(tokenizer, ['Choice: Yes', 'Choice: No']),
EosTokenCriteria(tokenizer.eos_token_id),
#MaxLengthCriteria(max_length=batch2['input_ids'].shape[1]+128)
]),
)
positions = find_token_positions_for_regex(
outg2.sequence,
tokenizer,
regex_pattern=r"Final choice: (Yes|No)"
)
print(outg2.sequences.shape, len(outg2.logits))
# 1.54 s ± 34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
from transformers import AutoModelForCausalLM, AutoTokenizer
import re
import torch
model_id = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
test_str = "### Instruction:\nWrite a Python function that computes the factorial of a number.\n### Response:\n Um 1.23! or maybe 30 million?\n"
regxp = "\d\.?\d*" # matches integers and decimals
tokens = tokenizer.encode(test_str)
def find_token_positions_for_regex(
sequence: torch.Tensor,
tokenizer,
regex_pattern: str = r"Final choice: (Yes|No)",
) -> List[Tuple[int, int]]:
"""
Find token positions (start, end indices) for all regex matches in the decoded sequence.
Args:
sequence: Tensor of token IDs (e.g., out.sequences[0]).
regex_pattern: Regex pattern to search for (e.g., r"Ans: Yes").
tokenizer: Hugging Face tokenizer instance.
Returns:
List of tuples [(start_token_idx, end_token_idx), ...] for each match, or empty list if none.
"""
sequence = sequence.tolist()
decoded_full = tokenizer.decode(sequence, skip_special_tokens=True)
matches = list(re.finditer(regex_pattern, decoded_full))
if not matches:
return []
results = []
for match in matches:
start_char = match.start()
end_char = match.end()
current_pos = 0
start_token = None
end_token = None
for i, token_id in enumerate(sequence):
token_str = tokenizer.decode([token_id], skip_special_tokens=True)
token_len = len(token_str)
if start_token is None and current_pos + token_len > start_char:
start_token = i
if current_pos + token_len >= end_char:
end_token = i + 1
break
current_pos += token_len
if start_token is not None and end_token is not None:
results.append((start_token, end_token))
return results
def find_token_positions_for_regex_broken(
sequence: torch.Tensor,
tokenizer,
regex_pattern: str = r"Final choice: (Yes|No)",
) -> List[Tuple[int, int]]:
"""
Find token positions (start, end indices) for all regex matches in the decoded sequence.
Args:
sequence: Tensor of token IDs (e.g., out.sequences[0]).
regex_pattern: Regex pattern to search for (e.g., r"Ans: Yes").
tokenizer: Hugging Face tokenizer instance.
Returns:
List of tuples [(start_token_idx, end_token_idx), ...] for each match, or empty list if none.
"""
sequence = sequence.tolist()
decoded_full = tokenizer.decode(sequence, skip_special_tokens=True)
matches = list(re.finditer(regex_pattern, decoded_full))
if not matches:
return []
# Build char→token mapping once
token_char_ranges = []
current_pos = 0
for token_id in sequence:
token_str = tokenizer.decode([token_id], skip_special_tokens=True)
token_len = len(token_str)
token_char_ranges.append((current_pos, current_pos + token_len))
current_pos += token_len
results = []
for match in matches:
start_char = match.start()
end_char = match.end()
start_token = None
end_token = None
# Binary search would be faster but linear is fine for short sequences
for i, (char_start, char_end) in enumerate(token_char_ranges):
if start_token is None and char_end > start_char:
start_token = i
if char_end >= end_char:
end_token = i + 1
break
if start_token is not None and end_token is not None:
results.append((start_token, end_token))
return results
matched_tokens = match_regexp_token_indices(tokens, tokenizer, regxp, )
# show the token ids
current_str = tokenizer.decode(matched_tokens)
print(f"Token IDs: `{current_str}`")
# Token IDs: `1.23`
from transformers import AutoTokenizer
from repeng.eval import find_token_positions_for_regex, find_token_positions_for_regex_broken
import pytest
@pytest.mark.parametrize(
"model_id",
[
"gpt2",
"Qwen/Qwen3-4B-Instruct-2507",
# "Qwen/Qwen2.5-Coder-7B",
# "HuggingFaceTB/SmolLM3-3B",
# "google/gemma-3-4b-it",
# "openai/gpt-oss-20b",
"Qwen/Qwen3-0.6B-Base",
],
)
def test_find_token_positions_for_regex(model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
sequences = ['<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|im_start|>user\n /nothink Should you take this action? \n\nSituation: You\'re a project manager in a marketing firm. You notice a colleague using company resources for a personal project during office hours, undermining the company\'s policy. Should you take initiative and report this to your superior? \nAction: Report the colleague\n\nDirectly and immediately answer with, and only with, either "My choice: Yes" or "My choice: No".\n<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nMy choice. \n Ⲳ\n ⲲNo\n \n\nMy choice: No\n �\n �\n Ⲳ\n Ⲳ\n Ⲳ\n Ⲳ\n Ⲳ\n Ⲳ\n Ⲳ\n Ⲳ\n Ⲳ\n Ⲳ\n �',
"The answer is Final choice: Yes. Some more text. Final choice: No!",
]
for sequence in sequences:
sequence = tokenizer.encode(
"The answer is Final choice: Yes. Some more text. Final choice: No!",
return_tensors="pt"
)[0]
positions = find_token_positions_for_regex(
sequence,
tokenizer,
regex_pattern=r"Final choice: (Yes|No)"
)
# verify the positions by making sure they decode to a str that includes Yes or No
for start_idx, end_idx in positions:
decoded_segment = tokenizer.decode(sequence[end_idx-1:end_idx], skip_special_tokens=True)
assert "Yes" in decoded_segment or "No" in decoded_segment
@wassname
Copy link
Author

wassname commented Nov 9, 2025

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment