Last active
March 31, 2024 22:16
-
-
Save fsndzomga/2b2aa2d8329bed3ece0e023837eddf6a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import (FastAPI, UploadFile, | |
HTTPException, Depends, BackgroundTasks) | |
import os | |
import shutil | |
import io | |
from db import get_db, File, FileChunk | |
from sqlalchemy.orm import Session | |
from file_parser import FileParser | |
from background_tasks import TextProcessor, client | |
from sqlalchemy import select | |
from pydantic import BaseModel | |
import openai | |
app = FastAPI() | |
class QuestionModel(BaseModel): | |
question: str | |
class AskModel(BaseModel): | |
document_id: int | |
question: str | |
@app.get("/") | |
async def root(db: Session = Depends(get_db)): | |
# Query the database for all files | |
files_query = select(File) | |
files = db.scalars(files_query).all() | |
# Format and return the list of files including file_id and filename | |
files_list = [{"file_id": file.file_id, "file_name": file.file_name} for file in files] | |
return files_list | |
@app.post("/uploadfile/") | |
async def upload_file(background_tasks: BackgroundTasks, file: UploadFile, db: Session = Depends(get_db)): # noqa | |
# Define allowed file extensions | |
allowed_extensions = ["txt", "pdf"] | |
# Check if the file extension is allowed | |
file_extension = file.filename.split('.')[-1] | |
if file_extension not in allowed_extensions: | |
raise HTTPException(status_code=400, detail="File type not allowed") | |
folder = "sources" | |
try: | |
# Ensure the directory exists | |
os.makedirs(folder, exist_ok=True) | |
# Secure way to save the file | |
file_location = os.path.join(folder, file.filename) | |
file_content = await file.read() # Read file content as bytes | |
with open(file_location, "wb+") as file_object: | |
# Convert bytes content to a file-like object | |
file_like_object = io.BytesIO(file_content) | |
# Use shutil.copyfileobj for secure file writing | |
shutil.copyfileobj(file_like_object, file_object) | |
content_parser = FileParser(file_location) | |
file_text_content = content_parser.parse() | |
# save file details in the database | |
new_file = File(file_name=file.filename, | |
file_content=file_text_content) | |
db.add(new_file) | |
db.commit() | |
db.refresh(new_file) | |
# Add background job for processing file content | |
background_tasks.add_task(TextProcessor(db, new_file.file_id).chunk_and_embed, file_text_content) # noqa | |
return {"info": "File saved", "filename": file.filename} | |
except Exception as e: | |
# Log the exception (add actual logging in production code) | |
print(f"Error saving file: {e}") | |
raise HTTPException(status_code=500, detail="Error saving file") | |
# Function to get similar chunks | |
async def get_similar_chunks(file_id: int, question: str, db: Session): | |
try: | |
# Create embeddings for the question (assuming client and embedding creation logic) | |
response = client.embeddings.create(input=question, model="text-embedding-ada-002") | |
question_embedding = response.data[0].embedding | |
similar_chunks_query = select(FileChunk).where(FileChunk.file_id == file_id)\ | |
.order_by(FileChunk.embedding_vector.l2_distance(question_embedding)).limit(10) | |
similar_chunks = db.scalars(similar_chunks_query).all() | |
return similar_chunks | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
@app.post("/ask/") | |
async def ask_question(request: AskModel, db: Session = Depends(get_db)): | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
openai.api_key = OPENAI_API_KEY | |
if OPENAI_API_KEY is None: | |
raise HTTPException(status_code=500, detail="OpenAI API key is not set") | |
try: | |
similar_chunks = await get_similar_chunks(request.document_id, request.question, db) | |
# Construct context from the similar chunks' texts | |
context_texts = [chunk.chunk_text for chunk in similar_chunks] | |
context = " ".join(context_texts) | |
# Update the system message with the context | |
system_message = f"You are a helpful assistant. Here is the context to use to reply to questions: {context}" | |
# Make the OpenAI API call with the updated context | |
response = openai.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": request.question}, | |
] | |
) | |
return {"response": response.choices[0].message.content} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
@app.post("/find-similar-chunks/{file_id}") | |
async def find_similar_chunks_endpoint(file_id: int, question_data: QuestionModel, db: Session = Depends(get_db)): | |
try: | |
similar_chunks = await get_similar_chunks(file_id, question_data.question, db) | |
# Format the response | |
formatted_response = [ | |
{"chunk_id": chunk.chunk_id, "chunk_text": chunk.chunk_text} | |
for chunk in similar_chunks | |
] | |
return formatted_response | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment