Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save KennyVaneetvelde/4f672cea1aae9feae5dafc03f172a546 to your computer and use it in GitHub Desktop.
Save KennyVaneetvelde/4f672cea1aae9feae5dafc03f172a546 to your computer and use it in GitHub Desktop.
ChromaDB story ingester
import os
import chromadb
from chromadb.utils import embedding_functions
from langchain.text_splitter import RecursiveCharacterTextSplitter
import hashlib
import logging
from rich.logging import RichHandler
from rich.progress import Progress, SpinnerColumn, TextColumn
from rich.console import Console
from rich.panel import Panel
import traceback
import dotenv
dotenv.load_dotenv(".env")
# Set up rich logging
logging.basicConfig(
level="INFO",
format="%(message)s",
datefmt="[%X]",
handlers=[RichHandler(rich_tracebacks=True)],
)
log = logging.getLogger("rich")
# Initialize the Console
console = Console()
class ChromaDBManager:
def __init__(self, path="chroma", collection_name="story_collection"):
self.client = chromadb.PersistentClient(path=path)
self.collection_name = collection_name
self.collection = self.get_or_create_collection()
self.openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
def get_or_create_collection(self):
try:
collection = self.client.get_collection(name=self.collection_name)
log.info(f"Using existing collection '{self.collection_name}'.")
except Exception:
log.info(
f"Collection '{self.collection_name}' not found. Creating new collection."
)
collection = self.client.create_collection(
name=self.collection_name,
metadata={"description": "story dataset chunks"},
)
log.info(f"Created new collection '{self.collection_name}'.")
return collection
def generate_embeddings(self, texts):
try:
log.info(f"Generating embeddings for {len(texts)} texts")
embeddings = self.openai_ef(texts)
log.info("Embeddings generated successfully")
return embeddings
except Exception as e:
log.error(f"Error generating embeddings: {str(e)}")
raise
def add_documents(self, chunks, batch_size=50):
try:
ingested_hashes = set(self.collection.get()["ids"])
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
transient=True,
) as progress:
task = progress.add_task(
"[green]Processing chunks...", total=len(chunks)
)
to_ingest = []
to_ingest_hashes = []
chunk_positions = []
to_update_metadata = []
for chunk_position, chunk in enumerate(chunks):
progress.update(task, advance=1)
if len(chunk) < 35:
continue
chunk_hash = hashlib.sha256(chunk.encode("utf-8")).hexdigest()
if chunk_hash not in ingested_hashes:
to_ingest.append(chunk)
to_ingest_hashes.append(chunk_hash)
chunk_positions.append(chunk_position)
else:
# Update metadata for existing chunks
to_update_metadata.append(
(chunk_hash, {"chunk_position": chunk_position})
)
if len(to_ingest) >= batch_size:
self._ingest_batch(to_ingest, to_ingest_hashes, chunk_positions)
to_ingest = []
to_ingest_hashes = []
chunk_positions = []
if len(to_update_metadata) >= batch_size:
self.collection.update(
ids=[item[0] for item in to_update_metadata],
metadatas=[item[1] for item in to_update_metadata],
)
to_update_metadata = []
if to_ingest:
self._ingest_batch(to_ingest, to_ingest_hashes, chunk_positions)
if to_update_metadata:
self.collection.update(
ids=[item[0] for item in to_update_metadata],
metadatas=[item[1] for item in to_update_metadata],
)
# Remove hashes that are no longer in the dataset
current_hashes = set(
hashlib.sha256(chunk.encode("utf-8")).hexdigest()
for chunk in chunks
if len(chunk) >= 35
)
hashes_to_remove = ingested_hashes - current_hashes
if hashes_to_remove:
self.collection.delete(ids=list(hashes_to_remove))
log.info(
f"Removed {len(hashes_to_remove)} outdated hashes from the collection"
)
except Exception as e:
log.error(f"Error in add_documents: {str(e)}")
raise
def _ingest_batch(self, chunks, hashes, chunk_positions):
try:
embeddings = self.generate_embeddings(chunks)
metadatas = [
{"chunk_position": chunk_position} for chunk_position in chunk_positions
]
self.collection.add(
documents=chunks, metadatas=metadatas, ids=hashes, embeddings=embeddings
)
log.info(f"Ingested batch of {len(chunks)} chunks")
except Exception as e:
log.error(f"Error ingesting batch: {str(e)}")
raise
def query(self, query_text, n_results=5):
try:
query_embedding = self.generate_embeddings([query_text])[0]
results = self.collection.query(
query_embeddings=[query_embedding], n_results=n_results
)
return results
except Exception as e:
log.error(f"Error querying collection: {str(e)}")
raise
def clear_collection(self):
try:
self.client.delete_collection(name=self.collection_name)
log.info(f"Collection '{self.collection_name}' deleted successfully")
self.collection = self.get_or_create_collection()
except Exception as e:
log.error(f"Error deleting collection: {str(e)}")
raise
def read_and_split_file(file_path, chunk_size=1000, chunk_overlap=200):
try:
with console.status("[bold green]Reading and splitting file...") as status:
with open(file_path, "r", encoding="utf-8") as f:
data = f.read()
log.info(f"File read successfully: {file_path}")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
)
chunks = text_splitter.split_text(data)
log.info(f"File split into {len(chunks)} chunks")
return chunks
except Exception as e:
log.error(f"Error reading or splitting file: {str(e)}")
raise
def ingest_story_file(file_path, clear_existing=False):
try:
console.print(Panel.fit("[bold blue]Starting story File Ingestion[/bold blue]"))
chroma_manager = ChromaDBManager()
if clear_existing:
chroma_manager.clear_collection()
chunks = read_and_split_file(file_path)
# Get the existing hashes before adding new documents
existing_hashes = set(chroma_manager.collection.get()["ids"])
# Add documents
chroma_manager.add_documents(chunks)
# After all processing is complete, get the new hashes
new_hashes = set(chroma_manager.collection.get()["ids"])
# Calculate statistics
added_chunks = new_hashes - existing_hashes
removed_chunks = existing_hashes - new_hashes
unchanged_chunks = existing_hashes.intersection(new_hashes)
collection_info = chroma_manager.collection.count()
console.print(
Panel.fit(
f"[bold green]Ingestion Complete![/bold green]\n"
f"Total documents in the collection: {collection_info}\n"
f"Chunks added: {len(added_chunks)}\n"
f"Chunks removed: {len(removed_chunks)}\n"
f"Chunks unchanged: {len(unchanged_chunks)}"
)
)
return chroma_manager
except Exception as e:
log.error(f"Error in ingest_story_file: {str(e)}")
log.error(traceback.format_exc())
# Example usage
if __name__ == "__main__":
try:
file_path = "Z:/Datasets/story/story.txt"
chroma_manager = ingest_story_file(file_path, clear_existing=False)
except Exception as e:
log.error(f"Unhandled exception in main: {str(e)}")
log.error(traceback.format_exc())
raise
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment