Skip to content

Instantly share code, notes, and snippets.

@bbrowning
Created May 14, 2025 10:52
Show Gist options
  • Save bbrowning/9adae17a24c5d3c7fc5c8c9ebd894b07 to your computer and use it in GitHub Desktop.
Save bbrowning/9adae17a24c5d3c7fc5c8c9ebd894b07 to your computer and use it in GitHub Desktop.
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index 8bc733fd..eaea63f8 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -161,45 +161,52 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
async def _process_vllm_chat_completion_stream_response(
stream: AsyncGenerator[OpenAIChatCompletionChunk, None],
) -> AsyncGenerator:
- event_type = ChatCompletionResponseEventType.start
- tool_call_buf = UnparseableToolCall()
+ yield ChatCompletionResponseStreamChunk(
+ event=ChatCompletionResponseEvent(
+ event_type=ChatCompletionResponseEventType.start,
+ delta=TextDelta(text=""),
+ )
+ )
+ event_type = ChatCompletionResponseEventType.progress
+ tool_call_bufs: dict[str, UnparseableToolCall] = {}
async for chunk in stream:
if not chunk.choices:
log.warning("vLLM failed to generation any completions - check the vLLM server logs for an error.")
continue
choice = chunk.choices[0]
if choice.finish_reason:
- args_str = tool_call_buf.arguments
- args = None
- try:
- args = {} if not args_str else json.loads(args_str)
- except Exception as e:
- log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
- if args:
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=event_type,
- delta=ToolCallDelta(
- tool_call=ToolCall(
- call_id=tool_call_buf.call_id,
- tool_name=tool_call_buf.tool_name,
- arguments=args,
- arguments_json=args_str,
+ for _index, tool_call_buf in sorted(tool_call_bufs.items()):
+ args_str = tool_call_buf.arguments
+ args = None
+ try:
+ args = {} if not args_str else json.loads(args_str)
+ except Exception as e:
+ log.warning(f"Failed to parse tool call buffer arguments: {args_str} \nError: {e}")
+ if args:
+ yield ChatCompletionResponseStreamChunk(
+ event=ChatCompletionResponseEvent(
+ event_type=event_type,
+ delta=ToolCallDelta(
+ tool_call=ToolCall(
+ call_id=tool_call_buf.call_id,
+ tool_name=tool_call_buf.tool_name,
+ arguments=args,
+ arguments_json=args_str,
+ ),
+ parse_status=ToolCallParseStatus.succeeded,
),
- parse_status=ToolCallParseStatus.succeeded,
- ),
+ )
)
- )
- elif args_str:
- yield ChatCompletionResponseStreamChunk(
- event=ChatCompletionResponseEvent(
- event_type=ChatCompletionResponseEventType.progress,
- delta=ToolCallDelta(
- tool_call=str(tool_call_buf),
- parse_status=ToolCallParseStatus.failed,
- ),
+ elif args_str:
+ yield ChatCompletionResponseStreamChunk(
+ event=ChatCompletionResponseEvent(
+ event_type=ChatCompletionResponseEventType.progress,
+ delta=ToolCallDelta(
+ tool_call=str(tool_call_buf),
+ parse_status=ToolCallParseStatus.failed,
+ ),
+ )
)
- )
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
@@ -209,11 +216,16 @@ async def _process_vllm_chat_completion_stream_response(
)
)
elif choice.delta.tool_calls:
- tool_call = convert_tool_call(choice.delta.tool_calls[0])
- tool_call_buf.tool_name += str(tool_call.tool_name)
- tool_call_buf.call_id += tool_call.call_id
- # TODO: remove str() when dict type for 'arguments' is no longer allowed
- tool_call_buf.arguments += str(tool_call.arguments)
+ for delta_tool_call in choice.delta.tool_calls:
+ tool_call = convert_tool_call(delta_tool_call)
+ if delta_tool_call.index not in tool_call_bufs:
+ tool_call_bufs[delta_tool_call.index] = UnparseableToolCall()
+ tool_call_buf = tool_call_bufs[delta_tool_call.index]
+ tool_call_buf.tool_name += str(tool_call.tool_name)
+ tool_call_buf.call_id += tool_call.call_id
+ tool_call_buf.arguments += (
+ tool_call.arguments if isinstance(tool_call.arguments, str) else json.dumps(tool_call.arguments)
+ )
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index e2314d44..cc000052 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -531,13 +531,19 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
tool_name = tc.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
+
+ # arguments_json can be None, so attempt it first and fall back to arguments
+ if hasattr(tc, "arguments_json") and tc.arguments_json:
+ arguments = tc.arguments_json
+ else:
+ arguments = json.dumps(tc.arguments)
result["tool_calls"].append(
{
"id": tc.call_id,
"type": "function",
"function": {
"name": tool_name,
- "arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
+ "arguments": arguments,
},
}
)
diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py
index a2e3b64c..e533b20d 100644
--- a/tests/unit/providers/inference/test_remote_vllm.py
+++ b/tests/unit/providers/inference/test_remote_vllm.py
@@ -24,6 +24,12 @@ from openai.types.chat.chat_completion_chunk import (
from openai.types.chat.chat_completion_chunk import (
ChoiceDelta as OpenAIChoiceDelta,
)
+from openai.types.chat.chat_completion_chunk import (
+ ChoiceDeltaToolCall as OpenAIChoiceDeltaToolCall,
+)
+from openai.types.chat.chat_completion_chunk import (
+ ChoiceDeltaToolCallFunction as OpenAIChoiceDeltaToolCallFunction,
+)
from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
@@ -205,8 +211,164 @@ async def test_tool_call_delta_empty_tool_call_buf():
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
- assert len(chunks) == 1
- assert chunks[0].event.stop_reason == StopReason.end_of_turn
+ assert len(chunks) == 2
+ assert chunks[0].event.event_type.value == "start"
+ assert chunks[1].event.event_type.value == "complete"
+ assert chunks[1].event.stop_reason == StopReason.end_of_turn
+
+
[email protected]
+async def test_tool_call_delta_streaming_arguments_dict():
+ async def mock_stream():
+ mock_chunk_1 = OpenAIChatCompletionChunk(
+ id="chunk-1",
+ created=1,
+ model="foo",
+ object="chat.completion.chunk",
+ choices=[
+ OpenAIChoice(
+ delta=OpenAIChoiceDelta(
+ content="",
+ tool_calls=[
+ OpenAIChoiceDeltaToolCall(
+ id="tc_1",
+ index=1,
+ function=OpenAIChoiceDeltaToolCallFunction(
+ name="power",
+ arguments="",
+ ),
+ )
+ ],
+ ),
+ finish_reason=None,
+ index=0,
+ )
+ ],
+ )
+ mock_chunk_2 = OpenAIChatCompletionChunk(
+ id="chunk-2",
+ created=1,
+ model="foo",
+ object="chat.completion.chunk",
+ choices=[
+ OpenAIChoice(
+ delta=OpenAIChoiceDelta(
+ content="",
+ tool_calls=[
+ OpenAIChoiceDeltaToolCall(
+ id="tc_1",
+ index=1,
+ function=OpenAIChoiceDeltaToolCallFunction(
+ name="power",
+ arguments='{"number": 28, "power": 3}',
+ ),
+ )
+ ],
+ ),
+ finish_reason=None,
+ index=0,
+ )
+ ],
+ )
+ mock_chunk_3 = OpenAIChatCompletionChunk(
+ id="chunk-3",
+ created=1,
+ model="foo",
+ object="chat.completion.chunk",
+ choices=[
+ OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
+ ],
+ )
+ for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
+ yield chunk
+
+ chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
+ assert len(chunks) == 3
+ assert chunks[0].event.event_type.value == "start"
+ assert chunks[1].event.event_type.value == "progress"
+ assert chunks[1].event.delta.type == "tool_call"
+ assert chunks[1].event.delta.parse_status.value == "succeeded"
+ assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
+ assert chunks[2].event.event_type.value == "complete"
+
+
[email protected]
+async def test_multiple_tool_calls():
+ async def mock_stream():
+ mock_chunk_1 = OpenAIChatCompletionChunk(
+ id="chunk-1",
+ created=1,
+ model="foo",
+ object="chat.completion.chunk",
+ choices=[
+ OpenAIChoice(
+ delta=OpenAIChoiceDelta(
+ content="",
+ tool_calls=[
+ OpenAIChoiceDeltaToolCall(
+ id="",
+ index=1,
+ function=OpenAIChoiceDeltaToolCallFunction(
+ name="power",
+ arguments='{"number": 28, "power": 3}',
+ ),
+ ),
+ ],
+ ),
+ finish_reason=None,
+ index=0,
+ )
+ ],
+ )
+ mock_chunk_2 = OpenAIChatCompletionChunk(
+ id="chunk-2",
+ created=1,
+ model="foo",
+ object="chat.completion.chunk",
+ choices=[
+ OpenAIChoice(
+ delta=OpenAIChoiceDelta(
+ content="",
+ tool_calls=[
+ OpenAIChoiceDeltaToolCall(
+ id="",
+ index=2,
+ function=OpenAIChoiceDeltaToolCallFunction(
+ name="multiple",
+ arguments='{"first_number": 4, "second_number": 7}',
+ ),
+ ),
+ ],
+ ),
+ finish_reason=None,
+ index=0,
+ )
+ ],
+ )
+ mock_chunk_3 = OpenAIChatCompletionChunk(
+ id="chunk-3",
+ created=1,
+ model="foo",
+ object="chat.completion.chunk",
+ choices=[
+ OpenAIChoice(delta=OpenAIChoiceDelta(content="", tool_calls=None), finish_reason="tool_calls", index=0)
+ ],
+ )
+ for chunk in [mock_chunk_1, mock_chunk_2, mock_chunk_3]:
+ yield chunk
+
+ chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
+ assert len(chunks) == 4
+ assert chunks[0].event.event_type.value == "start"
+ assert chunks[1].event.event_type.value == "progress"
+ assert chunks[1].event.delta.type == "tool_call"
+ assert chunks[1].event.delta.parse_status.value == "succeeded"
+ assert chunks[1].event.delta.tool_call.arguments_json == '{"number": 28, "power": 3}'
+ assert chunks[2].event.event_type.value == "progress"
+ assert chunks[2].event.delta.type == "tool_call"
+ assert chunks[2].event.delta.parse_status.value == "succeeded"
+ assert chunks[2].event.delta.tool_call.arguments_json == '{"first_number": 4, "second_number": 7}'
+ assert chunks[3].event.event_type.value == "complete"
@pytest.mark.asyncio
@@ -230,7 +392,8 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
yield chunk
chunks = [chunk async for chunk in _process_vllm_chat_completion_stream_response(mock_stream())]
- assert len(chunks) == 0
+ assert len(chunks) == 1
+ assert chunks[0].event.event_type.value == "start"
def test_chat_completion_doesnt_block_event_loop(caplog):
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment