Last active
November 28, 2023 14:45
-
-
Save simonmesmith/43b4c447584094d2a15793f0c6e60463 to your computer and use it in GitHub Desktop.
Simple RAG for easily embedding documents and querying embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
SIMPLE RAG! | |
This file provides a class, Collection, that makes it easy to add retrieval | |
augmented generation (RAG) to an application. | |
There are so many overly complex RAG tools out there, such as LlamaIndex and | |
LangChain. Even Chroma can be overly complex for some use cases, and I've | |
run into issues (in Streamlit Share) where Chroma's dependency on SQLite | |
caused a conflict that I couldn't resolve. Argh! | |
So I created the below to serve as a very simple, lightweight, low-dependency | |
solution to add RAG to an application. I also coded it to be a fairly easy | |
drop-in replacement for Chroma. | |
Its one major limitation is that it doesn't come with a database. This means | |
everything is stored in memory. This is fine for small collections, but if | |
you have a large collection, or don't want to constantly embed the same | |
documents, you'll probably want to use a solution with a database like Chroma. | |
The code below uses OpenAI embeddings. To use these, you'll need to get an | |
OpenAI API key and set an environment variable for OPENAI_API_KEY. If you don't | |
want to use OpenAI embeddings, you can swap those out for any other embeddings. | |
Just make sure you use the same embedding function for both adding documents | |
and querying them. | |
INSTALL: | |
pip install numpy scikit-learn openai | |
USAGE: | |
from simple_rag import Collection | |
collection = Collection() | |
collection.add( | |
ids=["1", "2"], | |
documents=["This is a document.", "This is another document."], | |
metadatas=[{"url": "http://test.com"}, {"url": "http://test.com"}], | |
) | |
results = collection.query(query_texts=["Find a document"]) | |
print(results) | |
>>> {'documents': [['This is a document.', 'This is another document.']], | |
>>> 'distances': [[0.1539412288910481, 0.17489997771983146]], 'metadatas': | |
>>> [[{'url': 'http://test.com'}, {'url': 'http://test.com'}]]} | |
""" | |
import os | |
import numpy as np | |
from openai import OpenAI | |
from sklearn.metrics.pairwise import cosine_similarity | |
class Collection: | |
"""A collection of documents with associated metadata and embeddings.""" | |
def __init__(self): | |
self.documents = [] | |
self.ids = [] | |
self.metadatas = [] | |
self.embeddings = [] | |
def add(self, ids: list[str], documents: list[str], metadatas: list[dict]): | |
"""Adds documents to the collection.""" | |
self.ids.extend(ids) | |
self.documents.extend(documents) | |
self.metadatas.extend(metadatas) | |
embeddings = self._embed_documents(documents) | |
if embeddings is not None: | |
self.embeddings.extend(embeddings) | |
def query(self, query_texts: list[str], min_distance: float = 0.3) -> dict: | |
"""Queries the collection for documents similar to the query texts.""" | |
query_embeddings = self._embed_documents(query_texts) | |
results = {"documents": [], "distances": [], "metadatas": []} | |
if query_embeddings is not None: | |
for query_embedding in query_embeddings: | |
distances = ( | |
1 | |
- cosine_similarity([query_embedding], self.embeddings)[0] | |
) | |
relevant_indices = np.where(distances <= min_distance)[0] | |
results["documents"].append( | |
[self.documents[i] for i in relevant_indices] | |
) | |
results["distances"].append( | |
[distances[i] for i in relevant_indices] | |
) | |
results["metadatas"].append( | |
[self.metadatas[i] for i in relevant_indices] | |
) | |
return results | |
def _embed_documents(self, documents: list[str]) -> np.ndarray | None: | |
"""Embeds documents. If you want to use something other than OpenAI | |
embeddings, you can change up this function.""" | |
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
model = "text-embedding-ada-002" | |
max_batch_size = 2048 # Max batch size for OpenAI embeddings | |
embeddings = [] | |
for i in range(0, len(documents), max_batch_size): | |
batch = documents[i : i + max_batch_size] # noqa | |
try: | |
response = client.embeddings.create(model=model, input=batch) | |
batch_embeddings = [data.embedding for data in response.data] | |
embeddings.extend(batch_embeddings) | |
except Exception as e: | |
print(f"Error embedding documents: {e}") | |
embeddings.extend([None] * len(batch)) | |
if embeddings: | |
return np.array(embeddings) | |
else: | |
return None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment