Last active
August 2, 2024 13:07
-
-
Save tazarov/4d10d807b02fb7e2b5c5f56816d8772f to your computer and use it in GitHub Desktop.
Runs a worst-case scenario benchmark on Chroma HNSW index to demonstrate effects of fragmentation on frequent add/delete
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import gc | |
from abc import ABC | |
from typing import List, Any, TypedDict, Optional | |
from overrides import EnforceOverrides, override | |
from pydantic import BaseModel, Field | |
from rich.console import Console | |
from rich.progress import track | |
from rich.prompt import Confirm | |
from rich.table import Table | |
from chromadb import QueryResult | |
from chromadb.api.configuration import HNSWConfiguration, CollectionConfiguration | |
from chromadb.segment import VectorReader | |
from chromadb.types import SegmentScope | |
import uuid | |
import numpy as np | |
import pandas as pd | |
import time | |
import os | |
import psutil | |
import chromadb | |
np.random.seed(42) | |
class PerfCounters(TypedDict): | |
cpu_percent: float | |
memory: float | |
written_total: int | |
read_total: int | |
process = psutil.Process(os.getpid()) | |
def get_perf_counters() -> PerfCounters: | |
global process | |
cpu_percent = process.cpu_percent() | |
memory = process.memory_info().rss / 1024 / 1024 | |
if hasattr(process, "io_counters"): | |
written_total = process.io_counters().write_bytes | |
read_total = process.io_counters().read_bytes | |
else: | |
written_total = 0 | |
read_total = 0 | |
return PerfCounters(cpu_percent=cpu_percent, memory=memory, written_total=written_total, read_total=read_total) | |
def get_directory_size(directory: str) -> int: | |
total_size = 0 | |
for dirpath, _, filenames in os.walk(directory): | |
for f in filenames: | |
fp = os.path.join(dirpath, f) | |
# skip if it is symbolic link | |
if not os.path.islink(fp): | |
total_size += os.path.getsize(fp) | |
return total_size | |
class IndexUsage(TypedDict): | |
total_items: int | |
items_in_use: int | |
size_on_disk: int | |
fragmentation_ratio: float | |
class IUT(ABC, EnforceOverrides): | |
def load_index(self, path: str, hnsw_configuration: Optional[HNSWConfiguration] = None) -> None: | |
raise NotImplementedError | |
def add_embeddings(self, embeddings: np.ndarray) -> List[Any]: | |
raise NotImplementedError | |
def delete_embeddings(self, ids: List[Any]) -> None: | |
raise NotImplementedError | |
def query_embeddings(self, query: np.ndarray, n_results: Optional[int] = 10) -> QueryResult: | |
raise NotImplementedError | |
def get_usage(self) -> IndexUsage: | |
"""Returns""" | |
raise NotImplementedError | |
class ChromaIndex(IUT): | |
client: chromadb.ClientAPI | |
collection: chromadb.Collection | |
max_id: int | |
@override | |
def load_index(self, path: str, hnsw_configuration: Optional[HNSWConfiguration] = None) -> None: | |
self.client = chromadb.PersistentClient(path) | |
config = None | |
if hnsw_configuration: | |
config = CollectionConfiguration(hnsw_configuration=hnsw_configuration) | |
self.collection = self.client.get_or_create_collection("hnsw", configuration=config) | |
self.max_id = self.collection.count() + 1 | |
@override | |
def add_embeddings(self, embeddings: np.ndarray) -> List[Any]: | |
ids = [f"{self.max_id + i}" for i in range(embeddings.shape[0])] | |
self.collection.add(ids=ids, embeddings=embeddings) | |
return ids | |
@override | |
def delete_embeddings(self, ids: List[Any]) -> None: | |
self.collection.delete(ids=ids) | |
@override | |
def query_embeddings(self, query: np.ndarray, n_results: Optional[int] = 10) -> QueryResult: | |
return self.collection.query(query_embeddings=query, n_results=n_results) | |
@override | |
def get_usage(self) -> IndexUsage: | |
vector_segments = [ | |
s | |
for s in self.client._server._sysdb.get_segments() | |
if s["scope"] == SegmentScope.VECTOR and s["collection"] == self.collection.id | |
] | |
if len(vector_segments) == 0: | |
return IndexUsage(total_items=0, items_in_use=0, size_on_disk=0, fragmentation_ratio=0.0) | |
index_path = os.path.join(self.client.get_settings().persist_directory, str(vector_segments[0]["id"])) | |
hnsw_item_count = 0 | |
if len(vector_segments) > 0: | |
segment = self.client._server._manager.get_segment( | |
vector_segments[0]["collection"], VectorReader | |
) | |
hnsw_item_count = segment._index.element_count | |
index_active_items = self.collection.count() | |
return IndexUsage( | |
total_items=hnsw_item_count, | |
items_in_use=index_active_items, | |
size_on_disk=get_directory_size(index_path), | |
fragmentation_ratio=((hnsw_item_count - index_active_items) / hnsw_item_count) * 100, | |
) | |
class StatRow(BaseModel): | |
run_id: str = Field(default=None) | |
cycle_id: int = Field(default=None) | |
cycle_size: int = Field(default=None) | |
vector_size: int = Field(default=None) | |
cycle_memory_start: float = Field(default=None) | |
cycle_memory_end: float = Field(default=None) | |
cycle_total_time: float = Field(default=None) | |
cycle_start_time: float = Field(default=None) | |
add_memory_start: float = Field(default=None) # memory before add | |
add_memory_end: float = Field(default=None) # memory after add | |
add_time: float = Field(default=None) | |
add_cpu_usage: float = Field(default=None) | |
add_bytes_written: int = Field(default=None) | |
add_bytes_read: int = Field(default=None) | |
query_memory_end: float = Field(default=None) # memory after query | |
query_time: float = Field(default=None) | |
query_cpu_usage: float = Field(default=None) | |
query_bytes_written: int = Field(default=None) | |
query_bytes_read: int = Field(default=None) | |
delete_memory_end: float = Field(default=None) # memory after delete | |
delete_time: float = Field(default=None) | |
delete_cpu_usage: float = Field(default=None) | |
delete_bytes_written: int = Field(default=None) | |
delete_bytes_read: int = Field(default=None) | |
index_total_items: int = Field(default=None) | |
index_items_in_use: int = Field(default=None) | |
index_size_on_disk: int = Field(default=None) | |
index_fragmentation_ratio: float = Field(default=None) | |
def main(args: argparse.Namespace): | |
data = np.random.uniform(-1, 1, (args.cycles, args.cycle_size, args.vector_size)) | |
index_under_test = ChromaIndex() | |
index_under_test.load_index(args.run_id) | |
stats = [] | |
console = Console() | |
# Create a table | |
table = Table(title="Confirm configuration") | |
# Add columns | |
table.add_column("Prop", justify="center", style="bold magenta") | |
table.add_column("Value", justify="center", style="bold cyan") | |
# Add rows | |
table.add_row("Run ID", str(args.run_id)) | |
table.add_row("Cycles", str(args.cycles)) | |
table.add_row("Number of embeddings per cycle", str(args.cycle_size)) | |
table.add_row("Embedding Size", str(args.vector_size)) | |
console.print(table) | |
confirm = Confirm.ask("Looks good?", default=False) | |
if confirm: | |
print("Starting, hold tight") | |
else: | |
print("Not good") | |
return | |
# Print the table | |
console.print(table) | |
for i in track(range(int(args.cycles))): | |
gc.collect() | |
sr = StatRow() | |
sr.run_id = args.run_id | |
sr.cycle_id = i | |
sr.cycle_start_time = time.perf_counter() | |
pc = get_perf_counters() | |
sr.cycle_memory_start = pc["memory"] | |
sr.vector_size = args.vector_size | |
sr.cycle_size = args.cycle_size | |
sr.add_memory_start = pc["memory"] | |
add_start_time = time.perf_counter() | |
ids = index_under_test.add_embeddings(data[i]) | |
pc = get_perf_counters() | |
sr.add_memory_end = pc["memory"] | |
sr.add_time = time.perf_counter() - add_start_time | |
sr.add_cpu_usage = pc["cpu_percent"] | |
sr.add_bytes_written = pc["written_total"] | |
sr.add_bytes_read = pc["read_total"] | |
usage = index_under_test.get_usage() | |
query_start_time = time.perf_counter() | |
try: | |
index_under_test.query_embeddings([data[i][np.random.choice(data[i].shape[0])].tolist()], n_results=10) | |
except Exception as e: | |
print(e) | |
print(usage) | |
pc = get_perf_counters() | |
sr.query_memory_end = pc["memory"] | |
sr.query_time = time.perf_counter() - query_start_time | |
sr.query_cpu_usage = pc["cpu_percent"] | |
sr.query_bytes_written = pc["written_total"] | |
sr.query_bytes_read = pc["read_total"] | |
delete_start_time = time.perf_counter() | |
index_under_test.delete_embeddings(ids) | |
pc = get_perf_counters() | |
sr.delete_memory_end = pc["memory"] | |
sr.delete_time = time.perf_counter() - delete_start_time | |
sr.delete_cpu_usage = pc["cpu_percent"] | |
sr.delete_bytes_written = pc["written_total"] | |
sr.delete_bytes_read = pc["read_total"] | |
usage = index_under_test.get_usage() | |
sr.index_total_items = usage["total_items"] | |
sr.index_items_in_use = usage["items_in_use"] | |
sr.index_size_on_disk = usage["size_on_disk"] | |
sr.index_fragmentation_ratio = usage["fragmentation_ratio"] | |
sr.cycle_total_time = time.perf_counter() - sr.cycle_start_time | |
sr.cycle_memory_end = pc["memory"] | |
stats.append(sr.model_dump()) | |
gc.collect() | |
df = pd.DataFrame(stats) | |
df.to_parquet(f'{args.run_id}.parquet', index=False) | |
if __name__ == "__main__": | |
ar = argparse.ArgumentParser() | |
ar.add_argument("--run-id", type=str, default=str(uuid.uuid4()), help="Unique identifier for this run") | |
ar.add_argument("--cycles", "-c", type=int, default=1000, help="Number of cycles to run") | |
ar.add_argument("--cycle-size", "-s", type=int, default=500, help="Number of items to add per cycle") | |
ar.add_argument("--vector-size", '-e', type=int, default=1536, help="Size of the vectors to add") | |
ar.add_argument("--output-path", "-o", type=str, default=".", help="Path to save the output") | |
a = ar.parse_args() | |
main(a) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment