Instantly share code, notes, and snippets.
Created
May 14, 2025 14:44
-
Star
0
(0)
You must be signed in to star a gist -
Fork
0
(0)
You must be signed in to fork a gist
-
Save bbrowning/4734240ce96b4264340caa9584e47c9e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py | |
index fbbbc1fb2..1f953706b 100644 | |
--- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py | |
+++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py | |
@@ -52,6 +52,27 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall( | |
name="get_weather", | |
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', | |
) | |
+PYTHON_TAGS_FUNCTION_OUTPUT="<|python_start|>[get_weather(city='San Francisco', metric='celsius')]<|python_end|>" | |
+PYTHON_TAGS_FUNCTION_CALL = FunctionCall( | |
+ name="get_weather", | |
+ arguments='{"city": "San Francisco", "metric": "celsius"}', | |
+) | |
+PYTHON_TAGS_MULTI_FUNCTION_OUTPUT='<|python_start|>get_boiling_point(liquid_name="polyjuice", celcius=true); get_boiling_point(liquid_name="polyjuice", celcius=false)<|python_end|>' | |
+PYTHON_TAGS_MULTI_FUNCTION_CALL = [ | |
+ FunctionCall( | |
+ name="get_boiling_point", | |
+ arguments='{"liquid_name": "polyjuice", "celcius": true}', | |
+ ), | |
+ FunctionCall( | |
+ name="get_boiling_point", | |
+ arguments='{"liquid_name": "polyjuice", "celcius": false}', | |
+ ), | |
+] | |
+LOWERCASE_BOOL_FUNCTION_OUTPUT = "has_things(thing1=true, thing2=false)" | |
+LOWERCASE_BOOL_FUNCTION_CALL = FunctionCall( | |
+ name="has_things", | |
+ arguments='{"thing1": true, "thing2": false}', | |
+) | |
@pytest.mark.parametrize("streaming", [True, False]) | |
@@ -118,6 +139,28 @@ TEST_CASES = [ | |
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", | |
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], | |
id="parallel_calls_nonstreaming"), | |
+ pytest.param(True, | |
+ PYTHON_TAGS_FUNCTION_OUTPUT, [PYTHON_TAGS_FUNCTION_CALL], | |
+ id="python_tags_streaming"), | |
+ pytest.param(False, | |
+ PYTHON_TAGS_FUNCTION_OUTPUT, [PYTHON_TAGS_FUNCTION_CALL], | |
+ id="python_tags_nonstreaming"), | |
+ pytest.param(True, | |
+ PYTHON_TAGS_MULTI_FUNCTION_OUTPUT, | |
+ PYTHON_TAGS_MULTI_FUNCTION_CALL, | |
+ id="python_tags_multi_streaming"), | |
+ pytest.param(False, | |
+ PYTHON_TAGS_MULTI_FUNCTION_OUTPUT, | |
+ PYTHON_TAGS_MULTI_FUNCTION_CALL, | |
+ id="python_tags_multi_nonstreaming"), | |
+ pytest.param(True, | |
+ f"[{LOWERCASE_BOOL_FUNCTION_OUTPUT}]", | |
+ [LOWERCASE_BOOL_FUNCTION_CALL], | |
+ id="lowercase_bool_streaming"), | |
+ pytest.param(False, | |
+ f"[{LOWERCASE_BOOL_FUNCTION_OUTPUT}]", | |
+ [LOWERCASE_BOOL_FUNCTION_CALL], | |
+ id="lowercase_bool_nonstreaming"), | |
] | |
diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py | |
index bb91a35af..717a2ea21 100644 | |
--- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py | |
+++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py | |
@@ -62,6 +62,17 @@ class PythonicToolParser(ToolParser): | |
Extract the tool calls from a complete model response. | |
""" | |
+ # remove <|python_start|> and <|python_end|> | |
+ # as Llama 4 model sometime will output those tokens | |
+ if model_output.startswith("<|python_start|>"): | |
+ model_output = model_output[len("<|python_start|>"):] | |
+ if not model_output.startswith("["): | |
+ model_output = "[" + model_output | |
+ model_output = model_output.replace(");", "),") | |
+ model_output = model_output.replace("<|python_end|>", "") | |
+ if not model_output.endswith("]"): | |
+ model_output += "]" | |
+ | |
if not (self.TOOL_CALL_REGEX.match(model_output)): | |
return ExtractedToolCallInformation(tools_called=False, | |
tool_calls=[], | |
@@ -100,6 +111,36 @@ class PythonicToolParser(ToolParser): | |
request: ChatCompletionRequest, | |
) -> Union[DeltaMessage, None]: | |
+ start_python_tag = "<|python_start|>" | |
+ end_python_tag = "<|python_end|>" | |
+ | |
+ # If we start with the entire start python tag, remove it | |
+ if current_text.startswith(start_python_tag): | |
+ current_text = current_text[len(start_python_tag):] | |
+ if not current_text.startswith("["): | |
+ current_text = "[" + current_text | |
+ current_text = current_text.replace(");", "),") | |
+ | |
+ # If we end with the entire end python tag, remove it | |
+ if current_text.endswith(end_python_tag): | |
+ current_text = current_text[:-len(end_python_tag)] | |
+ | |
+ # If we start with part of the start python tag, remove it | |
+ for i in range(1, len(start_python_tag)): | |
+ start_tag_substr = start_python_tag[:-i] | |
+ if current_text.startswith(start_tag_substr): | |
+ current_text = current_text[len(start_tag_substr):] | |
+ | |
+ # If we end with part of the end python tag, remove it | |
+ for i in range(1, len(end_python_tag)): | |
+ end_tag_substr = end_python_tag[:i] | |
+ if current_text.endswith(end_tag_substr): | |
+ current_text = current_text[:-len(end_tag_substr)] | |
+ | |
+ # If there's nothing left after removing python tags, stop parsing | |
+ if not current_text: | |
+ return None | |
+ | |
if not current_text.startswith("["): | |
return DeltaMessage(content=delta_text) | |
@@ -189,6 +230,8 @@ def _get_parameter_value(val: ast.expr) -> Any: | |
} | |
elif isinstance(val, ast.List): | |
return [_get_parameter_value(v) for v in val.elts] | |
+ elif isinstance(val, ast.Name) and val.id in ["true", "false"]: | |
+ return True if val.id == "true" else False | |
else: | |
raise _UnexpectedAstError("Tool call arguments must be literals") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment