Skip to content

Instantly share code, notes, and snippets.

@juanqui
Created April 25, 2026 16:54
Show Gist options
  • Select an option

  • Save juanqui/ad80d1dcdfd69f4b3fb6f92c0e41e393 to your computer and use it in GitHub Desktop.

Select an option

Save juanqui/ad80d1dcdfd69f4b3fb6f92c0e41e393 to your computer and use it in GitHub Desktop.
vLLM patch for Qwen3.5/3.6 tool-call & reasoning parser fixes (8 PRs + 1 local wire fix). Companion image: ghcr.io/juanqui/vllm-qwen35-toolfix:0.19.1rc1.dev328-g18013df6a
# vllm-qwen35-toolfix.patch
#
# What this patches:
# 8 vLLM source files (~1.3 KLOC of changes) baked into the local image
# ghcr.io/juanqui/vllm-qwen35-toolfix:0.19.1rc1.dev328-g18013df6a, on top
# of the upstream vllm/vllm-openai:cu130-nightly base.
#
# Origin (per /mnt/nvme2/experiments/2026-04-23-qwen36-27b-fp8-bench/journal/{14,15}*.md):
# - PR #35687 min_new_tokens allow-list -> config/model.py,
# entrypoints/openai/{chat_completion,completion}/protocol.py
# - PR #40783 qwen3 reasoning parser rewrite -> reasoning/qwen3_reasoning_parser.py,
# parser/abstract_parser.py
# - PR #40785 qwen3_coder rewrite -> tool_parsers/qwen3coder_tool_parser.py
# - PR #40787 qwen3_xml rewrite -> tool_parsers/qwen3xml_tool_parser.py
# - PR #36224 (partial) is_reasoning_end window + <function= implicit end
# -> serving.py, qwen3_reasoning_parser.py
# - PR #35982 enable_thinking=false stream -> qwen3_reasoning_parser.py
# - PR #38996 Python "None" vs JSON "null" -> qwen3coder_tool_parser.py,
# qwen3xml_tool_parser.py
# - PR #38890 XML logger format mismatch -> qwen3xml_tool_parser.py
# - PR #39055 rescue <tool_call> in <think> -> qwen3_reasoning_parser.py
# - local wire-level exclude_none=True -> entrypoints/openai/chat_completion/serving.py
#
# Not in this diff (separate Dockerfile-level changes):
# - pip install instanttensor==0.1.8 (runtime helper, not a source patch)
#
# How to apply:
# From a vllm checkout root: patch -p1 < vllm-qwen35-toolfix.patch
# Or: git apply vllm-qwen35-toolfix.patch
#
# Baseline: vllm/vllm-openai:cu130-nightly (locally tagged 2026-04-XX, may
# drift from upstream HEAD; expect minor fuzz against newer commits).
#
# Generated: 2026-04-25 from vllm-openai-it-minfix4:local
#
diff -urN a/vllm/config/model.py b/vllm/config/model.py
--- a/vllm/config/model.py 2026-04-23 00:18:59.000000000 -0500
+++ b/vllm/config/model.py 2026-04-24 12:21:21.000000000 -0500
@@ -1424,6 +1424,7 @@
"top_p",
"min_p",
"max_new_tokens",
+ "min_new_tokens",
]
if any(p in config for p in available_params):
diff_sampling_param = {
@@ -1435,6 +1436,10 @@
diff_sampling_param["max_tokens"] = diff_sampling_param.pop(
"max_new_tokens"
)
+ if "min_new_tokens" in diff_sampling_param:
+ diff_sampling_param["min_tokens"] = diff_sampling_param.pop(
+ "min_new_tokens"
+ )
else:
diff_sampling_param = {}
diff -urN a/vllm/entrypoints/openai/chat_completion/protocol.py b/vllm/entrypoints/openai/chat_completion/protocol.py
--- a/vllm/entrypoints/openai/chat_completion/protocol.py 2026-04-23 00:18:59.000000000 -0500
+++ b/vllm/entrypoints/openai/chat_completion/protocol.py 2026-04-24 12:21:21.000000000 -0500
@@ -534,6 +534,9 @@
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
+ _min_tokens = self.min_tokens
+ if _min_tokens == 0:
+ _min_tokens = default_sampling_params.get("min_tokens", 0)
return SamplingParams.from_optional(
n=self.n,
presence_penalty=self.presence_penalty,
@@ -550,7 +553,7 @@
prompt_logprobs=prompt_logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens,
- min_tokens=self.min_tokens,
+ min_tokens=_min_tokens,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
diff -urN a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py
--- a/vllm/entrypoints/openai/chat_completion/serving.py 2026-04-23 00:18:59.000000000 -0500
+++ b/vllm/entrypoints/openai/chat_completion/serving.py 2026-04-24 18:52:08.000000000 -0500
@@ -189,18 +189,6 @@
)
)
- def _effective_chat_template_kwargs(
- self, request: ChatCompletionRequest
- ) -> dict[str, Any]:
- return (
- request.build_chat_params(
- self.chat_template,
- self.chat_template_content_format,
- )
- .with_defaults(self.default_chat_template_kwargs)
- .chat_template_kwargs
- )
-
async def render_chat_request(
self,
request: ChatCompletionRequest,
@@ -243,7 +231,10 @@
# Streaming response
tokenizer = self.renderer.tokenizer
assert tokenizer is not None
- chat_template_kwargs = self._effective_chat_template_kwargs(request)
+ chat_template_kwargs = self._prepare_extra_chat_template_kwargs(
+ request.chat_template_kwargs,
+ self.default_chat_template_kwargs,
+ )
reasoning_parser: ReasoningParser | None = None
if self.reasoning_parser_cls:
reasoning_parser = self.reasoning_parser_cls(
@@ -334,9 +325,21 @@
# non-reasoning outputs.
reasoning_ended = True
elif reasoning_parser:
- reasoning_ended = reasoning_parser.is_reasoning_end(
- prompt_token_ids or []
- )
+ # PR #36224: only check the tail of the prompt to avoid
+ # false positives from tool definitions (which contain
+ # <tool_call> in the schema) or prior-turn </think>
+ # tokens in multi-turn history. Use fast single-token
+ # check when available; else fall back to is_reasoning_end
+ # on a 20-token tail.
+ pti = prompt_token_ids or []
+ if hasattr(reasoning_parser, "end_token_id"):
+ reasoning_ended = (
+ reasoning_parser.end_token_id in pti[-10:]
+ )
+ else:
+ reasoning_ended = reasoning_parser.is_reasoning_end(
+ pti[-20:]
+ )
else:
reasoning_ended = None
@@ -566,20 +569,6 @@
and self._should_stream_with_auto_tool_parsing(request)
)
- # Determine whether required/named tool_choice should fall back to
- # the auto tool_parser path instead of the standard JSON-based parsing.
- # This happens when the parser declares supports_required_and_named=False
- # (e.g. GLM models that output XML instead of JSON).
- tool_choice_uses_parser = (
- self.tool_parser is not None
- and not self.tool_parser.supports_required_and_named
- and request.tools
- and (
- request.tool_choice == "required"
- or isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam)
- )
- )
-
all_previous_token_ids: list[list[int]] | None
function_name_returned = [False] * num_choices
if self.tool_call_id_type == "kimi_k2":
@@ -592,12 +581,7 @@
# Only one of these will be used, thus previous_texts and
# all_previous_token_ids will not be used twice in the same iteration.
- if (
- is_mistral_grammar_path
- or tool_choice_auto
- or tool_choice_uses_parser
- or reasoning_parser
- ):
+ if is_mistral_grammar_path or tool_choice_auto or reasoning_parser:
# These are only required in "auto" tool choice case
all_previous_token_ids = [[] for _ in range(num_choices)]
reasoning_end_arr = [False] * num_choices
@@ -684,7 +668,7 @@
total_tokens=num_prompt_tokens,
)
- data = chunk.model_dump_json(exclude_unset=True)
+ data = chunk.model_dump_json(exclude_unset=True, exclude_none=True)
yield f"data: {data}\n\n"
# Send response to echo the input portion of the
@@ -720,7 +704,7 @@
total_tokens=num_prompt_tokens,
)
- data = chunk.model_dump_json(exclude_unset=True)
+ data = chunk.model_dump_json(exclude_unset=True, exclude_none=True)
yield f"data: {data}\n\n"
first_iteration = False
@@ -735,10 +719,22 @@
and prompt_is_reasoning_end_arr[i] is None
):
# only check once per choice, because prompt_token_ids
- # are the same for all deltas in that choice
- prompt_is_reasoning_end_arr[i] = (
- reasoning_parser.is_reasoning_end(res.prompt_token_ids)
- )
+ # are the same for all deltas in that choice.
+ # PR #36224: narrow window to last 10 tokens (fast
+ # single-token check when available) to avoid false
+ # positives from prior-turn </think> in multi-turn
+ # history or <tool_call> in tool definitions.
+ if hasattr(reasoning_parser, "end_token_id"):
+ prompt_is_reasoning_end_arr[i] = (
+ reasoning_parser.end_token_id
+ in res.prompt_token_ids[-10:]
+ )
+ else:
+ prompt_is_reasoning_end_arr[i] = (
+ reasoning_parser.is_reasoning_end(
+ res.prompt_token_ids[-20:]
+ )
+ )
if finish_reason_sent[i]:
continue
@@ -792,12 +788,7 @@
delta_message: DeltaMessage | None
# just update previous_texts and previous_token_ids
- if (
- is_mistral_grammar_path
- or tool_choice_auto
- or tool_choice_uses_parser
- or reasoning_parser
- ):
+ if is_mistral_grammar_path or tool_choice_auto or reasoning_parser:
assert previous_texts is not None
assert all_previous_token_ids is not None
previous_text = previous_texts[i]
@@ -846,9 +837,7 @@
if result.tools_called:
tools_streamed[i] = True
# handle streaming deltas for tools with named tool_choice
- # Skip when tool_choice_uses_parser so it falls through
- # to the auto tool_parser branches below.
- elif tool_choice_function_name and not tool_choice_uses_parser:
+ elif tool_choice_function_name:
# When encountering think end id in prompt_token_ids
# i.e {"enable_thinking": False},
# check BEFORE calling the parser to avoid a spurious
@@ -886,6 +875,7 @@
):
reasoning_end_arr[i] = True
if delta_message and delta_message.content:
+ # This need to be added to next `delta_text`
current_text = delta_message.content
delta_message.content = None
else:
@@ -930,12 +920,7 @@
)
tools_streamed[i] = True
- # Skip when tool_choice_uses_parser so it falls through
- # to the auto tool_parser branches below.
- elif (
- request.tool_choice == "required"
- and not tool_choice_uses_parser
- ):
+ elif request.tool_choice == "required":
assert previous_texts is not None
previous_text = previous_texts[i]
current_text = previous_text + delta_text
@@ -1005,10 +990,7 @@
# update the previous values for the next iteration
if (
- is_mistral_grammar_path
- or tool_choice_auto
- or tool_choice_uses_parser
- or reasoning_parser
+ is_mistral_grammar_path or tool_choice_auto or reasoning_parser
) and not self.use_harmony:
assert previous_texts is not None
assert all_previous_token_ids is not None
@@ -1199,7 +1181,7 @@
total_tokens=num_prompt_tokens + completion_tokens,
)
- data = chunk.model_dump_json(exclude_unset=True)
+ data = chunk.model_dump_json(exclude_unset=True, exclude_none=True)
yield f"data: {data}\n\n"
# once the final token is handled, if stream_options.include_usage
diff -urN a/vllm/entrypoints/openai/completion/protocol.py b/vllm/entrypoints/openai/completion/protocol.py
--- a/vllm/entrypoints/openai/completion/protocol.py 2026-04-23 00:18:59.000000000 -0500
+++ b/vllm/entrypoints/openai/completion/protocol.py 2026-04-24 12:21:21.000000000 -0500
@@ -293,6 +293,9 @@
if self.kv_transfer_params:
# Pass in kv_transfer_params via extra_args
extra_args["kv_transfer_params"] = self.kv_transfer_params
+ _min_tokens = self.min_tokens
+ if _min_tokens == 0:
+ _min_tokens = default_sampling_params.get("min_tokens", 0)
return SamplingParams.from_optional(
n=self.n,
presence_penalty=self.presence_penalty,
@@ -308,7 +311,7 @@
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=max_tokens if not echo_without_generation else 1,
- min_tokens=self.min_tokens,
+ min_tokens=_min_tokens,
prompt_logprobs=prompt_logprobs,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=self.spaces_between_special_tokens,
diff -urN a/vllm/parser/abstract_parser.py b/vllm/parser/abstract_parser.py
--- a/vllm/parser/abstract_parser.py 2026-04-23 00:18:59.000000000 -0500
+++ b/vllm/parser/abstract_parser.py 2026-04-24 16:55:50.000000000 -0500
@@ -563,6 +563,13 @@
return False
return self._reasoning_parser.is_reasoning_end(input_ids)
+ def is_reasoning_end_streaming(
+ self, input_ids: list[int], delta_ids: list[int]
+ ) -> bool:
+ if self._reasoning_parser is None:
+ return False
+ return self._reasoning_parser.is_reasoning_end_streaming(input_ids, delta_ids)
+
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
if self._reasoning_parser is None:
return input_ids
@@ -610,8 +617,13 @@
current_token_ids=current_token_ids,
delta_token_ids=delta_token_ids,
)
- # Hand off remaining content to tool parser
- if self._tool_parser and self.is_reasoning_end(delta_token_ids):
+ # Hand off remaining content to tool parser.
+ # Use is_reasoning_end_streaming for delta checks: it correctly
+ # detects <tool_call> in the current delta without the
+ # paired-token guard that is_reasoning_end applies for prompts.
+ if self._tool_parser and self.is_reasoning_end_streaming(
+ current_token_ids, delta_token_ids
+ ):
state.reasoning_ended = True
current_token_ids = self.extract_content_ids(delta_token_ids)
if delta_message and delta_message.content:
@@ -628,6 +640,12 @@
state.previous_token_ids = []
delta_text = current_text
delta_token_ids = current_token_ids
+ # Preserve any reasoning text produced by extract_reasoning_streaming
+ # in the same delta as the reasoning→tool-call transition. Without
+ # this, the assignment below would silently drop that last fragment.
+ reasoning_from_transition = (
+ delta_message.reasoning if delta_message is not None else None
+ )
delta_message = self.extract_tool_calls_streaming(
previous_text=state.previous_text,
current_text=current_text,
@@ -637,6 +655,11 @@
delta_token_ids=delta_token_ids,
request=request, # type: ignore[arg-type]
)
+ if reasoning_from_transition:
+ if delta_message is not None:
+ delta_message.reasoning = reasoning_from_transition
+ else:
+ delta_message = DeltaMessage(reasoning=reasoning_from_transition)
# No parsers: pass through as content
if self._reasoning_parser is None and self._tool_parser is None:
diff -urN a/vllm/reasoning/qwen3_reasoning_parser.py b/vllm/reasoning/qwen3_reasoning_parser.py
--- a/vllm/reasoning/qwen3_reasoning_parser.py 2026-04-23 00:18:59.000000000 -0500
+++ b/vllm/reasoning/qwen3_reasoning_parser.py 2026-04-24 18:57:24.000000000 -0500
@@ -1,11 +1,24 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-from collections.abc import Sequence
+import re
+from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING
from vllm.entrypoints.openai.engine.protocol import DeltaMessage
from vllm.reasoning.basic_parsers import BaseThinkingReasoningParser
+from vllm.tool_parsers.utils import partial_tag_overlap
+
+
+# PR #39055: matches a complete <tool_call>...</tool_call> block, or a
+# truncated trailing <tool_call>... at the end of reasoning. Used by
+# _split_embedded_tool_calls() to promote XML tool calls out of the
+# reasoning channel into content when the model emits them inside the
+# <think> block.
+_EMBEDDED_TOOL_CALL_RE = re.compile(
+ r"<tool_call>(.*?)</tool_call>|<tool_call>.*$",
+ re.DOTALL,
+)
if TYPE_CHECKING:
from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
@@ -31,8 +44,23 @@
use an older chat template where the model generates <think> itself.
This parser handles both styles: if <think> appears in the generated output
it is stripped before extraction (non-streaming) or skipped (streaming).
+
+ NOTE: Qwen3.5 models may emit <tool_call> inside the thinking block
+ without closing </think> first. <tool_call> is treated as an implicit
+ end of reasoning, matching the approach in KimiK2ReasoningParser.
+
+ PR #36224 also recognizes <function= (text-level) as a secondary
+ implicit end-of-reasoning boundary, since the model sometimes omits
+ the <tool_call> wrapper and emits <function=name> directly inside
+ <think>. Without this the entire tool call XML (including
+ </parameter> closing tags) gets streamed as reasoning content.
"""
+ # PR #36224: secondary implicit end-of-reasoning boundary.
+ # <function= is not a single special token, so this must be matched
+ # at text level rather than token id.
+ _FUNCTION_PREFIX = "<function="
+
def __init__(self, tokenizer: "TokenizerLike", *args, **kwargs):
super().__init__(tokenizer, *args, **kwargs)
@@ -41,6 +69,63 @@
# pure content when the user explicitly disables it.
self.thinking_enabled = chat_kwargs.get("enable_thinking", True)
+ self._tool_call_tag = "<tool_call>"
+ self._tool_call_token_id = self.vocab.get(self._tool_call_tag)
+ self._tool_call_end_tag = "</tool_call>"
+ self._tool_call_end_token_id = self.vocab.get(self._tool_call_end_tag)
+
+ # PR #36224: set when <function= is detected in streaming output as
+ # an implicit boundary. is_reasoning_end() returns True afterwards
+ # so the serving layer transitions to tool parsing.
+ self._function_prefix_ended: bool = False
+
+ def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
+ # Augment the base class check with the text-level <function=
+ # boundary (PR #36224). The base class only inspects token IDs.
+ if super().is_reasoning_end(input_ids):
+ return True
+ return self._function_prefix_ended
+
+ @staticmethod
+ def _split_embedded_tool_calls(
+ reasoning: str | None,
+ content: str | None,
+ ) -> tuple[str | None, str | None]:
+ """PR #39055 — promote XML tool-call blocks out of reasoning into
+ content. Qwen3.5/3.6 sometimes emit a complete
+ <tool_call>...</tool_call> inside <think>; the downstream tool
+ parser only inspects the content channel, so without this fix
+ the call is lost. Only blocks that contain <function= are moved
+ (so prose mentioning the tag literally isn't promoted).
+ """
+ if (
+ not reasoning
+ or "<tool_call>" not in reasoning
+ or "<function=" not in reasoning
+ ):
+ return reasoning, content
+
+ extracted_blocks: list[str] = []
+
+ def _collect_or_keep(match: "re.Match[str]") -> str:
+ block = match.group(0)
+ if "<function=" not in block:
+ return block
+ extracted_blocks.append(block.strip())
+ return ""
+
+ remaining_reasoning = _EMBEDDED_TOOL_CALL_RE.sub(_collect_or_keep, reasoning)
+ remaining_reasoning = remaining_reasoning.strip() or None
+
+ if not extracted_blocks:
+ return reasoning, content
+
+ content_parts = ["\n\n".join(extracted_blocks)]
+ if content:
+ content_parts.append(content)
+ merged_content = "\n\n".join(part for part in content_parts if part) or None
+ return remaining_reasoning, merged_content
+
@property
def start_token(self) -> str:
"""The token that starts reasoning content."""
@@ -51,6 +136,57 @@
"""The token that ends reasoning content."""
return "</think>"
+ def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
+ start_token_id = self.start_token_id
+ end_token_id = self.end_token_id
+ tool_call_token_id = self._tool_call_token_id
+ tool_call_end_token_id = self._tool_call_end_token_id
+
+ for i in range(len(input_ids) - 1, -1, -1):
+ token_id = input_ids[i]
+ if token_id == start_token_id:
+ return False
+ if token_id == end_token_id:
+ return True
+ if tool_call_token_id is not None and token_id == tool_call_token_id:
+ # Skip <tool_call> tokens that are paired with a subsequent
+ # </tool_call> — these appear in system-prompt tool examples
+ # and must not be mistaken for an implicit reasoning end.
+ # Unpaired <tool_call> (model output) still signals the end.
+ if tool_call_end_token_id is not None and any(
+ input_ids[j] == tool_call_end_token_id
+ for j in range(i + 1, len(input_ids))
+ ):
+ continue
+ return True
+ return False
+
+ def is_reasoning_end_streaming(
+ self, input_ids: Sequence[int], delta_ids: Iterable[int]
+ ) -> bool:
+ if super().is_reasoning_end_streaming(input_ids, delta_ids):
+ return True
+ if self._tool_call_token_id is not None:
+ return self._tool_call_token_id in delta_ids
+ return False
+
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
+ """
+ Extract content token ids from the input_ids.
+ """
+ result = super().extract_content_ids(input_ids)
+ if result:
+ return result
+ # Fall back: content starts at the FIRST <tool_call>
+ # (implicit reasoning end).
+ if (
+ self._tool_call_token_id is not None
+ and self._tool_call_token_id in input_ids
+ ):
+ tool_call_index = input_ids.index(self._tool_call_token_id)
+ return input_ids[tool_call_index:]
+ return []
+
def extract_reasoning(
self, model_output: str, request: "ChatCompletionRequest | ResponsesRequest"
) -> tuple[str | None, str | None]:
@@ -78,19 +214,34 @@
model_output_parts[2] if model_output_parts[1] else model_output_parts[0]
)
- if self.end_token not in model_output:
- if not self.thinking_enabled:
- # Thinking explicitly disabled — treat everything as content.
- return None, model_output
- # Thinking enabled but no </think>: output was truncated.
- # Everything generated so far is reasoning.
- return model_output, None
-
- # Extract reasoning content from the model output.
- reasoning, _, content = model_output.partition(self.end_token)
-
- final_content = content or None
- return reasoning, final_content
+ if self.end_token in model_output:
+ reasoning, _, content = model_output.partition(self.end_token)
+ # PR #39055: rescue any complete <tool_call>...</tool_call>
+ # blocks the model emitted inside <think>.
+ return self._split_embedded_tool_calls(
+ reasoning or None, content or None
+ )
+
+ if not self.thinking_enabled:
+ # Thinking explicitly disabled — treat everything as content.
+ return None, model_output
+
+ # No </think> — check for implicit reasoning end via <tool_call>
+ # (single token, fast path) or <function= (PR #36224, text-level
+ # fallback when the model omits the <tool_call> wrapper). Pick
+ # whichever appears first.
+ tool_call_index = model_output.find(self._tool_call_tag)
+ function_prefix_index = model_output.find(self._FUNCTION_PREFIX)
+ candidates = [i for i in (tool_call_index, function_prefix_index) if i != -1]
+ if candidates:
+ implicit_end = min(candidates)
+ reasoning = model_output[:implicit_end]
+ content = model_output[implicit_end:]
+ return reasoning or None, content or None
+ # Thinking enabled but no </think>: output was truncated.
+ # Everything generated so far is reasoning. PR #39055: still
+ # try to rescue any embedded tool-call blocks from reasoning.
+ return self._split_embedded_tool_calls(model_output, None)
def extract_reasoning_streaming(
self,
@@ -139,9 +290,89 @@
if not delta_text:
# Nothing left after stripping start token.
return None
- elif self.end_token_id in previous_token_ids:
- # End token already passed: everything is content now.
+
+ # If thinking already ended, everything is content.
+ # PR #35982: also force content path when thinking is explicitly disabled
+ # (chat_template_kwargs.enable_thinking=false). Belt-and-braces for the
+ # case where the serving layer's prompt_is_reasoning_end detection misses,
+ # which is what triggers issue #40816 (final answer emitted in
+ # delta.reasoning even when thinking is disabled).
+ if (not self.thinking_enabled or
+ self.end_token_id in previous_token_ids or
+ (self._tool_call_token_id is not None and
+ self._tool_call_token_id in previous_token_ids) or
+ (bool(self._tool_call_tag) and
+ self._tool_call_tag in previous_text)):
return DeltaMessage(content=delta_text)
- else:
- # No end token yet: still in reasoning phase.
- return DeltaMessage(reasoning=delta_text)
+
+ # Implicit reasoning end via <tool_call>.
+ has_tool_call_id = (
+ self._tool_call_token_id is not None
+ and self._tool_call_token_id in delta_token_ids
+ )
+ just_completed_tool_call_tag = (
+ bool(self._tool_call_tag)
+ and self._tool_call_tag in current_text
+ and self._tool_call_tag not in previous_text
+ )
+
+ if has_tool_call_id or just_completed_tool_call_tag:
+ if self._tool_call_tag and self._tool_call_tag in current_text:
+ tag_start_idx = current_text.find(self._tool_call_tag)
+ delta_start_idx = len(previous_text)
+
+ if tag_start_idx >= delta_start_idx:
+ reasoning_len = tag_start_idx - delta_start_idx
+ reasoning = delta_text[:reasoning_len]
+ content = delta_text[reasoning_len:]
+ else:
+ # Part of the tag was already emitted as reasoning.
+ # We MUST emit the full tag as content for the tool parser,
+ # but we avoid emitting it as reasoning in this delta.
+ reasoning = None
+ content = current_text[tag_start_idx:]
+
+ return DeltaMessage(
+ reasoning=reasoning if reasoning else None,
+ content=content if content else None,
+ )
+
+ # PR #36224: <function= as implicit end-of-reasoning when the
+ # model emits the tool-call XML directly without the
+ # <tool_call> wrapper. <function= is multi-token text so we can
+ # only detect it at text level (not token id).
+ if (
+ self._FUNCTION_PREFIX in current_text
+ and self._FUNCTION_PREFIX not in previous_text
+ ):
+ tag_start_idx = current_text.find(self._FUNCTION_PREFIX)
+ delta_start_idx = len(previous_text)
+ self._function_prefix_ended = True
+ if tag_start_idx >= delta_start_idx:
+ reasoning_len = tag_start_idx - delta_start_idx
+ reasoning = delta_text[:reasoning_len]
+ content = delta_text[reasoning_len:]
+ else:
+ reasoning = None
+ content = current_text[tag_start_idx:]
+ return DeltaMessage(
+ reasoning=reasoning if reasoning else None,
+ content=content if content else None,
+ )
+
+ # To avoid leaking fragments of <tool_call> into reasoning,
+ # check for partial overlap.
+ # Before this fix we had the partial tool call tag being (partially)
+ # duplicated in reasoning and in content.
+
+ overlap = partial_tag_overlap(current_text, self._tool_call_tag)
+ if overlap > 0:
+ sendable_reasoning_len = len(delta_text) - overlap
+ if sendable_reasoning_len > 0:
+ return DeltaMessage(reasoning=delta_text[:sendable_reasoning_len])
+ # Return an empty message instead of None to satisfy tests
+ # and indicate that processing is ongoing but no new content is ready.
+ return DeltaMessage()
+
+ # No end token yet: still in reasoning phase.
+ return DeltaMessage(reasoning=delta_text)
diff -urN a/vllm/tool_parsers/qwen3coder_tool_parser.py b/vllm/tool_parsers/qwen3coder_tool_parser.py
--- a/vllm/tool_parsers/qwen3coder_tool_parser.py 2026-04-23 00:18:59.000000000 -0500
+++ b/vllm/tool_parsers/qwen3coder_tool_parser.py 2026-04-24 18:50:06.000000000 -0500
@@ -25,7 +25,7 @@
Tool,
ToolParser,
)
-from vllm.tool_parsers.utils import find_tool_properties
+from vllm.tool_parsers.utils import find_tool_properties, partial_tag_overlap
logger = init_logger(__name__)
@@ -109,13 +109,17 @@
# Store accumulated parameters for type conversion
self.accumulated_params = {}
self.streaming_request = None
+ self._sent_content_idx = 0
+ self.current_tool_index = 0
def _convert_param_value(
self, param_value: str, param_name: str, param_config: dict, func_name: str
) -> Any:
"""Convert parameter value based on its type in the schema."""
- # Handle null value for any type
- if param_value.lower() == "null":
+ # Handle null value for any type. PR #38996: Qwen3.5's chat template
+ # uses Jinja's `| string` filter for scalar tool-call arguments,
+ # producing Python repr ("None") instead of JSON ("null"). Accept both.
+ if param_value.lower() in ("null", "none"):
return None
if param_name not in param_config:
@@ -372,6 +376,22 @@
# Check if this tool call has ended
tool_ends = current_text.count(self.tool_call_end_token)
if tool_ends > self.current_tool_index:
+ # Find the end of the tool call that just finished and update
+ # _sent_content_idx to prevent it from leaking into content.
+ search_idx = 0
+ for _ in range(self.current_tool_index + 1):
+ search_idx = current_text.find(self.tool_call_start_token,
+ search_idx)
+ if search_idx == -1:
+ break
+ end_idx = current_text.find(self.tool_call_end_token,
+ search_idx)
+ if end_idx != -1:
+ self._sent_content_idx = max(
+ self._sent_content_idx,
+ end_idx + len(self.tool_call_end_token))
+ search_idx += len(self.tool_call_start_token)
+
# This tool has ended, advance to next
self.current_tool_index += 1
self.header_sent = False
@@ -380,47 +400,55 @@
self.json_closed = False
self.accumulated_params = {}
- # Check if there are more tool calls
- tool_starts = current_text.count(self.tool_call_start_token)
- if self.current_tool_index >= tool_starts:
- # No more tool calls
- self.is_tool_call_started = False
+ # Always reset is_tool_call_started when a tool call ends.
+ # This allows correctly sending any content between or after
+ # tool calls.
+ self.is_tool_call_started = False
# Continue processing next tool
return None
+ content_message = None
# Handle normal content before tool calls
if not self.is_tool_call_started:
# Check if tool call is starting
+ tool_starts_count = current_text.count(self.tool_call_start_token)
if (
self.tool_call_start_token_id in delta_token_ids
- or self.tool_call_start_token in delta_text
+ or tool_starts_count > self.current_tool_index
):
self.is_tool_call_started = True
# Return any content before the tool call
- if self.tool_call_start_token in delta_text:
- content_before = delta_text[
- : delta_text.index(self.tool_call_start_token)
- ]
+ last_start = current_text.find(self.tool_call_start_token, self._sent_content_idx)
+ if last_start != -1 and last_start > self._sent_content_idx:
+ content_before = current_text[self._sent_content_idx:last_start]
+ self._sent_content_idx = last_start
if content_before:
- return DeltaMessage(content=content_before)
- return None
+ content_message = DeltaMessage(content=content_before)
else:
+ overlap = partial_tag_overlap(current_text, self.tool_call_start_token)
+ sendable_idx = len(current_text) - overlap
+
# Check if we're between tool calls - skip whitespace
if (
current_text.rstrip().endswith(self.tool_call_end_token)
and delta_text.strip() == ""
):
# We just ended a tool call, skip whitespace
+ self._sent_content_idx = len(current_text)
return None
- # Normal content, no tool call
- return DeltaMessage(content=delta_text)
+
+ if sendable_idx > self._sent_content_idx:
+ content = current_text[self._sent_content_idx:sendable_idx]
+ self._sent_content_idx = sendable_idx
+ if content:
+ return DeltaMessage(content=content)
+ return None
# Check if we're between tool calls (waiting for next one)
# Count tool calls we've seen vs processed
tool_starts_count = current_text.count(self.tool_call_start_token)
if self.current_tool_index >= tool_starts_count:
- # We're past all tool calls, shouldn't be here
- return None
+ return content_message
# We're in a tool call, find the current tool call portion
# Need to find the correct tool call based on current_tool_index
@@ -434,8 +462,7 @@
idx += len(self.tool_call_start_token)
if self.current_tool_index >= len(tool_start_positions):
- # No more tool calls to process yet
- return None
+ return content_message
tool_start_idx = tool_start_positions[self.current_tool_index]
# Find where this tool call ends (or current position if not ended yet)
@@ -447,6 +474,7 @@
tool_start_idx : tool_end_idx + len(self.tool_call_end_token)
]
+ tool_call_fragments = None
# Looking for function header
if not self.header_sent:
if self.tool_call_prefix in tool_text:
@@ -479,21 +507,16 @@
# accesses streamed_args_for_tool[index].
self.streamed_args_for_tool.append("")
- # Send header with function info
- return DeltaMessage(
- tool_calls=[
- DeltaToolCall(
- index=self.current_tool_index,
- id=self.current_tool_id,
- function=DeltaFunctionCall(
- name=self.current_function_name, arguments=""
- ),
- type="function",
- )
- ]
+ tool_call_fragments = DeltaToolCall(
+ index=self.current_tool_index,
+ id=self.current_tool_id,
+ function=DeltaFunctionCall(name=self.current_function_name, arguments=""),
+ type="function",
)
- return None
+ if not self.header_sent:
+ return content_message
+ arguments_to_emit = ""
# We've sent header, now handle function body
if self.in_function:
# Always send opening brace first, regardless of whether
@@ -504,16 +527,8 @@
if not self.json_started:
self.json_started = True
self.streamed_args_for_tool[self.current_tool_index] += "{"
- return DeltaMessage(
- tool_calls=[
- DeltaToolCall(
- index=self.current_tool_index,
- function=DeltaFunctionCall(arguments="{"),
- )
- ]
- )
+ arguments_to_emit += "{"
- # Find all parameter start positions in current tool_text
param_starts = []
search_idx = 0
while True:
@@ -614,15 +629,7 @@
self.current_tool_index,
len(self.streamed_args_for_tool),
)
-
- return DeltaMessage(
- tool_calls=[
- DeltaToolCall(
- index=self.current_tool_index,
- function=DeltaFunctionCall(arguments=combined),
- )
- ]
- )
+ arguments_to_emit += combined
# Check for function end AFTER processing parameters.
# This ordering is critical: with speculative decoding a
@@ -664,20 +671,24 @@
self.current_tool_index,
len(self.streamed_args_for_tool),
)
-
- result = DeltaMessage(
- tool_calls=[
- DeltaToolCall(
- index=self.current_tool_index,
- function=DeltaFunctionCall(arguments="}"),
- )
- ]
- )
-
+ arguments_to_emit += "}"
self.in_function = False
self.json_closed = True
self.accumulated_params = {}
- return result
+ if tool_call_fragments or arguments_to_emit:
+ if not tool_call_fragments:
+ tool_call_fragments = DeltaToolCall(
+ index=self.current_tool_index,
+ function=DeltaFunctionCall(arguments=arguments_to_emit),
+ )
+ else:
+ tool_call_fragments.function.arguments += arguments_to_emit
+
+ if content_message:
+ content_message.tool_calls = [tool_call_fragments]
+ return content_message
+ else:
+ return DeltaMessage(tool_calls=[tool_call_fragments])
- return None
+ return content_message
diff -urN a/vllm/tool_parsers/qwen3xml_tool_parser.py b/vllm/tool_parsers/qwen3xml_tool_parser.py
--- a/vllm/tool_parsers/qwen3xml_tool_parser.py 2026-04-23 00:18:59.000000000 -0500
+++ b/vllm/tool_parsers/qwen3xml_tool_parser.py 2026-04-24 18:50:55.000000000 -0500
@@ -56,6 +56,7 @@
# state for streaming
self.tool_call_index = 0
self.current_call_id = None
+ self.id_emitted = False
self.last_completed_call_id = None
self.current_function_name = None
self.current_function_open = False
@@ -112,58 +113,25 @@
if (
self.current_call_id is not None
and self.function_end_token in xml_chunk
+ and self.current_function_open
):
- # - Added '}' (non-empty parameter ending)
- # - Added '{}' (empty parameter function)
- has_function_close = any(
- (
- td.tool_calls
- and any(
- (
- tc.function
- and tc.id == self.current_call_id
- and isinstance(tc.function.arguments, str)
- and (tc.function.arguments in ("}", "{}"))
- )
- for tc in td.tool_calls
- )
- )
- for td in new_deltas
- )
- if not has_function_close:
- # Close potentially unclosed element
- if self.current_param_name:
- self._end_element("parameter")
- if self.current_function_name:
- self._end_element("function")
+ # Close potentially unclosed element
+ if self.current_param_name:
+ self._end_element("parameter")
+ if self.current_function_name:
+ self._end_element("function")
# If this chunk contains </tool_call>
# but didn't generate final empty delta, then complete it
if (
self.current_call_id is not None
and self.tool_call_end_token in xml_chunk
):
- has_toolcall_close = any(
- (
- td.tool_calls
- and any(
- (
- tc.type == "function"
- and tc.function
- and tc.function.arguments == ""
- and tc.id == self.current_call_id
- )
- for tc in td.tool_calls
- )
- )
- for td in new_deltas
- )
- if not has_toolcall_close:
- # Close potentially unclosed element
- if self.current_param_name:
- self._end_element("parameter")
- if self.current_function_name:
- self._end_element("function")
- self._end_element("tool_call")
+ # Close potentially unclosed elements
+ if self.current_param_name:
+ self._end_element("parameter")
+ if self.current_function_open:
+ self._end_element("function")
+ self._end_element("tool_call")
except Exception as e:
logger.warning("Error with fallback parsing: %s", e)
# Merge newly generated deltas into single response
@@ -173,8 +141,8 @@
return result_delta
else:
# No complete elements, check if there's unoutput text content
- if self.text_content_buffer and self.tool_call_index == 0:
- # Has text content but no tool_call yet, output text content
+ if self.text_content_buffer:
+ # Output buffered text content
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer to avoid duplicate output
@@ -251,16 +219,15 @@
# Found complete XML element, process it
try:
preprocessed_element = self._preprocess_xml_chunk(element)
- # Check if this is the first tool_call start
+ # Check if a new tool_call starts and we have buffered text content
if (
(
preprocessed_element.strip().startswith("<tool_call>")
or preprocessed_element.strip().startswith("<function name=")
)
- and self.tool_call_index == 0
- ) and self.text_content_buffer:
- # First tool_call starts,
- # output previously collected text content first
+ and self.text_content_buffer
+ ):
+ # Output previously collected text content first
text_delta = DeltaMessage(content=self.text_content_buffer)
self._emit_delta(text_delta)
# Clear buffer for potential subsequent text content
@@ -286,7 +253,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
@@ -441,10 +408,10 @@
if delta.tool_calls:
# For tool_calls, we need to intelligently merge arguments
for tool_call in delta.tool_calls:
- # Find if there's already a tool_call with the same call_id
+ # Find if there's already a tool_call with the same index
existing_call = None
for existing in merged_tool_calls:
- if existing.id == tool_call.id:
+ if existing.index == tool_call.index:
existing_call = existing
break
@@ -593,6 +560,12 @@
"""Emit Delta response (streaming output)"""
self.deltas.append(delta)
+ def _get_call_id_for_delta(self) -> str | None:
+ if not self.id_emitted:
+ self.id_emitted = True
+ return self.current_call_id
+ return None
+
def _auto_close_open_parameter_if_needed(self, incoming_tag: str | None = None):
"""Before starting to process new elements,
if there are unclosed tags from before,
@@ -648,7 +621,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(
name=function_name, arguments=""
@@ -679,7 +652,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_start
@@ -697,7 +670,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(
name=None, arguments=json_continue
@@ -740,7 +713,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
@@ -762,9 +735,12 @@
original_data = original_data[:-1]
self.current_param_value += original_data
- # convert parameter value by param_type
+ # convert parameter value by param_type (PR #38890 — pass param/func name for log context)
converted_value = self._convert_param_value(
- self.current_param_value, param_type
+ self.current_param_value,
+ param_type,
+ self.current_param_name or "",
+ self.current_function_name or "",
)
output_data = self._convert_for_json_streaming(converted_value, param_type)
@@ -775,7 +751,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(name=None, arguments=delta_data),
)
@@ -832,7 +808,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(
name=None, arguments=output_arguments
@@ -855,11 +831,36 @@
param_type = self._get_param_type(param_name)
- # convert complete parameter value by param_type
- converted_value = self._convert_param_value(param_value, param_type)
+ # convert complete parameter value by param_type (PR #38890)
+ converted_value = self._convert_param_value(
+ param_value,
+ param_type,
+ param_name,
+ self.current_function_name or "",
+ )
+
+ # PR #38996: if the converted value is null (Python None or
+ # JSON null spelled in the model output), emit the JSON literal
+ # ``null`` here rather than treating it as a string. During
+ # streaming the value may have been partially emitted (e.g. as
+ # an empty string for object types); we emit the full literal
+ # at param-end where we know the value is complete.
+ if converted_value is None:
+ if not self.start_quote_emitted:
+ delta = DeltaMessage(
+ tool_calls=[
+ DeltaToolCall(
+ index=self.tool_call_index - 1,
+ id=self._get_call_id_for_delta(),
+ type="function",
+ function=DeltaFunctionCall(arguments="null"),
+ )
+ ]
+ )
+ self._emit_delta(delta)
# Decide whether to add end quote based on parameter type
- if param_type in ["string", "str", "text", "varchar", "char", "enum"]:
+ elif param_type in ["string", "str", "text", "varchar", "char", "enum"]:
# For empty string parameters, need special handling
if not param_value and not self.start_quote_emitted:
# No start quote output,
@@ -868,7 +869,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(name=None, arguments='""'),
)
@@ -881,7 +882,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(name=None, arguments='"'),
)
@@ -904,7 +905,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(name=None, arguments="}"),
)
@@ -917,7 +918,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(name=None, arguments="{}"),
)
@@ -940,7 +941,7 @@
tool_calls=[
DeltaToolCall(
index=self.tool_call_index - 1,
- id=self.current_call_id,
+ id=self._get_call_id_for_delta(),
type="function",
function=DeltaFunctionCall(name=None, arguments=""),
)
@@ -1003,9 +1004,18 @@
properties = find_tool_properties(self.tools, self.current_function_name)
if param_name in properties and isinstance(properties[param_name], dict):
- return self.repair_param_type(
- str(properties[param_name].get("type", "string"))
- )
+ prop = properties[param_name]
+ param_type = prop.get("type")
+ if param_type is None and "anyOf" in prop:
+ # Handle anyOf schemas (common in Qwen 3.6)
+ for option in prop["anyOf"]:
+ if isinstance(option, dict) and "type" in option:
+ opt_type = str(option["type"])
+ if opt_type in ["object", "array", "arr", "sequence"]:
+ return opt_type
+ return "string"
+
+ return self.repair_param_type(str(param_type or "string"))
return "string"
def repair_param_type(self, param_type: str) -> str:
@@ -1036,16 +1046,27 @@
else:
return "string"
- def _convert_param_value(self, param_value: str, param_type: str) -> Any:
+ def _convert_param_value(
+ self,
+ param_value: str,
+ param_type: str,
+ param_name: str = "",
+ func_name: str = "",
+ ) -> Any:
"""Convert value based on parameter type
Args:
param_value: Parameter value
param_type: Parameter type
+ param_name: Parameter name (used for diagnostic logging — PR #38890)
+ func_name: Function/tool name (used for diagnostic logging — PR #38890)
Returns:
Converted value
"""
- if param_value.lower() == "null":
+ # PR #38996: accept both JSON-style "null" and Python-style "None"
+ # (Qwen3.5's chat template uses `| string` for scalar args, which
+ # produces Python repr instead of JSON literals).
+ if param_value.lower() in ("null", "none"):
return None
param_type = param_type.strip().lower()
@@ -1065,6 +1086,8 @@
"Parsed value '%s' of parameter '%s' is not an integer "
"in tool '%s', degenerating to string.",
param_value,
+ param_name,
+ func_name,
)
return param_value
elif param_type.startswith("num") or param_type.startswith("float"):
@@ -1080,6 +1103,8 @@
"Parsed value '%s' of parameter '%s' is not a float "
"in tool '%s', degenerating to string.",
param_value,
+ param_name,
+ func_name,
)
return param_value
elif param_type in ["boolean", "bool", "binary"]:
@@ -1126,6 +1151,7 @@
if self.current_call_id:
self.last_completed_call_id = self.current_call_id
self.current_call_id = None
+ self.id_emitted = False
self.current_function_name = None
self.current_function_open = False
self.parameters = {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment