Created
May 14, 2025 10:52
-
-
Save bbrowning/9adae17a24c5d3c7fc5c8c9ebd894b07 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py | |
index 8bc733fd..eaea63f8 100644 | |
--- a/llama_stack/providers/remote/inference/vllm/vllm.py | |
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py | |
@@ -161,45 +161,52 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: | |
async def _process_vllm_chat_completion_stream_response( | |
stream: AsyncGenerator[OpenAIChatCompletionChunk, None], | |
) -> AsyncGenerator: | |
- event_type = ChatCompletionResponseEventType.start | |
- tool_call_buf = UnparseableToolCall() | |
+ yield ChatCompletionResponseStreamChunk( | |
+ event=ChatCompletionResponseEvent( | |
+ event_type=ChatCompletionResponseEventType.start, | |
+ delta=TextDelta(text=""), | |
+ ) | |
+ ) | |
+ event_type = ChatCompletionResponseEventType.progress | |
+ tool_call_bufs: dict[str, UnparseableToolCall] = {} | |
async for chunk in stream: | |
if not chunk.choices: | |
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.") | |
continue | |
choice = chunk.choices[0] | |
if choice.finish_reason: | |
- args_str = tool_call_buf.arguments | |
- args = None | |
- try: | |
- args = {} if not args_str else json.loads(args_str) | |
- except Exception as e: | |
- log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}") | |
- if args: | |
- yield ChatCompletionResponseStreamChunk( | |
- event=ChatCompletionResponseEvent( | |
- event_type=event_type, | |
- delta=ToolCallDelta( | |
- tool_call=ToolCall( | |
- call_id=tool_call_buf.call_id, | |
- tool_name=tool_call_buf.tool_name, | |
- arguments=args, | |
- arguments_json=args_str, | |
+ for _index, tool_call_buf in sorted(tool_call_bufs.items()): | |
+ args_str = tool_call_buf.arguments | |
+ args = None | |
+ try: | |
+ args = {} if not args_str else json.loads(args_str) | |
+ except Exception as e: | |
+ log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}") | |
+ if args: | |
+ yield ChatCompletionResponseStreamChunk( | |
+ event=ChatCompletionResponseEvent( | |
+ event_type=event_type, | |
+ delta=ToolCallDelta( | |
+ tool_call=ToolCall( | |
+ call_id=tool_call_buf.call_id, | |
+ tool_name=tool_call_buf.tool_name, | |
+ arguments=args, | |
+ arguments_json=args_str, | |
+ ), | |
+ parse_status=ToolCallParseStatus.succeeded, | |
), | |
- parse_status=ToolCallParseStatus.succeeded, | |
- ), | |
+ ) | |
) | |
- ) | |
- elif args_str: | |
- yield ChatCompletionResponseStreamChunk( | |
- event=ChatCompletionResponseEvent( | |
- event_type=ChatCompletionResponseEventType.progress, | |
- delta=ToolCallDelta( | |
- tool_call=str(tool_call_buf), | |
- parse_status=ToolCallParseStatus.failed, | |
- ), | |
+ elif args_str: | |
+ yield ChatCompletionResponseStreamChunk( | |
+ event=ChatCompletionResponseEvent( | |
+ event_type=ChatCompletionResponseEventType.progress, | |
+ delta=ToolCallDelta( | |
+ tool_call=str(tool_call_buf), | |
+ parse_status=ToolCallParseStatus.failed, | |
+ ), | |
+ ) | |
) | |
- ) | |
yield ChatCompletionResponseStreamChunk( | |
event=ChatCompletionResponseEvent( | |
event_type=ChatCompletionResponseEventType.complete, | |
@@ -209,11 +216,16 @@ async def _process_vllm_chat_completion_stream_response( | |
) | |
) | |
elif choice.delta.tool_calls: | |
- tool_call = convert_tool_call(choice.delta.tool_calls[0]) | |
- tool_call_buf.tool_name += str(tool_call.tool_name) | |
- tool_call_buf.call_id += tool_call.call_id | |
- # TODO: remove str() when dict type for 'arguments' is no longer allowed | |
- tool_call_buf.arguments += str(tool_call.arguments) | |
+ for delta_tool_call in choice.delta.tool_calls: | |
+ tool_call = convert_tool_call(delta_tool_call) | |
+ if delta_tool_call.index not in tool_call_bufs: | |
+ tool_call_bufs[delta_tool_call.index] = UnparseableToolCall() | |
+ tool_call_buf = tool_call_bufs[delta_tool_call.index] | |
+ tool_call_buf.tool_name += str(tool_call.tool_name) | |
+ tool_call_buf.call_id += tool_call.call_id | |
+ tool_call_buf.arguments += ( | |
+ tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments) | |
+ ) | |
else: | |
yield ChatCompletionResponseStreamChunk( | |
event=ChatCompletionResponseEvent( | |
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py | |
index e2314d44..cc000052 100644 | |
--- a/llama_stack/providers/utils/inference/openai_compat.py | |
+++ b/llama_stack/providers/utils/inference/openai_compat.py | |
@@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals | |
tool_name = tc.tool_name | |
if isinstance(tool_name, BuiltinTool): | |
tool_name = tool_name.value | |
+ | |
+ # arguments_json can be None, so attempt it first and fall back to arguments | |
+ if hasattr(tc, "arguments_json") and tc.arguments_json: | |
+ arguments = tc.arguments_json | |
+ else: | |
+ arguments = json.dumps(tc.arguments) | |
result["tool_calls"].append( | |
{ | |
"id": tc.call_id, | |
"type": "function", | |
"function": { | |
"name": tool_name, | |
- "arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments), | |
+ "arguments": arguments, | |
}, | |
} | |
) | |
diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py | |
index a2e3b64c..e533b20d 100644 | |
--- a/tests/unit/providers/inference/test_remote_vllm.py | |
+++ b/tests/unit/providers/inference/test_remote_vllm.py | |
@@ -24,6 +24,12 @@ from openai.types.chat.chat_completion_chunk import ( | |
from openai.types.chat.chat_completion_chunk import ( | |
ChoiceDelta as OpenAIChoiceDelta, | |
) | |
+from openai.types.chat.chat_completion_chunk import ( | |
+ ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall, | |
+) | |
+from openai.types.chat.chat_completion_chunk import ( | |
+ ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction, | |
+) | |
from openai.types.model import Model as OpenAIModel | |
from llama_stack.apis.inference import ( | |
@@ -205,8 +211,164 @@ async def test_tool_call_delta_empty_tool_call_buf(): | |
yield chunk | |
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] | |
- assert len(chunks) == 1 | |
- assert chunks[0].event.stop_reason == StopReason.end_of_turn | |
+ assert len(chunks) == 2 | |
+ assert chunks[0].event.event_type.value == "start" | |
+ assert chunks[1].event.event_type.value == "complete" | |
+ assert chunks[1].event.stop_reason == StopReason.end_of_turn | |
+ | |
+ | |
[email protected] | |
+async def test_tool_call_delta_streaming_arguments_dict(): | |
+ async def mock_stream(): | |
+ mock_chunk_1 = OpenAIChatCompletionChunk( | |
+ id="chunk-1", | |
+ created=1, | |
+ model="foo", | |
+ object="chat.completion.chunk", | |
+ choices=[ | |
+ OpenAIChoice( | |
+ delta=OpenAIChoiceDelta( | |
+ content="", | |
+ tool_calls=[ | |
+ OpenAIChoiceDeltaToolCall( | |
+ id="tc_1", | |
+ index=1, | |
+ function=OpenAIChoiceDeltaToolCallFunction( | |
+ name="power", | |
+ arguments="", | |
+ ), | |
+ ) | |
+ ], | |
+ ), | |
+ finish_reason=None, | |
+ index=0, | |
+ ) | |
+ ], | |
+ ) | |
+ mock_chunk_2 = OpenAIChatCompletionChunk( | |
+ id="chunk-2", | |
+ created=1, | |
+ model="foo", | |
+ object="chat.completion.chunk", | |
+ choices=[ | |
+ OpenAIChoice( | |
+ delta=OpenAIChoiceDelta( | |
+ content="", | |
+ tool_calls=[ | |
+ OpenAIChoiceDeltaToolCall( | |
+ id="tc_1", | |
+ index=1, | |
+ function=OpenAIChoiceDeltaToolCallFunction( | |
+ name="power", | |
+ arguments='{"number": 28, "power": 3}', | |
+ ), | |
+ ) | |
+ ], | |
+ ), | |
+ finish_reason=None, | |
+ index=0, | |
+ ) | |
+ ], | |
+ ) | |
+ mock_chunk_3 = OpenAIChatCompletionChunk( | |
+ id="chunk-3", | |
+ created=1, | |
+ model="foo", | |
+ object="chat.completion.chunk", | |
+ choices=[ | |
+ OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) | |
+ ], | |
+ ) | |
+ for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: | |
+ yield chunk | |
+ | |
+ chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] | |
+ assert len(chunks) == 3 | |
+ assert chunks[0].event.event_type.value == "start" | |
+ assert chunks[1].event.event_type.value == "progress" | |
+ assert chunks[1].event.delta.type == "tool_call" | |
+ assert chunks[1].event.delta.parse_status.value == "succeeded" | |
+ assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}' | |
+ assert chunks[2].event.event_type.value == "complete" | |
+ | |
+ | |
[email protected] | |
+async def test_multiple_tool_calls(): | |
+ async def mock_stream(): | |
+ mock_chunk_1 = OpenAIChatCompletionChunk( | |
+ id="chunk-1", | |
+ created=1, | |
+ model="foo", | |
+ object="chat.completion.chunk", | |
+ choices=[ | |
+ OpenAIChoice( | |
+ delta=OpenAIChoiceDelta( | |
+ content="", | |
+ tool_calls=[ | |
+ OpenAIChoiceDeltaToolCall( | |
+ id="", | |
+ index=1, | |
+ function=OpenAIChoiceDeltaToolCallFunction( | |
+ name="power", | |
+ arguments='{"number": 28, "power": 3}', | |
+ ), | |
+ ), | |
+ ], | |
+ ), | |
+ finish_reason=None, | |
+ index=0, | |
+ ) | |
+ ], | |
+ ) | |
+ mock_chunk_2 = OpenAIChatCompletionChunk( | |
+ id="chunk-2", | |
+ created=1, | |
+ model="foo", | |
+ object="chat.completion.chunk", | |
+ choices=[ | |
+ OpenAIChoice( | |
+ delta=OpenAIChoiceDelta( | |
+ content="", | |
+ tool_calls=[ | |
+ OpenAIChoiceDeltaToolCall( | |
+ id="", | |
+ index=2, | |
+ function=OpenAIChoiceDeltaToolCallFunction( | |
+ name="multiple", | |
+ arguments='{"first_number": 4, "second_number": 7}', | |
+ ), | |
+ ), | |
+ ], | |
+ ), | |
+ finish_reason=None, | |
+ index=0, | |
+ ) | |
+ ], | |
+ ) | |
+ mock_chunk_3 = OpenAIChatCompletionChunk( | |
+ id="chunk-3", | |
+ created=1, | |
+ model="foo", | |
+ object="chat.completion.chunk", | |
+ choices=[ | |
+ OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0) | |
+ ], | |
+ ) | |
+ for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]: | |
+ yield chunk | |
+ | |
+ chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] | |
+ assert len(chunks) == 4 | |
+ assert chunks[0].event.event_type.value == "start" | |
+ assert chunks[1].event.event_type.value == "progress" | |
+ assert chunks[1].event.delta.type == "tool_call" | |
+ assert chunks[1].event.delta.parse_status.value == "succeeded" | |
+ assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}' | |
+ assert chunks[2].event.event_type.value == "progress" | |
+ assert chunks[2].event.delta.type == "tool_call" | |
+ assert chunks[2].event.delta.parse_status.value == "succeeded" | |
+ assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}' | |
+ assert chunks[3].event.event_type.value == "complete" | |
@pytest.mark.asyncio | |
@@ -230,7 +392,8 @@ async def test_process_vllm_chat_completion_stream_response_no_choices(): | |
yield chunk | |
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())] | |
- assert len(chunks) == 0 | |
+ assert len(chunks) == 1 | |
+ assert chunks[0].event.event_type.value == "start" | |
def test_chat_completion_doesnt_block_event_loop(caplog): |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment