|
import os |
|
import logging |
|
from fastapi import FastAPI, Request, HTTPException |
|
from fastapi.responses import StreamingResponse |
|
from fastapi.middleware.cors import CORSMiddleware |
|
import httpx |
|
import json |
|
from typing import Dict, Any |
|
from dotenv import load_dotenv |
|
import time |
|
from langchain_ollama import ChatOllama |
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage |
|
from agents.base import DefaultAgent |
|
from agents.driver import DriverAgent |
|
from agents.navigator import NavigatorAgent |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
load_dotenv() |
|
|
|
app = FastAPI() |
|
|
|
# Add CORS middleware |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], # Allows all origins |
|
allow_credentials=True, |
|
allow_methods=["*"], # Allows all methods |
|
allow_headers=["*"], # Allows all headers |
|
) |
|
|
|
OLLAMA_BASE = os.getenv('OLLAMA_API_BASE', 'http://localhost:11434/v1') |
|
DEFAULT_MODEL = os.getenv('OLLAMA_MODEL', 'deepseek-r1:32b') |
|
|
|
async def stream_response(response: httpx.Response): |
|
async for chunk in response.aiter_bytes(): |
|
yield chunk |
|
|
|
# Add right after your existing constants: |
|
AVAILABLE_AGENTS = { |
|
"driver": DriverAgent(), |
|
"navigator": NavigatorAgent(), |
|
"default": DefaultAgent() |
|
} |
|
|
|
MODEL_MAPPING = { |
|
"gpt-4o": DEFAULT_MODEL, |
|
"gpt-4.1": DEFAULT_MODEL, |
|
} |
|
|
|
@app.post("/v1/chat/completions") |
|
async def handle_chat_completion(request: Request) -> Dict[str, Any]: |
|
try: |
|
request_data = await request.json() |
|
logger.info("Received chat completion request:") |
|
logger.info(json.dumps(request_data, indent=2)) |
|
|
|
model = request_data.get('model', DEFAULT_MODEL) |
|
model = MODEL_MAPPING.get(model, model) |
|
|
|
agent_name = request_data.get('agent', 'default') |
|
selected_agent = AVAILABLE_AGENTS.get(agent_name) |
|
|
|
messages = request_data.get('messages', []) |
|
stream = request_data.get('stream', False) |
|
|
|
# Convert messages to langchain format |
|
langchain_messages = [] |
|
for msg in messages: |
|
content = msg.get('content', '') |
|
if msg['role'] == 'system': |
|
langchain_messages.append(SystemMessage(content=content)) |
|
elif msg['role'] == 'user': |
|
langchain_messages.append(HumanMessage(content=content)) |
|
elif msg['role'] == 'assistant': |
|
langchain_messages.append(AIMessage(content=content)) |
|
|
|
# Check if request contains function calls |
|
functions = request_data.get('functions', None) |
|
function_call = request_data.get('function_call', None) |
|
|
|
# Use agent if there are function calls, otherwise use normal chat |
|
if functions or function_call: |
|
response_content = await selected_agent.process_with_functions( |
|
request_data['messages'][-1]['content'], |
|
functions, |
|
function_call |
|
) |
|
|
|
formatted_response = { |
|
"id": f"chatcmpl-{os.urandom(12).hex()}", |
|
"object": "chat.completion", |
|
"created": int(time.time()), |
|
"model": model, |
|
"system_fingerprint": "fp_" + os.urandom(5).hex(), |
|
"choices": [{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": response_content, |
|
}, |
|
"finish_reason": "stop" |
|
}], |
|
"usage": { |
|
"prompt_tokens": -1, |
|
"completion_tokens": -1, |
|
"total_tokens": -1 |
|
} |
|
} |
|
return formatted_response |
|
|
|
chat_model = ChatOllama( |
|
model=model, |
|
base_url=OLLAMA_BASE, |
|
streaming=stream |
|
) |
|
|
|
if stream: |
|
async def generate_stream(): |
|
current_content = "" |
|
stream_gen = chat_model.astream(langchain_messages) |
|
|
|
async def transform_chunk(chunk): |
|
nonlocal current_content |
|
if chunk.content: |
|
current_content += chunk.content |
|
response_json = { |
|
"id": f"chatcmpl-{os.urandom(12).hex()}", |
|
"object": "chat.completion.chunk", |
|
"created": int(time.time()), |
|
"model": model, |
|
"system_fingerprint": "fp_" + os.urandom(5).hex(), |
|
"choices": [{ |
|
"index": 0, |
|
"delta": { |
|
"content": chunk.content |
|
}, |
|
"finish_reason": None |
|
}] |
|
} |
|
return f"data: {json.dumps(response_json)}\n\n" |
|
return "" |
|
|
|
async for chunk in stream_gen: |
|
yield await transform_chunk(chunk) |
|
|
|
# Send final chunk |
|
final_json = { |
|
"id": f"chatcmpl-{os.urandom(12).hex()}", |
|
"object": "chat.completion.chunk", |
|
"created": int(time.time()), |
|
"model": model, |
|
"system_fingerprint": "fp_" + os.urandom(5).hex(), |
|
"choices": [{ |
|
"index": 0, |
|
"delta": {}, |
|
"finish_reason": "stop" |
|
}] |
|
} |
|
yield f"data: {json.dumps(final_json)}\n\n" |
|
yield "data: [DONE]\n\n" |
|
|
|
return StreamingResponse( |
|
generate_stream(), |
|
media_type="text/event-stream" |
|
) |
|
else: |
|
response = await chat_model.ainvoke(langchain_messages) |
|
|
|
formatted_response = { |
|
"id": f"chatcmpl-{os.urandom(12).hex()}", |
|
"object": "chat.completion", |
|
"created": int(time.time()), |
|
"model": model, |
|
"system_fingerprint": "fp_" + os.urandom(5).hex(), |
|
"choices": [{ |
|
"index": 0, |
|
"message": { |
|
"role": "assistant", |
|
"content": response.content, |
|
}, |
|
"finish_reason": "stop" |
|
}], |
|
"usage": { |
|
"prompt_tokens": -1, |
|
"completion_tokens": -1, |
|
"total_tokens": -1 |
|
} |
|
} |
|
logger.info("Formatted response:") |
|
logger.info(json.dumps(formatted_response, indent=2)) |
|
return formatted_response |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing request: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/v1/models") |
|
async def list_models(): |
|
try: |
|
async with httpx.AsyncClient() as client: |
|
response = await client.get(f"{OLLAMA_BASE}/api/tags") |
|
models = response.json() |
|
return { |
|
"object": "list", |
|
"data": [{"id": model["name"], |
|
"object": "model", |
|
"created": int(time.time()), |
|
"owned_by": "organization-owner"} |
|
for model in models.get("models", [])] |
|
} |
|
except Exception as e: |
|
logger.error(f"Error listing models: {str(e)}") |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
if __name__ == "__main__": |
|
host = os.getenv('HOST', '0.0.0.0') |
|
port = int(os.getenv('PORT', '8000')) |
|
|
|
logger.info(f"Starting API compatibility layer on {host}:{port}") |
|
logger.info(f"Proxying requests to Ollama at {OLLAMA_BASE}") |
|
|
|
import uvicorn |
|
uvicorn.run(app, host=host, port=port, log_level="info") |