Created
May 14, 2025 00:42
-
-
Save bbrowning/b5007709015cb2aabd85e0bd08e6d60f to your computer and use it in GitHub Desktop.
diff of changes to llama 4 pythonic tool parser
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py | |
index fbbbc1fb2..5d232f44a 100644 | |
--- a/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py | |
+++ b/tests/entrypoints/openai/tool_parsers/test_pythonic_tool_parser.py | |
@@ -52,6 +52,16 @@ ESCAPED_STRING_FUNCTION_CALL = FunctionCall( | |
name="get_weather", | |
arguments='{"city": "Martha\'s Vineyard", "metric": "\\"cool units\\""}', | |
) | |
+PYTHON_TAGS_FUNCTION_OUTPUT="<|python_start|>[get_weather(city='San Francisco', metric='celsius')]<|python_end|>" | |
+PYTHON_TAGS_FUNCTION_CALL = FunctionCall( | |
+ name="get_weather", | |
+ arguments='{"city": "San Francisco", "metric": "celsius"}', | |
+) | |
+LOWERCASE_BOOL_FUNCTION_OUTPUT = "has_things(thing1=true, thing2=false)" | |
+LOWERCASE_BOOL_FUNCTION_CALL = FunctionCall( | |
+ name="has_things", | |
+ arguments='{"thing1": true, "thing2": false}', | |
+) | |
@pytest.mark.parametrize("streaming", [True, False]) | |
@@ -118,6 +128,20 @@ TEST_CASES = [ | |
f"[{SIMPLE_FUNCTION_OUTPUT}, {MORE_TYPES_FUNCTION_OUTPUT}]", | |
[SIMPLE_FUNCTION_CALL, MORE_TYPES_FUNCTION_CALL], | |
id="parallel_calls_nonstreaming"), | |
+ pytest.param(True, | |
+ PYTHON_TAGS_FUNCTION_OUTPUT, [PYTHON_TAGS_FUNCTION_CALL], | |
+ id="python_tags_streaming"), | |
+ pytest.param(False, | |
+ PYTHON_TAGS_FUNCTION_OUTPUT, [PYTHON_TAGS_FUNCTION_CALL], | |
+ id="python_tags_nonstreaming"), | |
+ pytest.param(True, | |
+ f"[{LOWERCASE_BOOL_FUNCTION_OUTPUT}]", | |
+ [LOWERCASE_BOOL_FUNCTION_CALL], | |
+ id="lowercase_bool_streaming"), | |
+ pytest.param(False, | |
+ f"[{LOWERCASE_BOOL_FUNCTION_OUTPUT}]", | |
+ [LOWERCASE_BOOL_FUNCTION_CALL], | |
+ id="lowercase_bool_nonstreaming"), | |
] | |
diff --git a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py | |
index bb91a35af..e421ade5e 100644 | |
--- a/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py | |
+++ b/vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py | |
@@ -62,6 +62,12 @@ class PythonicToolParser(ToolParser): | |
Extract the tool calls from a complete model response. | |
""" | |
+ # remove <|python_start|> and <|python_end|> | |
+ # as Llama 4 model sometime will output those tokens | |
+ if model_output.startswith("<|python_start|>"): | |
+ model_output = model_output[len("<|python_start|>"):] | |
+ model_output = model_output.replace("<|python_end|>", "") | |
+ | |
if not (self.TOOL_CALL_REGEX.match(model_output)): | |
return ExtractedToolCallInformation(tools_called=False, | |
tool_calls=[], | |
@@ -100,6 +106,33 @@ class PythonicToolParser(ToolParser): | |
request: ChatCompletionRequest, | |
) -> Union[DeltaMessage, None]: | |
+ start_python_tag = "<|python_start|>" | |
+ end_python_tag = "<|python_end|>" | |
+ | |
+ # If we start with the entire start python tag, remove it | |
+ if current_text.startswith(start_python_tag): | |
+ current_text = current_text[len(start_python_tag):] | |
+ | |
+ # If we end with the entire end python tag, remove it | |
+ if current_text.endswith(end_python_tag): | |
+ current_text = current_text[:-len(end_python_tag)] | |
+ | |
+ # If we start with part of the start python tag, remove it | |
+ for i in range(1, len(start_python_tag)): | |
+ start_tag_substr = start_python_tag[:-i] | |
+ if current_text.startswith(start_tag_substr): | |
+ current_text = current_text[len(start_tag_substr):] | |
+ | |
+ # If we end with part of the end python tag, remove it | |
+ for i in range(1, len(end_python_tag)): | |
+ end_tag_substr = end_python_tag[:i] | |
+ if current_text.endswith(end_tag_substr): | |
+ current_text = current_text[:-len(end_tag_substr)] | |
+ | |
+ # If there's nothing left after removing python tags, stop parsing | |
+ if not current_text: | |
+ return None | |
+ | |
if not current_text.startswith("["): | |
return DeltaMessage(content=delta_text) | |
@@ -189,6 +222,8 @@ def _get_parameter_value(val: ast.expr) -> Any: | |
} | |
elif isinstance(val, ast.List): | |
return [_get_parameter_value(v) for v in val.elts] | |
+ elif isinstance(val, ast.Name) and val.id in ["true", "false"]: | |
+ return True if val.id == "true" else False | |
else: | |
raise _UnexpectedAstError("Tool call arguments must be literals") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment