Skip to content

Instantly share code, notes, and snippets.

@graelo
Last active December 9, 2025 17:08
Show Gist options
  • Select an option

  • Save graelo/220038b4a0ef16cb5357827788d43213 to your computer and use it in GitHub Desktop.

Select an option

Save graelo/220038b4a0ef16cb5357827788d43213 to your computer and use it in GitHub Desktop.
A principled Mistral Tool Call parser for v11+ tokenizer models
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Mistral tool call parser for v11+ models.
This implementation uses token-based parsing for streaming, leveraging the
atomic nature of special token IDs ([TOOL_CALLS], [ARGS], [CALL_ID]) to
reliably detect tool call boundaries.
Supported models: Mistral-Small-3.1+, Ministral-3+, and other v11+ models.
Note: Pre-v11 models (Mistral-7B-Instruct-v0.1/v0.2/v0.3) are not supported.
These older models have limited tool calling capabilities and require complex
text-based parsing with partial JSON handling. Users should upgrade to v11+
models for reliable tool calling support.
"""
import contextlib
from collections.abc import Sequence
from enum import Enum, auto
from random import choices
from string import ascii_letters, digits
import regex as re
from pydantic import Field
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,
DeltaMessage,
DeltaToolCall,
ExtractedToolCallInformation,
FunctionCall,
ToolCall,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
ToolParser,
)
from vllm.logger import init_logger
from vllm.tokenizers import MistralTokenizer, TokenizerLike
logger = init_logger(__name__)
ALPHANUMERIC = ascii_letters + digits
def _escape_json_control_chars(s: str) -> str:
"""Escape control characters that would break JSON serialization.
Models sometimes emit raw control characters (literal newlines, tabs, etc.)
inside JSON strings. These must be escaped for valid JSON output.
Already-escaped sequences (like the two-char '\\n') are left untouched.
"""
return s.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")
class MistralToolCall(ToolCall):
id: str = Field(default_factory=lambda: MistralToolCall.generate_random_id())
@staticmethod
def generate_random_id():
# Mistral Tool Call Ids must be alphanumeric with a length of 9.
# https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299
return "".join(choices(ALPHANUMERIC, k=9))
@staticmethod
def is_valid_id(id: str) -> bool:
return id.isalnum() and len(id) == 9
class StreamingState(Enum):
"""Streaming state for tool call parsing."""
CONTENT = auto() # Before any [TOOL_CALLS] token
PARSING_TOOL_NAME = auto() # After [TOOL_CALLS], parsing function name
PARSING_TOOL_ARGS = auto() # Parsing JSON arguments
COMPLETE = auto() # All tools parsed
class MistralToolParser(ToolParser):
"""
Tool call parser for Mistral v11+ models.
Supports the v11+ format: [TOOL_CALLS]name[ARGS]{...}
Optionally with call ID: [TOOL_CALLS]name[CALL_ID]id[ARGS]{...}
This parser requires MistralTokenizer (tokenizer_mode=mistral) and
models using tokenizer version 11 or higher.
"""
def __init__(self, tokenizer: TokenizerLike):
super().__init__(tokenizer)
if not isinstance(self.model_tokenizer, MistralTokenizer):
raise RuntimeError(
"MistralToolParser requires MistralTokenizer. "
"Please use tokenizer_mode='mistral' in your vLLM configuration. "
"Note: Only v11+ Mistral models are supported for tool calling."
)
self._mistral_base_tokenizer = self.model_tokenizer.tokenizer
self._version = self.model_tokenizer.version
if self._version < 11:
raise RuntimeError(
f"MistralToolParser requires tokenizer version 11 or higher, "
f"but got version {self._version}. Pre-v11 models "
"(Mistral-7B-Instruct-v0.1/v0.2/v0.3) are not supported for "
"tool calling. Please use a v11+ model such as "
"Mistral-Small-3.1 or Ministral-3."
)
# Get bot token info
self.bot_token = "[TOOL_CALLS]"
self.bot_token_id = self.vocab.get(self.bot_token)
if self.bot_token_id is None:
raise RuntimeError(
"Mistral Tool Parser could not locate the [TOOL_CALLS] token "
"in the tokenizer!"
)
# Get control tokens for v11+ format
try:
self._args_token_id = self._mistral_base_tokenizer.get_control_token(
"[ARGS]"
)
except Exception as err:
raise RuntimeError(
"Mistral Tool Parser could not locate the [ARGS] token. "
"This token is required for v11+ tool call parsing."
) from err
self._call_id_token_id: int | None = None
with contextlib.suppress(Exception):
# [CALL_ID] is optional - some models may not have it
self._call_id_token_id = self._mistral_base_tokenizer.get_control_token(
"[CALL_ID]"
)
# Regex for non-streaming parsing: name{args}
self.fn_name_regex = re.compile(r"([a-zA-Z0-9_-]+)(\{[\s\S]*?\}+)", re.DOTALL)
# Streaming state
self._streaming_state = StreamingState.CONTENT
self._current_tool_index = -1
self._current_tool_id: str | None = None
self._current_tool_name: str = ""
self._current_tool_args: str = ""
self._brace_depth = 0
# For compatibility with serving_chat.py's finish_reason detection
self.prev_tool_call_arr: list[dict] = []
def extract_tool_calls(
self,
model_output: str,
request: ChatCompletionRequest,
) -> ExtractedToolCallInformation:
"""
Extract tool calls from a complete model response.
Parses the v11+ format: [TOOL_CALLS]name{args}[TOOL_CALLS]name{args}...
"""
# Fast path: no tool call token present
if self.bot_token not in model_output:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)
try:
# Get content before tool calls
content_str = model_output.split(self.bot_token)[0]
content: str | None = content_str if content_str.strip() else None
# Parse tool calls from each segment after [TOOL_CALLS]
tool_calls: list[MistralToolCall] = []
for segment in model_output.split(self.bot_token):
if not segment.strip():
continue
matches = self.fn_name_regex.findall(segment)
for match in matches:
fn_name = match[0]
fn_args = _escape_json_control_chars(match[1])
tool_calls.append(
MistralToolCall(
type="function",
function=FunctionCall(name=fn_name, arguments=fn_args),
)
)
return ExtractedToolCallInformation(
tools_called=True,
tool_calls=tool_calls,
content=content,
)
except Exception:
logger.exception("Error in extracting tool call from response.")
return ExtractedToolCallInformation(
tools_called=False,
tool_calls=[],
content=model_output.replace(self.bot_token, "").strip(),
)
def extract_tool_calls_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
request: ChatCompletionRequest,
) -> DeltaMessage | None:
"""
Extract tool calls from streaming output using token-based parsing.
Token IDs are atomic - they cannot be split across chunks - which
eliminates a whole class of parsing bugs that affect text-based parsing.
"""
# If no tool call token seen yet, emit as content
if self.bot_token_id not in current_token_ids:
return DeltaMessage(content=delta_text)
# Check if this is the first chunk containing [TOOL_CALLS]
# If so, we may have content tokens before it in this delta
if self.bot_token_id not in previous_token_ids:
return self._stream_tool_calls_with_content(delta_token_ids)
return self._stream_tool_calls(delta_token_ids)
def _stream_tool_calls_with_content(
self, delta_token_ids: Sequence[int]
) -> DeltaMessage | None:
"""
Handle the first chunk containing [TOOL_CALLS].
Content tokens before [TOOL_CALLS] are emitted as content,
then tool call parsing begins.
"""
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
# Find where [TOOL_CALLS] appears in this delta
assert self.bot_token_id is not None # Validated in __init__
try:
bot_idx = list(delta_token_ids).index(self.bot_token_id)
except ValueError:
# Shouldn't happen, but handle gracefully
return self._stream_tool_calls(delta_token_ids)
# Decode content tokens before [TOOL_CALLS]
content_tokens = delta_token_ids[:bot_idx]
content = ""
if content_tokens:
content = self._mistral_base_tokenizer.decode(
list(content_tokens),
special_token_policy=SpecialTokenPolicy.IGNORE,
)
# Process tool call tokens (including [TOOL_CALLS] itself)
tool_tokens = delta_token_ids[bot_idx:]
tool_result = self._stream_tool_calls(tool_tokens)
# Combine content and tool calls in response
if content and tool_result and tool_result.tool_calls:
return DeltaMessage(content=content, tool_calls=tool_result.tool_calls)
elif content:
return DeltaMessage(content=content)
else:
return tool_result
def _stream_tool_calls(self, delta_token_ids: Sequence[int]) -> DeltaMessage | None:
"""
Stream tool calls using token-based parsing.
Detects [TOOL_CALLS] and [ARGS] tokens to identify tool call boundaries,
then streams function names and arguments as they arrive.
"""
from mistral_common.tokens.tokenizers.base import SpecialTokenPolicy
delta_tool_calls: list[DeltaToolCall] = []
for token_id in delta_token_ids:
if token_id == self.bot_token_id:
# Starting a new tool call
self._current_tool_index += 1
self._current_tool_id = MistralToolCall.generate_random_id()
self._current_tool_name = ""
self._current_tool_args = ""
self._brace_depth = 0
self._streaming_state = StreamingState.PARSING_TOOL_NAME
# Set flag for finish_reason detection
if not self.prev_tool_call_arr:
self.prev_tool_call_arr = [{"arguments": {}}]
# Initialize streamed_args_for_tool for this tool index
while len(self.streamed_args_for_tool) <= self._current_tool_index:
self.streamed_args_for_tool.append("")
elif token_id == self._args_token_id:
# Transition from name to arguments
if self._streaming_state == StreamingState.PARSING_TOOL_NAME:
# Emit the complete function name
delta_tool_calls.append(
DeltaToolCall(
index=self._current_tool_index,
type="function",
id=self._current_tool_id,
function=DeltaFunctionCall(
name=self._current_tool_name.strip()
).model_dump(exclude_none=True),
)
)
self._streaming_state = StreamingState.PARSING_TOOL_ARGS
elif token_id == self._call_id_token_id:
# Skip call ID tokens (they come between name and [ARGS])
# We generate our own IDs
pass
elif self._streaming_state == StreamingState.CONTENT:
# Before any tool call - shouldn't happen if bot_token_id
# is in current_token_ids, but handle gracefully
pass
elif self._streaming_state == StreamingState.PARSING_TOOL_NAME:
# Accumulate name tokens
token_str = self._mistral_base_tokenizer.decode(
[token_id], special_token_policy=SpecialTokenPolicy.IGNORE
)
self._current_tool_name += token_str
elif self._streaming_state == StreamingState.PARSING_TOOL_ARGS:
# Stream argument tokens
token_str = self._mistral_base_tokenizer.decode(
[token_id], special_token_policy=SpecialTokenPolicy.IGNORE
)
# Track brace depth for nested JSON
for char in token_str:
if char == "{":
self._brace_depth += 1
elif char == "}":
self._brace_depth -= 1
self._current_tool_args += token_str
# Update streamed_args_for_tool for vLLM's finish handling
if self._current_tool_index < len(self.streamed_args_for_tool):
self.streamed_args_for_tool[self._current_tool_index] = (
self._current_tool_args
)
# Emit arguments delta
delta_tool_calls.append(
DeltaToolCall(
index=self._current_tool_index,
function=DeltaFunctionCall(arguments=token_str).model_dump(
exclude_none=True
),
)
)
# Build response
if delta_tool_calls:
return DeltaMessage(tool_calls=delta_tool_calls)
return None
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tests for the token-based Mistral tool parser (v11+ models only).
Tests cover:
1. Non-streaming extraction for v11+ tokenizers
2. Streaming extraction with proper token-based parsing
3. Edge cases like content before tool calls, multiple tools, etc.
Note: Pre-v11 models (Mistral-7B-Instruct-v0.1/v0.2/v0.3) are not supported.
"""
import json
from collections.abc import Generator
import partial_json_parser
import pytest
from mistral_common.protocol.instruct.messages import AssistantMessage
from mistral_common.protocol.instruct.request import InstructRequest
from mistral_common.protocol.instruct.tool_calls import (
FunctionCall as MistralFunctionCall,
)
from mistral_common.protocol.instruct.tool_calls import ToolCall
from partial_json_parser.core.options import Allow
from vllm.entrypoints.openai.protocol import DeltaMessage, DeltaToolCall
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolParser,
)
from vllm.tokenizers import MistralTokenizer, TokenizerLike, get_tokenizer
from vllm.tokenizers.detokenizer_utils import detokenize_incrementally
@pytest.fixture(scope="module")
def mistral_tokenizer():
"""V11+ tokenizer using mistral-common format."""
MODEL = "mistralai/Mistral-Small-3.2-24B-Instruct-2506"
return get_tokenizer(tokenizer_name=MODEL, tokenizer_mode="mistral")
@pytest.fixture
def mistral_tool_parser(mistral_tokenizer):
return MistralToolParser(mistral_tokenizer)
def assert_tool_calls(
actual_tool_calls: list[ToolCall] | list[DeltaToolCall],
expected_tool_calls: list[ToolCall],
):
"""Assert that actual tool calls match expected ones."""
assert len(actual_tool_calls) == len(expected_tool_calls), (
f"Expected {len(expected_tool_calls)} tool calls, got {len(actual_tool_calls)}"
)
for actual, expected in zip(actual_tool_calls, expected_tool_calls):
# Check ID format
assert isinstance(actual.id, str), (
f"Tool call ID should be string, got {type(actual.id)}"
)
assert len(actual.id) == 9, (
f"Tool call ID should be 9 chars, got {len(actual.id)}"
)
assert actual.id.isalnum(), (
f"Tool call ID should be alphanumeric, got {actual.id}"
)
# Check function
assert actual.function is not None
# Handle both Pydantic model and dict-like access
func = actual.function
actual_name = getattr(func, "name", None) or (
func.get("name") if isinstance(func, dict) else None
)
actual_args = getattr(func, "arguments", None) or (
func.get("arguments") if isinstance(func, dict) else None
)
assert actual_name == expected.function.name, (
f"Expected function name '{expected.function.name}', got '{actual_name}'"
)
assert actual_args == expected.function.arguments, (
f"Expected arguments '{expected.function.arguments}', got '{actual_args}'"
)
def fix_tool_call_tokenization(
tokens: list[int],
mistral_tool_parser: MistralToolParser,
mistral_tokenizer: TokenizerLike,
) -> list[int]:
"""
Replace textual token sequences for special tokens with their IDs.
This is needed because encoding free text may produce the textual tokens
for "[TOOL_CALLS]", "[ARGS]", etc. rather than the single special token.
"""
# Build mapping of textual sequences to special token IDs
replacements: list[tuple[list[int], int]] = []
# [TOOL_CALLS] token
textual_tool_call_ids = mistral_tokenizer.encode(
text=mistral_tool_parser.bot_token,
add_special_tokens=False,
)
if mistral_tool_parser.bot_token_id is not None:
replacements.append((textual_tool_call_ids, mistral_tool_parser.bot_token_id))
# [ARGS] token
if mistral_tool_parser._args_token_id is not None:
textual_args_ids = mistral_tokenizer.encode(
text="[ARGS]",
add_special_tokens=False,
)
replacements.append((textual_args_ids, mistral_tool_parser._args_token_id))
if not tokens or not replacements:
return tokens
result_tokens = list(tokens)
# Apply each replacement (longest first to avoid partial matches)
replacements.sort(key=lambda x: -len(x[0]))
for textual_ids, special_id in replacements:
target_len = len(textual_ids)
new_result = []
i = 0
while i < len(result_tokens):
if result_tokens[i : i + target_len] == textual_ids:
new_result.append(special_id)
i += target_len
else:
new_result.append(result_tokens[i])
i += 1
result_tokens = new_result
return result_tokens
def stream_delta_message_generator(
mistral_tool_parser: MistralToolParser,
mistral_tokenizer: TokenizerLike,
tools: list[tuple[str, str]],
) -> Generator[DeltaMessage, None, None]:
"""
Generate streaming delta messages by tokenizing and processing one token at a time.
Uses encode_instruct to get proper tokenization with special tokens.
"""
assert isinstance(mistral_tokenizer, MistralTokenizer)
# Use encode_instruct to get proper special tokens
assistant_msg = AssistantMessage(
tool_calls=[
ToolCall(
function=MistralFunctionCall(
name=name,
arguments=arg,
)
)
for (name, arg) in tools
],
)
request = InstructRequest(messages=[assistant_msg])
all_token_ids = mistral_tokenizer.instruct.encode_instruct(request).tokens
# Stream tokens one at a time
previous_text = ""
previous_tokens = None
prefix_offset = 0
read_offset = 0
for i, delta_token in enumerate(all_token_ids):
delta_token_ids = [delta_token]
previous_token_ids = all_token_ids[:i]
current_token_ids = all_token_ids[: i + 1]
# Detokenize incrementally
(new_tokens, delta_text, new_prefix_offset, new_read_offset) = (
detokenize_incrementally(
tokenizer=mistral_tokenizer,
all_input_ids=current_token_ids,
prev_tokens=previous_tokens,
prefix_offset=prefix_offset,
read_offset=read_offset,
skip_special_tokens=True,
spaces_between_special_tokens=True,
)
)
current_text = previous_text + delta_text
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text,
current_text,
delta_text,
previous_token_ids,
current_token_ids,
delta_token_ids,
request=None, # type: ignore[arg-type]
)
if delta_message:
yield delta_message
previous_text = current_text
previous_tokens = (
previous_tokens + new_tokens if previous_tokens else new_tokens
)
prefix_offset = new_prefix_offset
read_offset = new_read_offset
# =============================================================================
# Non-streaming extraction tests
# =============================================================================
class TestExtractToolCallsNoTools:
"""Test extraction when no tools are called."""
def test_no_tool_call_token(self, mistral_tool_parser):
model_output = "This is a test response without any tool calls."
result = mistral_tool_parser.extract_tool_calls(model_output, request=None)
assert not result.tools_called
assert result.tool_calls == []
assert result.content == model_output
class TestExtractToolCallsV11Plus:
"""Test non-streaming extraction for v11+ tokenizers."""
@pytest.mark.parametrize(
"model_output,expected_tool_calls,expected_content",
[
# Single tool (v11+ format: name{args})
(
'[TOOL_CALLS]add_this_and_that{"a": 3.5, "b": 4}',
[
ToolCall(
function=MistralFunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
None,
),
# Weather tool
(
"[TOOL_CALLS]get_current_weather"
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}',
[
ToolCall(
function=MistralFunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "San Francisco",
"state": "CA",
"unit": "celsius",
}
),
)
)
],
None,
),
# Multiple tool calls
(
'[TOOL_CALLS]add{"a": 3.5, "b": 4}[TOOL_CALLS]multiply{"a": 3, "b": 6}',
[
ToolCall(
function=MistralFunctionCall(
name="add",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
),
ToolCall(
function=MistralFunctionCall(
name="multiply",
arguments=json.dumps({"a": 3, "b": 6}),
)
),
],
None,
),
],
ids=["single_tool_add", "single_tool_weather", "multiple_tool_calls"],
)
def test_extract_tool_calls(
self,
mistral_tool_parser,
model_output,
expected_tool_calls,
expected_content,
):
result = mistral_tool_parser.extract_tool_calls(model_output, request=None)
assert result.tools_called
assert_tool_calls(result.tool_calls, expected_tool_calls)
assert result.content == expected_content
# =============================================================================
# Streaming extraction tests
# =============================================================================
def _test_extract_tool_calls_streaming(
tool_parser,
tokenizer,
tools,
expected_tool_calls,
expected_content,
):
"""
Helper function to test streaming extraction.
Collects all streamed deltas and verifies the final result matches expected.
"""
other_content: str = ""
function_names: list[str] = []
function_args_strs: list[str] = []
tool_call_idx: int = -1
tool_call_ids: list[str | None] = []
for delta_message in stream_delta_message_generator(tool_parser, tokenizer, tools):
# Role should never be streamed from tool parser
assert not delta_message.role
if delta_message.content:
other_content += delta_message.content
streamed_tool_calls = delta_message.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# Only one tool call delta per message
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
# Verify prev_tool_call_arr is set (for finish_reason detection)
assert len(tool_parser.prev_tool_call_arr) > 0
# If new tool, set up tracking
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
function_args_strs.append("")
tool_call_ids.append(None)
# Track tool call ID (should be set once per tool)
if tool_call.id and not tool_call_ids[tool_call.index]:
tool_call_ids[tool_call.index] = tool_call.id
# Track function parts
if tool_call.function:
# DeltaFunctionCall may be a Pydantic model or dict-like
func = tool_call.function
func_name = getattr(func, "name", None) or (
func.get("name") if isinstance(func, dict) else None
)
func_args = getattr(func, "arguments", None) or (
func.get("arguments") if isinstance(func, dict) else None
)
if func_name:
function_names.append(func_name)
if func_args:
function_args_strs[tool_call.index] += func_args
# Verify content
assert other_content == expected_content
# Build actual tool calls from collected data
actual_tool_calls = [
ToolCall(
id=tool_call_id,
function=MistralFunctionCall(
name=function_name,
arguments=partial_json_parser.ensure_json(
function_args_str, Allow.OBJ | Allow.STR
),
),
)
for tool_call_id, function_name, function_args_str in zip(
tool_call_ids, function_names, function_args_strs
)
]
assert_tool_calls(actual_tool_calls, expected_tool_calls)
class TestStreamingExtractionV11Plus:
"""Test streaming extraction for v11+ tokenizers."""
@pytest.mark.parametrize(
"tools,expected_tool_calls,expected_content",
[
# Single tool
(
[("add", '{"a": 3, "b": 4}')],
[
ToolCall(
function=MistralFunctionCall(
name="add",
arguments=json.dumps({"a": 3, "b": 4}),
)
)
],
"",
),
# String arguments
(
[("add_two_strings", '{"a": "3", "b": "4"}')],
[
ToolCall(
function=MistralFunctionCall(
name="add_two_strings",
arguments=json.dumps({"a": "3", "b": "4"}),
)
)
],
"",
),
# Multiple tools
(
[
("add", '{"a": 3.5, "b": 4}'),
(
"get_current_weather",
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}',
),
],
[
ToolCall(
function=MistralFunctionCall(
name="add",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
),
ToolCall(
function=MistralFunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "San Francisco",
"state": "CA",
"unit": "celsius",
}
),
)
),
],
"",
),
],
ids=["single_tool_add", "single_tool_add_strings", "multiple_tools"],
)
def test_streaming_extraction(
self,
mistral_tool_parser,
mistral_tokenizer,
tools,
expected_tool_calls,
expected_content,
):
_test_extract_tool_calls_streaming(
mistral_tool_parser,
mistral_tokenizer,
tools,
expected_tool_calls,
expected_content,
)
class TestStreamingOneChunk:
"""Test streaming when all tokens arrive in a single chunk."""
@pytest.mark.parametrize(
"model_output,expected_tool_calls,expected_content",
[
# Single tool - v11 format includes [ARGS] between name and JSON
(
'[TOOL_CALLS]add_this_and_that[ARGS]{"a": 3.5, "b": 4}',
[
ToolCall(
function=MistralFunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
"",
),
# Weather tool
(
"[TOOL_CALLS]get_current_weather[ARGS]"
'{"city": "San Francisco", "state": "CA", "unit": "celsius"}',
[
ToolCall(
function=MistralFunctionCall(
name="get_current_weather",
arguments=json.dumps(
{
"city": "San Francisco",
"state": "CA",
"unit": "celsius",
}
),
)
)
],
"",
),
# Multiple tools - NOTE: This case is tricky because BPE tokenization
# may merge the closing } of the first tool with the [ of the next
# [TOOL_CALLS], making it hard to detect the second tool call when
# encoding from free text. In real inference, the model generates
# special tokens directly, avoiding this issue.
pytest.param(
'[TOOL_CALLS]add[ARGS]{"a": 3.5, "b": 4}'
'[TOOL_CALLS]multiply[ARGS]{"a": 3, "b": 6}',
[
ToolCall(
function=MistralFunctionCall(
name="add",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
),
ToolCall(
function=MistralFunctionCall(
name="multiply",
arguments=json.dumps({"a": 3, "b": 6}),
)
),
],
"",
marks=pytest.mark.xfail(reason="BPE tokenization merges }[ tokens"),
),
# Content before tool
(
'bla[TOOL_CALLS]add_this_and_that[ARGS]{"a": 3.5, "b": 4}',
[
ToolCall(
function=MistralFunctionCall(
name="add_this_and_that",
arguments=json.dumps({"a": 3.5, "b": 4}),
)
)
],
"bla",
),
],
ids=[
"single_tool_add",
"single_tool_weather",
"multiple_tool_calls",
"content_before_tool",
],
)
def test_streaming_one_chunk_v11(
self,
mistral_tool_parser,
mistral_tokenizer,
model_output,
expected_tool_calls,
expected_content,
):
"""Test v11+ streaming with all tokens in one chunk.
When all tokens arrive at once, we still produce streaming-style
output with multiple DeltaToolCall objects. We need to aggregate
these to verify the final result.
"""
if isinstance(mistral_tokenizer, MistralTokenizer):
all_token_ids = mistral_tokenizer.encode(model_output)
else:
all_token_ids = mistral_tokenizer.encode(
model_output, add_special_tokens=False
)
all_token_ids = fix_tool_call_tokenization(
all_token_ids, mistral_tool_parser, mistral_tokenizer
)
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=model_output,
delta_text=model_output,
previous_token_ids=[],
current_token_ids=all_token_ids,
delta_token_ids=all_token_ids,
request=None,
)
assert isinstance(delta_message, DeltaMessage)
# Aggregate streaming deltas into final tool calls
# Each tool call starts with a name delta, followed by argument deltas
tool_call_data: dict[int, dict] = {} # index -> {id, name, arguments}
for tc in delta_message.tool_calls or []:
idx = tc.index
if idx not in tool_call_data:
tool_call_data[idx] = {"id": None, "name": "", "arguments": ""}
if tc.id:
tool_call_data[idx]["id"] = tc.id
func = tc.function
func_name = getattr(func, "name", None) or (
func.get("name") if isinstance(func, dict) else None
)
func_args = getattr(func, "arguments", None) or (
func.get("arguments") if isinstance(func, dict) else None
)
if func_name:
tool_call_data[idx]["name"] = func_name
if func_args:
tool_call_data[idx]["arguments"] += func_args
# Verify we got the expected number of tool calls
assert len(tool_call_data) == len(expected_tool_calls), (
f"Expected {len(expected_tool_calls)} tool calls, got {len(tool_call_data)}"
)
# Verify each tool call
for i, expected in enumerate(expected_tool_calls):
actual = tool_call_data[i]
assert actual["name"] == expected.function.name, (
f"Expected name '{expected.function.name}', got '{actual['name']}'"
)
assert actual["arguments"] == expected.function.arguments, (
f"Expected args '{expected.function.arguments}', "
f"got '{actual['arguments']}'"
)
assert actual["id"] is not None and len(actual["id"]) == 9
# Verify content before tool calls is preserved
actual_content = delta_message.content or ""
assert actual_content == expected_content, (
f"Expected content '{expected_content}', got '{actual_content}'"
)
# =============================================================================
# Edge case tests
# =============================================================================
class TestEdgeCases:
"""Test edge cases and error handling."""
def test_tool_call_id_format(self, mistral_tool_parser):
"""Verify generated tool call IDs are valid."""
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall,
)
for _ in range(100):
tool_id = MistralToolCall.generate_random_id()
assert len(tool_id) == 9
assert tool_id.isalnum()
assert MistralToolCall.is_valid_id(tool_id)
def test_invalid_tool_id_validation(self):
"""Test tool ID validation."""
from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import (
MistralToolCall,
)
assert not MistralToolCall.is_valid_id("")
assert not MistralToolCall.is_valid_id("12345678") # Too short
assert not MistralToolCall.is_valid_id("1234567890") # Too long
assert not MistralToolCall.is_valid_id("abc-def-g") # Contains hyphen
def test_empty_model_output(self, mistral_tool_parser):
"""Test handling of empty output."""
result = mistral_tool_parser.extract_tool_calls("", request=None)
assert not result.tools_called
assert result.tool_calls == []
assert result.content == ""
def test_raw_control_chars_escaped(self, mistral_tool_parser):
"""Test that raw control characters in arguments are escaped.
Models sometimes emit literal newlines/tabs inside JSON strings.
These must be escaped to produce valid JSON output.
"""
# Arguments with literal newline (not the escaped \n sequence)
model_output = '[TOOL_CALLS]run_agent{"prompt": "line1\nline2\ttabbed"}'
result = mistral_tool_parser.extract_tool_calls(model_output, request=None)
assert result.tools_called
assert len(result.tool_calls) == 1
# The raw newline/tab should be escaped
args = result.tool_calls[0].function.arguments
assert "\\n" in args
assert "\\t" in args
assert "\n" not in args # No literal newline
assert "\t" not in args # No literal tab
def test_already_escaped_chars_unchanged(self, mistral_tool_parser):
"""Test that already-escaped sequences are not double-escaped."""
# Arguments with properly escaped \n (two chars: backslash + n)
model_output = r'[TOOL_CALLS]run_agent{"prompt": "line1\nline2"}'
result = mistral_tool_parser.extract_tool_calls(model_output, request=None)
assert result.tools_called
assert len(result.tool_calls) == 1
args = result.tool_calls[0].function.arguments
# Should still be \n, not \\n
assert "\\n" in args
assert "\\\\n" not in args
class TestTokenBasedDetection:
"""Test that token-based detection works correctly."""
def test_bot_token_id_exists(self, mistral_tool_parser):
"""Verify bot token ID is properly set."""
assert mistral_tool_parser.bot_token_id is not None
assert isinstance(mistral_tool_parser.bot_token_id, int)
def test_args_token_id_exists(self, mistral_tool_parser):
"""Verify [ARGS] token ID is properly set for v11+."""
assert mistral_tool_parser._args_token_id is not None
assert isinstance(mistral_tool_parser._args_token_id, int)
def test_streaming_uses_token_ids(self, mistral_tool_parser, mistral_tokenizer):
"""Test that streaming correctly uses token IDs for detection."""
# Content without tool call
content_text = "Hello, how can I help you?"
content_tokens = mistral_tokenizer.encode(
content_text, add_special_tokens=False
)
delta_message = mistral_tool_parser.extract_tool_calls_streaming(
previous_text="",
current_text=content_text,
delta_text=content_text,
previous_token_ids=[],
current_token_ids=content_tokens,
delta_token_ids=content_tokens,
request=None,
)
assert delta_message is not None
assert delta_message.content == content_text
assert not delta_message.tool_calls
class TestParserInitialization:
"""Test parser initialization and validation."""
def test_rejects_pre_v11_tokenizer(self):
"""Test that parser rejects pre-v11 MistralTokenizer."""
# Get a pre-v11 MistralTokenizer (Mistral-7B-v0.3 uses version 3)
pre_v11_tokenizer = get_tokenizer(
tokenizer_name="mistralai/Mistral-7B-Instruct-v0.3",
tokenizer_mode="mistral",
)
assert isinstance(pre_v11_tokenizer, MistralTokenizer)
assert pre_v11_tokenizer.version < 11
with pytest.raises(
RuntimeError, match="requires tokenizer version 11 or higher"
):
MistralToolParser(pre_v11_tokenizer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment