Skip to content

Instantly share code, notes, and snippets.

@tazarov
Last active August 2, 2024 13:07
Show Gist options
  • Save tazarov/4d10d807b02fb7e2b5c5f56816d8772f to your computer and use it in GitHub Desktop.
Save tazarov/4d10d807b02fb7e2b5c5f56816d8772f to your computer and use it in GitHub Desktop.
Runs a worst-case scenario benchmark on Chroma HNSW index to demonstrate effects of fragmentation on frequent add/delete
import argparse
import gc
from abc import ABC
from typing import List, Any, TypedDict, Optional
from overrides import EnforceOverrides, override
from pydantic import BaseModel, Field
from rich.console import Console
from rich.progress import track
from rich.prompt import Confirm
from rich.table import Table
from chromadb import QueryResult
from chromadb.api.configuration import HNSWConfiguration, CollectionConfiguration
from chromadb.segment import VectorReader
from chromadb.types import SegmentScope
import uuid
import numpy as np
import pandas as pd
import time
import os
import psutil
import chromadb
np.random.seed(42)
class PerfCounters(TypedDict):
cpu_percent: float
memory: float
written_total: int
read_total: int
process = psutil.Process(os.getpid())
def get_perf_counters() -> PerfCounters:
global process
cpu_percent = process.cpu_percent()
memory = process.memory_info().rss / 1024 / 1024
if hasattr(process, "io_counters"):
written_total = process.io_counters().write_bytes
read_total = process.io_counters().read_bytes
else:
written_total = 0
read_total = 0
return PerfCounters(cpu_percent=cpu_percent, memory=memory, written_total=written_total, read_total=read_total)
def get_directory_size(directory: str) -> int:
total_size = 0
for dirpath, _, filenames in os.walk(directory):
for f in filenames:
fp = os.path.join(dirpath, f)
# skip if it is symbolic link
if not os.path.islink(fp):
total_size += os.path.getsize(fp)
return total_size
class IndexUsage(TypedDict):
total_items: int
items_in_use: int
size_on_disk: int
fragmentation_ratio: float
class IUT(ABC, EnforceOverrides):
def load_index(self, path: str, hnsw_configuration: Optional[HNSWConfiguration] = None) -> None:
raise NotImplementedError
def add_embeddings(self, embeddings: np.ndarray) -> List[Any]:
raise NotImplementedError
def delete_embeddings(self, ids: List[Any]) -> None:
raise NotImplementedError
def query_embeddings(self, query: np.ndarray, n_results: Optional[int] = 10) -> QueryResult:
raise NotImplementedError
def get_usage(self) -> IndexUsage:
"""Returns"""
raise NotImplementedError
class ChromaIndex(IUT):
client: chromadb.ClientAPI
collection: chromadb.Collection
max_id: int
@override
def load_index(self, path: str, hnsw_configuration: Optional[HNSWConfiguration] = None) -> None:
self.client = chromadb.PersistentClient(path)
config = None
if hnsw_configuration:
config = CollectionConfiguration(hnsw_configuration=hnsw_configuration)
self.collection = self.client.get_or_create_collection("hnsw", configuration=config)
self.max_id = self.collection.count() + 1
@override
def add_embeddings(self, embeddings: np.ndarray) -> List[Any]:
ids = [f"{self.max_id + i}" for i in range(embeddings.shape[0])]
self.collection.add(ids=ids, embeddings=embeddings)
return ids
@override
def delete_embeddings(self, ids: List[Any]) -> None:
self.collection.delete(ids=ids)
@override
def query_embeddings(self, query: np.ndarray, n_results: Optional[int] = 10) -> QueryResult:
return self.collection.query(query_embeddings=query, n_results=n_results)
@override
def get_usage(self) -> IndexUsage:
vector_segments = [
s
for s in self.client._server._sysdb.get_segments()
if s["scope"] == SegmentScope.VECTOR and s["collection"] == self.collection.id
]
if len(vector_segments) == 0:
return IndexUsage(total_items=0, items_in_use=0, size_on_disk=0, fragmentation_ratio=0.0)
index_path = os.path.join(self.client.get_settings().persist_directory, str(vector_segments[0]["id"]))
hnsw_item_count = 0
if len(vector_segments) > 0:
segment = self.client._server._manager.get_segment(
vector_segments[0]["collection"], VectorReader
)
hnsw_item_count = segment._index.element_count
index_active_items = self.collection.count()
return IndexUsage(
total_items=hnsw_item_count,
items_in_use=index_active_items,
size_on_disk=get_directory_size(index_path),
fragmentation_ratio=((hnsw_item_count - index_active_items) / hnsw_item_count) * 100,
)
class StatRow(BaseModel):
run_id: str = Field(default=None)
cycle_id: int = Field(default=None)
cycle_size: int = Field(default=None)
vector_size: int = Field(default=None)
cycle_memory_start: float = Field(default=None)
cycle_memory_end: float = Field(default=None)
cycle_total_time: float = Field(default=None)
cycle_start_time: float = Field(default=None)
add_memory_start: float = Field(default=None) # memory before add
add_memory_end: float = Field(default=None) # memory after add
add_time: float = Field(default=None)
add_cpu_usage: float = Field(default=None)
add_bytes_written: int = Field(default=None)
add_bytes_read: int = Field(default=None)
query_memory_end: float = Field(default=None) # memory after query
query_time: float = Field(default=None)
query_cpu_usage: float = Field(default=None)
query_bytes_written: int = Field(default=None)
query_bytes_read: int = Field(default=None)
delete_memory_end: float = Field(default=None) # memory after delete
delete_time: float = Field(default=None)
delete_cpu_usage: float = Field(default=None)
delete_bytes_written: int = Field(default=None)
delete_bytes_read: int = Field(default=None)
index_total_items: int = Field(default=None)
index_items_in_use: int = Field(default=None)
index_size_on_disk: int = Field(default=None)
index_fragmentation_ratio: float = Field(default=None)
def main(args: argparse.Namespace):
data = np.random.uniform(-1, 1, (args.cycles, args.cycle_size, args.vector_size))
index_under_test = ChromaIndex()
index_under_test.load_index(args.run_id)
stats = []
console = Console()
# Create a table
table = Table(title="Confirm configuration")
# Add columns
table.add_column("Prop", justify="center", style="bold magenta")
table.add_column("Value", justify="center", style="bold cyan")
# Add rows
table.add_row("Run ID", str(args.run_id))
table.add_row("Cycles", str(args.cycles))
table.add_row("Number of embeddings per cycle", str(args.cycle_size))
table.add_row("Embedding Size", str(args.vector_size))
console.print(table)
confirm = Confirm.ask("Looks good?", default=False)
if confirm:
print("Starting, hold tight")
else:
print("Not good")
return
# Print the table
console.print(table)
for i in track(range(int(args.cycles))):
gc.collect()
sr = StatRow()
sr.run_id = args.run_id
sr.cycle_id = i
sr.cycle_start_time = time.perf_counter()
pc = get_perf_counters()
sr.cycle_memory_start = pc["memory"]
sr.vector_size = args.vector_size
sr.cycle_size = args.cycle_size
sr.add_memory_start = pc["memory"]
add_start_time = time.perf_counter()
ids = index_under_test.add_embeddings(data[i])
pc = get_perf_counters()
sr.add_memory_end = pc["memory"]
sr.add_time = time.perf_counter() - add_start_time
sr.add_cpu_usage = pc["cpu_percent"]
sr.add_bytes_written = pc["written_total"]
sr.add_bytes_read = pc["read_total"]
usage = index_under_test.get_usage()
query_start_time = time.perf_counter()
try:
index_under_test.query_embeddings([data[i][np.random.choice(data[i].shape[0])].tolist()], n_results=10)
except Exception as e:
print(e)
print(usage)
pc = get_perf_counters()
sr.query_memory_end = pc["memory"]
sr.query_time = time.perf_counter() - query_start_time
sr.query_cpu_usage = pc["cpu_percent"]
sr.query_bytes_written = pc["written_total"]
sr.query_bytes_read = pc["read_total"]
delete_start_time = time.perf_counter()
index_under_test.delete_embeddings(ids)
pc = get_perf_counters()
sr.delete_memory_end = pc["memory"]
sr.delete_time = time.perf_counter() - delete_start_time
sr.delete_cpu_usage = pc["cpu_percent"]
sr.delete_bytes_written = pc["written_total"]
sr.delete_bytes_read = pc["read_total"]
usage = index_under_test.get_usage()
sr.index_total_items = usage["total_items"]
sr.index_items_in_use = usage["items_in_use"]
sr.index_size_on_disk = usage["size_on_disk"]
sr.index_fragmentation_ratio = usage["fragmentation_ratio"]
sr.cycle_total_time = time.perf_counter() - sr.cycle_start_time
sr.cycle_memory_end = pc["memory"]
stats.append(sr.model_dump())
gc.collect()
df = pd.DataFrame(stats)
df.to_parquet(f'{args.run_id}.parquet', index=False)
if __name__ == "__main__":
ar = argparse.ArgumentParser()
ar.add_argument("--run-id", type=str, default=str(uuid.uuid4()), help="Unique identifier for this run")
ar.add_argument("--cycles", "-c", type=int, default=1000, help="Number of cycles to run")
ar.add_argument("--cycle-size", "-s", type=int, default=500, help="Number of items to add per cycle")
ar.add_argument("--vector-size", '-e', type=int, default=1536, help="Size of the vectors to add")
ar.add_argument("--output-path", "-o", type=str, default=".", help="Path to save the output")
a = ar.parse_args()
main(a)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment