|
import json |
|
import httpx |
|
from fastapi import FastAPI, Request |
|
from fastapi.responses import JSONResponse |
|
import uvicorn |
|
|
|
app = FastAPI() |
|
|
|
# Point this to your actual vLLM instance |
|
VLLM_BASE_URL = "http://localhost:8000" |
|
|
|
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) |
|
async def proxy(request: Request, path: str): |
|
""" |
|
Catch-all proxy. We only intercept /v1/chat/completions to force streaming. |
|
Everything else (like /v1/models) is passed through transparently. |
|
""" |
|
url = f"{VLLM_BASE_URL}/{path}" |
|
|
|
# Read the original request |
|
body = await request.body() |
|
headers = dict(request.headers) |
|
headers.pop("host", None) # Let httpx set the correct host |
|
headers.pop("content-length", None) # Let httpx recalculate length |
|
|
|
# Pass-through for non-chat-completion endpoints |
|
if not path.endswith("v1/chat/completions") or request.method != "POST": |
|
async with httpx.AsyncClient() as client: |
|
resp = await client.request( |
|
request.method, url, content=body, headers=headers, timeout=None |
|
) |
|
return JSONResponse(status_code=resp.status_code, content=resp.json()) |
|
|
|
# --- INTERCEPT & FORCE STREAMING --- |
|
try: |
|
req_json = json.loads(body) |
|
except json.JSONDecodeError: |
|
return JSONResponse(status_code=400, content={"error": "Invalid JSON"}) |
|
|
|
req_json["stream"] = True |
|
# Request final usage stats at the end of the stream, which non-streaming relies on |
|
req_json["stream_options"] = {"include_usage": True} |
|
|
|
async with httpx.AsyncClient() as client: |
|
# Initiate the streaming request to your vLLM backend |
|
async with client.stream("POST", url, json=req_json, headers=headers, timeout=None) as response: |
|
if response.status_code != 200: |
|
await response.aread() |
|
return JSONResponse(status_code=response.status_code, content=response.json()) |
|
|
|
# Accumulators for the reassembled response |
|
full_content = "" |
|
reasoning_content = "" |
|
reasoning = "" |
|
tool_calls = {} |
|
role = "assistant" |
|
final_usage = None |
|
model = req_json.get("model", "unknown") |
|
response_id = "" |
|
created = 0 |
|
finish_reason = "stop" |
|
|
|
# Parse the SSE stream |
|
async for line in response.aiter_lines(): |
|
if not line.startswith("data: "): |
|
continue |
|
|
|
data_str = line[6:].strip() |
|
if data_str == "[DONE]": |
|
break |
|
|
|
try: |
|
chunk = json.loads(data_str) |
|
except json.JSONDecodeError: |
|
continue |
|
|
|
# Extract top-level metadata from the first chunk |
|
if not response_id and "id" in chunk: |
|
response_id = chunk["id"] |
|
if not created and "created" in chunk: |
|
created = chunk["created"] |
|
if "model" in chunk: |
|
model = chunk["model"] |
|
if "usage" in chunk and chunk["usage"]: |
|
final_usage = chunk["usage"] |
|
|
|
if not chunk.get("choices"): |
|
continue |
|
|
|
choice = chunk["choices"][0] |
|
delta = choice.get("delta", {}) |
|
|
|
# Track the final finish_reason |
|
if choice.get("finish_reason") is not None: |
|
finish_reason = choice["finish_reason"] |
|
|
|
# Aggregate roles and content |
|
if "role" in delta: |
|
role = delta["role"] |
|
if "content" in delta and delta["content"]: |
|
full_content += delta["content"] |
|
if "reasoning_content" in delta and delta["reasoning_content"]: |
|
reasoning_content += delta["reasoning_content"] |
|
if "reasoning" in delta and delta["reasoning"]: |
|
reasoning += delta["reasoning"] |
|
|
|
# Aggregate Tool Calls (trickiest part) |
|
if "tool_calls" in delta: |
|
for tc_chunk in delta["tool_calls"]: |
|
index = tc_chunk["index"] |
|
|
|
# Initialize tool call at this index if we haven't seen it |
|
if index not in tool_calls: |
|
tool_calls[index] = { |
|
"id": tc_chunk.get("id", ""), |
|
"type": "function", |
|
"function": { |
|
"name": tc_chunk.get("function", {}).get("name", ""), |
|
"arguments": tc_chunk.get("function", {}).get("arguments", "") |
|
} |
|
} |
|
else: |
|
# Append to existing tool call |
|
if "function" in tc_chunk: |
|
func_chunk = tc_chunk["function"] |
|
if "name" in func_chunk and func_chunk["name"]: |
|
tool_calls[index]["function"]["name"] += func_chunk["name"] |
|
if "arguments" in func_chunk and func_chunk["arguments"]: |
|
tool_calls[index]["function"]["arguments"] += func_chunk["arguments"] |
|
|
|
# --- RECONSTRUCT NON-STREAMING RESPONSE --- |
|
message = {"role": role} |
|
|
|
message["content"] = full_content if full_content else None |
|
|
|
if reasoning_content: |
|
message["reasoning_content"] = reasoning_content |
|
if reasoning: |
|
message["reasoning"] = reasoning |
|
|
|
if tool_calls: |
|
message["tool_calls"] = [tc for idx, tc in sorted(tool_calls.items())] |
|
else: |
|
message["tool_calls"] = [] |
|
|
|
final_response = { |
|
"id": response_id, |
|
"object": "chat.completion", |
|
"created": created, |
|
"model": model, |
|
"choices": [{ |
|
"index": 0, |
|
"message": message, |
|
"logprobs": None, |
|
"finish_reason": finish_reason |
|
}], |
|
"usage": final_usage or {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0} |
|
} |
|
|
|
return JSONResponse(content=final_response) |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8001) |