Created
May 8, 2024 01:22
-
-
Save sam2332/67c2287e1c6fe7f08bfd5a91778caa6d to your computer and use it in GitHub Desktop.
RAG Server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, HTTPException, Depends | |
from pydantic import BaseModel | |
import requests | |
import time | |
import sqlite3 | |
from contextlib import closing | |
import numpy as np | |
app = FastAPI() | |
embeddings_model = "mxbai-embed-large" | |
chat_model = "llama3" | |
chat_model = "dolphin-mixtral" | |
chat_model = "mixtral:latest" | |
chat_model = "dolphin-mixtral:latest" | |
chat_model = "dolphin-mistral:latest" | |
ollama_host = "http://localhost:11434" | |
embeddings_model_db = "default" | |
# Database connection utility | |
def get_db_connection(): | |
global embeddings_model_db | |
conn = sqlite3.connect(f"./embeddings/{embeddings_model_db}.db") | |
conn.row_factory = sqlite3.Row | |
return conn | |
# API models | |
class EmbeddingRequest(BaseModel): | |
source: str | |
content: str | |
class ChatRequest(BaseModel): | |
messages: list | |
class RagRequest(BaseModel): | |
prompt: str | |
related_count: int | |
max_tokens: int | |
class ChangeEmbeddingDBFilename(BaseModel): | |
name: str | |
# Database setup | |
def setup_database(): | |
with get_db_connection() as conn: | |
with closing(conn.cursor()) as cursor: | |
cursor.execute(""" | |
CREATE TABLE IF NOT EXISTS embeddings ( | |
id INTEGER PRIMARY KEY, | |
source TEXT, | |
content TEXT, | |
embedding TEXT | |
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP) | |
""") | |
conn.commit() | |
setup_database() | |
@app.post("/api/change_embedding_db/") | |
async def change_embedding_db(data: ChangeEmbeddingDBFilename): | |
global embeddings_model_db | |
embeddings_model_db = data.name | |
setup_database() | |
return {"status": "success"} | |
# Endpoint to get all embeddings | |
@app.get("/api/embeddings/") | |
async def get_embeddings(): | |
with get_db_connection() as conn: | |
with closing(conn.cursor()) as cursor: | |
cursor.execute("SELECT * FROM embeddings") | |
rows = cursor.fetchall() | |
return [dict(row) for row in rows] | |
def make_embeddings_safe_for_db(embedding): | |
return str(embedding).replace('[', '{').replace(']', '}') | |
def insert_embedding(content, source): | |
print(f"Inserting embedding for {len(content)} bytes from {source}") | |
response = requests.post(ollama_host+"/api/embeddings", json={"model": embeddings_model, "prompt": content}) | |
if response.status_code == 200: | |
embedding = response.json()['embedding'] | |
with get_db_connection() as conn: | |
with closing(conn.cursor()) as cursor: | |
embedding = make_embeddings_safe_for_db(embedding) | |
#check if exists | |
cursor.execute("INSERT INTO embeddings (source, content, embedding) VALUES (?, ?, ?)", (source, content, embedding)) | |
conn.commit() | |
return {"status": "success", "content": content, "embedding": embedding} | |
else: | |
raise HTTPException(status_code=response.status_code, detail="Error processing embeddings") | |
# Endpoint to insert text and embeddings | |
@app.post("/api/insert_text_embeddings/") | |
async def insert_text_embeddings(data: EmbeddingRequest): | |
# Simulating external API call for embeddings | |
return insert_embedding(data.content, data.source) | |
#import all files in the "ingress" folder and mark the filenames as the source use the pathlib | |
from pathlib import Path | |
@app.post("/api/ingress_file_embeddings/") | |
async def ingress_file_embeddings(): | |
# Get all files in the ingress folder | |
ingress_folder = Path("ingress") | |
for file in ingress_folder.iterdir(): | |
if file.is_file(): | |
if file.suffix == ".txt": | |
with open(file, "r") as f: | |
content = f.read() | |
#chunk content 255 characters | |
for i in range(0, len(content), 255): | |
insert_embedding(content[i:i+255], file.name + " - chunk " + str(i)) | |
elif file.suffix == ".csv": | |
data = file.read_text() | |
lines = data.split("\n") | |
avg_list = [] | |
for index in range(2, len(lines) - 1,5): | |
start = time.time() | |
content = "" | |
if index-2 >0: | |
content += lines[index-2] + "\n" | |
if index-1 >0: | |
content += lines[index-1] + "\n" | |
content += lines[index] + "\n" | |
if index+1 < len(lines): | |
content += lines[index+1] + "\n" | |
if index+2 < len(lines): | |
content += lines[index+2] + "\n" | |
insert_embedding(content, file.name+" - line "+str(index)) | |
end = time.time() | |
avg_list.append(end-start) | |
avg = sum(avg_list)/len(avg_list) | |
print(f"Average time for processing 5 lines: {avg} seconds, time remaining for {len(lines)-index} lines: {avg*(len(lines)-index)} seconds") | |
avg_list = avg_list[-10:] | |
return {"status": "success"} | |
import threading as threadding | |
from queue import Queue | |
def ingress_thread(queue): | |
failout = 5 | |
while failout >0: | |
while not queue.empty(): | |
file, lines = queue.get() | |
content = "\n".join(lines) | |
try: | |
insert_embedding(content, file) | |
except Exception as e: | |
print(e) | |
failout -= 1 | |
failout -1 | |
@app.post("/api/fast_csv_ingress/") | |
async def fast_csv_ingress(): | |
queue = Queue() | |
for file in Path('ingress').glob("*.csv"): | |
lines = file.read_text().split("\n") | |
threadding.Thread(target=ingress_thread, args=(queue,)).start() | |
threadding.Thread(target=ingress_thread, args=(queue,)).start() | |
threadding.Thread(target=ingress_thread, args=(queue,)).start() | |
threadding.Thread(target=ingress_thread, args=(queue,)).start() | |
threadding.Thread(target=ingress_thread, args=(queue,)).start() | |
threadding.Thread(target=ingress_thread, args=(queue,)).start() | |
threadding.Thread(target=ingress_thread, args=(queue,)).start() | |
for index in range(2, len(lines) - 1, 5): | |
queue.put((f"{file.name} - lines {index-2} - {index+2}", lines[index-2:index+3])) | |
while queue.qsize() > 0: | |
time.sleep(1) | |
return {"status": "success"} | |
def generate_embedding(prompt): | |
response = requests.post(ollama_host+"/api/embeddings", json={"model": embeddings_model, "prompt": prompt}) | |
if response.status_code == 200: | |
return response.json()['embedding'] | |
else: | |
raise Exception("Error generating embeddings") | |
# Retrieval-Augmented Generation using embeddings | |
@app.post("/api/rag_test") | |
async def perform_ragtest(data: RagRequest): | |
with get_db_connection() as conn: | |
with closing(conn.cursor()) as cursor: | |
cursor.execute("SELECT content, embedding FROM embeddings") | |
embeddings = cursor.fetchall() | |
query_emb = generate_embedding(data.prompt) | |
db_embs = [np.fromstring(row['embedding'][1:-1], sep=',') for row in embeddings] | |
cos_sims = cosine_similarity([query_emb], db_embs)[0] | |
indices = np.argsort(cos_sims)[::-1][:3] | |
related_prompts = " ".join(embeddings[i]['content'] for i in indices) | |
system_prompt = "You are helpful, here is some info related to the user's question:\n" + related_prompts | |
return {"system_prompt": system_prompt, "related_prompts": related_prompts} | |
from fastapi import HTTPException | |
from numpy import array, argsort, fromstring | |
from sklearn.metrics.pairwise import cosine_similarity | |
from pydantic import BaseModel | |
@app.post("/api/reset_embeddings_db") | |
async def reset_embeddings_db(): | |
with get_db_connection() as conn: | |
with closing(conn.cursor()) as cursor: | |
cursor.execute("DELETE FROM embeddings") | |
conn.commit() | |
return {"status": "success"} | |
@app.post("/api/rag") | |
async def perform_rag(data: RagRequest): | |
# Create a connection to the database | |
with get_db_connection() as conn: | |
with closing(conn.cursor()) as cursor: | |
# Retrieve all embeddings from the database | |
cursor.execute("SELECT source, content, embedding FROM embeddings") | |
embeddings = cursor.fetchall() | |
# Generate the embedding for the prompt | |
query_emb = array([generate_embedding(data.prompt)]) | |
# Convert stored embeddings from strings back to numpy arrays | |
db_embs = array([fromstring(emb['embedding'][1:-1], sep=',') for emb in embeddings]) | |
# Compute cosine similarities | |
cos_sims = cosine_similarity(query_emb, db_embs)[0] | |
indices = argsort(cos_sims)[::-1][:data.related_count] # Top 3 related prompts | |
# Construct related prompts text | |
related_prompts = "" | |
#"\n".join(embeddings[i]['content'] for i in indices) | |
for i in indices: | |
related_prompts += f""" | |
#{embeddings[i]['source']} | |
``` | |
{embeddings[i]['content']} | |
```""" | |
system_prompt = f"You are helpful, here is some info related to the user's question:\n{related_prompts}\nThe next message is the users question" | |
print() | |
print(system_prompt) | |
print(data.prompt) | |
# Query an external chat model | |
response = requests.post(ollama_host+"/api/chat", json={ | |
"stream": False, | |
"model": chat_model, | |
"messages": [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": data.prompt} | |
], | |
"max_tokens": data.max_tokens | |
}) | |
print(response.status_code) | |
print(response.text) | |
if response.status_code == 200: | |
print(response.json()) | |
print() | |
return response.json() | |
else: | |
raise HTTPException(status_code=response.status_code, detail="Error processing chat with model") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=11435) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment