Created
August 4, 2024 20:55
-
-
Save KennyVaneetvelde/4f672cea1aae9feae5dafc03f172a546 to your computer and use it in GitHub Desktop.
ChromaDB story ingester
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import hashlib | |
import logging | |
from rich.logging import RichHandler | |
from rich.progress import Progress, SpinnerColumn, TextColumn | |
from rich.console import Console | |
from rich.panel import Panel | |
import traceback | |
import dotenv | |
dotenv.load_dotenv(".env") | |
# Set up rich logging | |
logging.basicConfig( | |
level="INFO", | |
format="%(message)s", | |
datefmt="[%X]", | |
handlers=[RichHandler(rich_tracebacks=True)], | |
) | |
log = logging.getLogger("rich") | |
# Initialize the Console | |
console = Console() | |
class ChromaDBManager: | |
def __init__(self, path="chroma", collection_name="story_collection"): | |
self.client = chromadb.PersistentClient(path=path) | |
self.collection_name = collection_name | |
self.collection = self.get_or_create_collection() | |
self.openai_ef = embedding_functions.OpenAIEmbeddingFunction( | |
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small" | |
) | |
def get_or_create_collection(self): | |
try: | |
collection = self.client.get_collection(name=self.collection_name) | |
log.info(f"Using existing collection '{self.collection_name}'.") | |
except Exception: | |
log.info( | |
f"Collection '{self.collection_name}' not found. Creating new collection." | |
) | |
collection = self.client.create_collection( | |
name=self.collection_name, | |
metadata={"description": "story dataset chunks"}, | |
) | |
log.info(f"Created new collection '{self.collection_name}'.") | |
return collection | |
def generate_embeddings(self, texts): | |
try: | |
log.info(f"Generating embeddings for {len(texts)} texts") | |
embeddings = self.openai_ef(texts) | |
log.info("Embeddings generated successfully") | |
return embeddings | |
except Exception as e: | |
log.error(f"Error generating embeddings: {str(e)}") | |
raise | |
def add_documents(self, chunks, batch_size=50): | |
try: | |
ingested_hashes = set(self.collection.get()["ids"]) | |
with Progress( | |
SpinnerColumn(), | |
TextColumn("[progress.description]{task.description}"), | |
transient=True, | |
) as progress: | |
task = progress.add_task( | |
"[green]Processing chunks...", total=len(chunks) | |
) | |
to_ingest = [] | |
to_ingest_hashes = [] | |
chunk_positions = [] | |
to_update_metadata = [] | |
for chunk_position, chunk in enumerate(chunks): | |
progress.update(task, advance=1) | |
if len(chunk) < 35: | |
continue | |
chunk_hash = hashlib.sha256(chunk.encode("utf-8")).hexdigest() | |
if chunk_hash not in ingested_hashes: | |
to_ingest.append(chunk) | |
to_ingest_hashes.append(chunk_hash) | |
chunk_positions.append(chunk_position) | |
else: | |
# Update metadata for existing chunks | |
to_update_metadata.append( | |
(chunk_hash, {"chunk_position": chunk_position}) | |
) | |
if len(to_ingest) >= batch_size: | |
self._ingest_batch(to_ingest, to_ingest_hashes, chunk_positions) | |
to_ingest = [] | |
to_ingest_hashes = [] | |
chunk_positions = [] | |
if len(to_update_metadata) >= batch_size: | |
self.collection.update( | |
ids=[item[0] for item in to_update_metadata], | |
metadatas=[item[1] for item in to_update_metadata], | |
) | |
to_update_metadata = [] | |
if to_ingest: | |
self._ingest_batch(to_ingest, to_ingest_hashes, chunk_positions) | |
if to_update_metadata: | |
self.collection.update( | |
ids=[item[0] for item in to_update_metadata], | |
metadatas=[item[1] for item in to_update_metadata], | |
) | |
# Remove hashes that are no longer in the dataset | |
current_hashes = set( | |
hashlib.sha256(chunk.encode("utf-8")).hexdigest() | |
for chunk in chunks | |
if len(chunk) >= 35 | |
) | |
hashes_to_remove = ingested_hashes - current_hashes | |
if hashes_to_remove: | |
self.collection.delete(ids=list(hashes_to_remove)) | |
log.info( | |
f"Removed {len(hashes_to_remove)} outdated hashes from the collection" | |
) | |
except Exception as e: | |
log.error(f"Error in add_documents: {str(e)}") | |
raise | |
def _ingest_batch(self, chunks, hashes, chunk_positions): | |
try: | |
embeddings = self.generate_embeddings(chunks) | |
metadatas = [ | |
{"chunk_position": chunk_position} for chunk_position in chunk_positions | |
] | |
self.collection.add( | |
documents=chunks, metadatas=metadatas, ids=hashes, embeddings=embeddings | |
) | |
log.info(f"Ingested batch of {len(chunks)} chunks") | |
except Exception as e: | |
log.error(f"Error ingesting batch: {str(e)}") | |
raise | |
def query(self, query_text, n_results=5): | |
try: | |
query_embedding = self.generate_embeddings([query_text])[0] | |
results = self.collection.query( | |
query_embeddings=[query_embedding], n_results=n_results | |
) | |
return results | |
except Exception as e: | |
log.error(f"Error querying collection: {str(e)}") | |
raise | |
def clear_collection(self): | |
try: | |
self.client.delete_collection(name=self.collection_name) | |
log.info(f"Collection '{self.collection_name}' deleted successfully") | |
self.collection = self.get_or_create_collection() | |
except Exception as e: | |
log.error(f"Error deleting collection: {str(e)}") | |
raise | |
def read_and_split_file(file_path, chunk_size=1000, chunk_overlap=200): | |
try: | |
with console.status("[bold green]Reading and splitting file...") as status: | |
with open(file_path, "r", encoding="utf-8") as f: | |
data = f.read() | |
log.info(f"File read successfully: {file_path}") | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len | |
) | |
chunks = text_splitter.split_text(data) | |
log.info(f"File split into {len(chunks)} chunks") | |
return chunks | |
except Exception as e: | |
log.error(f"Error reading or splitting file: {str(e)}") | |
raise | |
def ingest_story_file(file_path, clear_existing=False): | |
try: | |
console.print(Panel.fit("[bold blue]Starting story File Ingestion[/bold blue]")) | |
chroma_manager = ChromaDBManager() | |
if clear_existing: | |
chroma_manager.clear_collection() | |
chunks = read_and_split_file(file_path) | |
# Get the existing hashes before adding new documents | |
existing_hashes = set(chroma_manager.collection.get()["ids"]) | |
# Add documents | |
chroma_manager.add_documents(chunks) | |
# After all processing is complete, get the new hashes | |
new_hashes = set(chroma_manager.collection.get()["ids"]) | |
# Calculate statistics | |
added_chunks = new_hashes - existing_hashes | |
removed_chunks = existing_hashes - new_hashes | |
unchanged_chunks = existing_hashes.intersection(new_hashes) | |
collection_info = chroma_manager.collection.count() | |
console.print( | |
Panel.fit( | |
f"[bold green]Ingestion Complete![/bold green]\n" | |
f"Total documents in the collection: {collection_info}\n" | |
f"Chunks added: {len(added_chunks)}\n" | |
f"Chunks removed: {len(removed_chunks)}\n" | |
f"Chunks unchanged: {len(unchanged_chunks)}" | |
) | |
) | |
return chroma_manager | |
except Exception as e: | |
log.error(f"Error in ingest_story_file: {str(e)}") | |
log.error(traceback.format_exc()) | |
# Example usage | |
if __name__ == "__main__": | |
try: | |
file_path = "Z:/Datasets/story/story.txt" | |
chroma_manager = ingest_story_file(file_path, clear_existing=False) | |
except Exception as e: | |
log.error(f"Unhandled exception in main: {str(e)}") | |
log.error(traceback.format_exc()) | |
raise |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment