Created
July 14, 2024 09:58
-
-
Save harisec/2f816de4acf52a227766cbf9ba7402bc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# main.py | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import litellm | |
import asyncio | |
from typing import List | |
import os | |
from dotenv import load_dotenv | |
import json | |
from datetime import datetime | |
load_dotenv() # Load environment variables from .env file | |
app = FastAPI() | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:3000"], # Allow the frontend origin | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Configure LiteLLM | |
litellm.set_verbose = True | |
# LLM configurations | |
llm_configs = { | |
"gpt-4": {"model": "gpt-4o"}, | |
"claude": {"model": "anthropic/claude-3-5-sonnet-20240620"}, | |
"deepseek": {"model": "deepseek/deepseek-chat"} | |
} | |
class Question(BaseModel): | |
text: str | |
class LLMResponse(BaseModel): | |
model: str | |
response: str | |
class AggregatedResponse(BaseModel): | |
individual_responses: List[LLMResponse] | |
aggregated_response: str | |
# Ensure logs directory exists | |
logs_dir = "logs" | |
os.makedirs(logs_dir, exist_ok=True) | |
def get_next_log_number(): | |
existing_logs = [f for f in os.listdir(logs_dir) if f.endswith('.txt')] | |
if not existing_logs: | |
return 1 | |
return max([int(f.split('.')[0]) for f in existing_logs]) + 1 | |
def write_log(question, individual_responses, aggregated_response): | |
log_number = get_next_log_number() | |
filename = f"{log_number}.txt" | |
filepath = os.path.join(logs_dir, filename) | |
with open(filepath, 'w', encoding='utf-8') as f: | |
f.write("-" * 100 + "\n") | |
f.write(f"Question:\n{question}\n\n") | |
f.write("-" * 100 + "\n") | |
f.write("Individual Responses:\n") | |
f.write("-" * 100 + "\n") | |
for resp in individual_responses: | |
f.write(f"[{resp.model}]\n{resp.response}\n\n") | |
f.write("-" * 100 + "\n") | |
f.write("Aggregated Response:\n") | |
f.write("-" * 100 + "\n") | |
f.write(f"{aggregated_response}\n") | |
async def query_llm(question: str, llm_name: str): | |
try: | |
response = await litellm.acompletion( | |
model=llm_configs[llm_name]["model"], | |
messages=[{"role": "user", "content": question}] | |
) | |
return LLMResponse(model=llm_name, response=response.choices[0].message.content) | |
except Exception as e: | |
print(f"Error querying {llm_name}: {str(e)}") | |
return LLMResponse(model=llm_name, response=f"Error: {str(e)}") | |
async def aggregate_responses(question: str, responses: List[LLMResponse]): | |
aggregator_prompt = """ | |
You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability. | |
User question: {question} | |
Responses from models: | |
{responses} | |
Please provide an aggregated response: | |
""" | |
responses_text = "\n\n".join([f"Model {i+1}:\n{resp.response}" for i, resp in enumerate(responses)]) | |
full_prompt = aggregator_prompt.format(question=question, responses=responses_text) | |
try: | |
aggregated = await litellm.acompletion( | |
model="anthropic/claude-3-5-sonnet-20240620", | |
messages=[{"role": "user", "content": full_prompt}] | |
) | |
return aggregated.choices[0].message.content | |
except Exception as e: | |
print(f"Error in aggregation: {str(e)}") | |
return "Error occurred during aggregation." | |
@app.post("/ask", response_model=AggregatedResponse) | |
async def ask_question(question: Question): | |
try: | |
llm_responses = await asyncio.gather(*[query_llm(question.text, llm) for llm in llm_configs.keys()]) | |
aggregated_response = await aggregate_responses(question.text, llm_responses) | |
# Write log | |
write_log(question.text, llm_responses, aggregated_response) | |
return AggregatedResponse(individual_responses=llm_responses, aggregated_response=aggregated_response) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8010) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment