Last active
December 31, 2023 17:16
-
-
Save kzinmr/c4b45728e1306c8148fad715ea86b77c to your computer and use it in GitHub Desktop.
Simplify the [supabase/vecs](https://github.com/supabase/vecs) library
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import uuid | |
from collections.abc import Iterable | |
from dataclasses import dataclass | |
from enum import Enum | |
from typing import Any, Self | |
from typing import cast as typing_cast | |
from pgvector.sqlalchemy import Vector | |
from sqlalchemy import ( | |
BinaryExpression, | |
Boolean, | |
Column, | |
ColumnElement, | |
MetaData, | |
String, | |
Table, | |
and_, | |
cast, | |
create_engine, | |
delete, | |
func, | |
or_, | |
select, | |
text, | |
TextClause, | |
) | |
from sqlalchemy.dialects import postgresql | |
from sqlalchemy.orm import Session, sessionmaker | |
from sqlalchemy.schema import DropTable | |
from src.lib.util import chunk_list | |
SCHEMANAME = "vecs" | |
MetadataValues = str | int | float | bool | list[str | int] | dict[str, Any] | |
Metadata = dict[str, MetadataValues] | |
Numeric = int | float | complex | |
Record = tuple[str, Iterable[Numeric], Metadata] | |
class IndexMeasure(str, Enum): | |
""" | |
An enum representing the types of distance measures available for indexing. | |
Attributes: | |
cosine_distance (str): The cosine distance measure for indexing. | |
l2_distance (str): The Euclidean (L2) distance measure for indexing. | |
max_inner_product (str): The maximum inner product measure for indexing. | |
""" | |
cosine_distance = "cosine_distance" | |
l2_distance = "l2_distance" | |
max_inner_product = "max_inner_product" | |
INDEX_MEASURE_TO_OPS = { | |
# Maps the IndexMeasure enum options to the SQL ops string required by | |
# the pgvector `create index` statement | |
IndexMeasure.cosine_distance: "vector_cosine_ops", | |
IndexMeasure.l2_distance: "vector_l2_ops", | |
IndexMeasure.max_inner_product: "vector_ip_ops", | |
} | |
INDEX_MEASURE_TO_SQLA_ACC = { | |
IndexMeasure.cosine_distance: lambda x: x.cosine_distance, | |
IndexMeasure.l2_distance: lambda x: x.l2_distance, | |
IndexMeasure.max_inner_product: lambda x: x.max_inner_product, | |
} # New intermediary class | |
class CollectionNotFound(Exception): | |
""" | |
Exception raised when attempting to access or manipulate a collection that does not exist. | |
""" | |
... | |
class ArgError(Exception): | |
""" | |
Exception raised for invalid arguments when calling a method. | |
""" | |
... | |
class MismatchedDimension(ArgError): | |
""" | |
Exception raised when multiple sources of truth for a collection's embedding dimension do not match. | |
""" | |
... | |
class FilterError(Exception): | |
""" | |
Exception raised when there's an error related to filter usage in a query. | |
""" | |
... | |
class Unreachable(Exception): | |
""" | |
Exception raised when an unreachable part of the code is executed. | |
This is typically used for error handling in cases that should be logically impossible. | |
""" | |
... | |
def build_filters(json_col: Column, filters: dict) -> ColumnElement[bool] | BinaryExpression[bool]: | |
""" | |
Builds filters for SQL query based on provided dictionary. | |
Args: | |
json_col (Column): The column in the database table. | |
filters (dict): The dictionary specifying filter conditions. | |
Raises: | |
FilterError: If filter conditions are not correctly formatted. | |
Returns: | |
The filter clause for the SQL query. | |
""" | |
if not isinstance(filters, dict): | |
raise FilterError("filters must be a dict") | |
if len(filters) > 1: | |
raise FilterError("max 1 entry per filter") | |
for key, value in filters.items(): | |
if not isinstance(key, str): | |
raise FilterError("*filters* keys must be strings") | |
if key in ("$and", "$or"): | |
if not isinstance(value, list): | |
raise FilterError("$and/$or filters must have associated list of conditions") | |
if key == "$and": | |
return and_(*[build_filters(json_col, subcond) for subcond in value]) | |
if key == "$or": | |
return or_(*[build_filters(json_col, subcond) for subcond in value]) | |
raise Unreachable() | |
if isinstance(value, dict): | |
if len(value) > 1: | |
raise FilterError("only one operator permitted") | |
for operator, clause in value.items(): | |
if operator not in ("$eq", "$ne", "$lt", "$lte", "$gt", "$gte", "$in"): | |
raise FilterError("unknown operator") | |
# equality of singular values can take advantage of the metadata index | |
# using containment operator. Containment can not be used to test equality | |
# of lists or dicts so we restrict to single values with a __len__ check. | |
if operator == "$eq" and not hasattr(clause, "__len__"): | |
_contains_value = cast({key: clause}, postgresql.JSONB) | |
return cast(json_col.op("@>")(_contains_value), Boolean) | |
if operator == "$in": | |
if not isinstance(clause, list): | |
raise FilterError("argument to $in filter must be a list") | |
for elem in clause: | |
if not isinstance(elem, (int, str, float)): | |
raise FilterError("argument to $in filter must be a list or scalars") | |
# cast the array of scalars to a postgres array of jsonb so we can | |
# directly compare json types in the query | |
_contains_value_lst = [cast(elem, postgresql.JSONB) for elem in clause] | |
return json_col.op("->")(key).in_(_contains_value_lst) | |
matches_value = cast(clause, postgresql.JSONB) | |
# handles non-singular values | |
if operator == "$eq": | |
return json_col.op("->")(key) == matches_value | |
elif operator == "$ne": | |
return json_col.op("->")(key) != matches_value | |
elif operator == "$lt": | |
return json_col.op("->")(key) < matches_value | |
elif operator == "$lte": | |
return json_col.op("->")(key) <= matches_value | |
elif operator == "$gt": | |
return json_col.op("->")(key) > matches_value | |
elif operator == "$gte": | |
return json_col.op("->")(key) >= matches_value | |
raise Unreachable() | |
class Collection: | |
def __init__(self, name: str, dimension: int, metadata: MetaData): | |
""" | |
Initializes a new instance of the `Collection` class. | |
Args: | |
name (str): The name of the collection. | |
dimension (int): The dimension of the vectors in the collection. | |
metadata (Metadata): The metadata associated with the collection. | |
""" | |
self.name = name | |
self.dimension = dimension | |
self.metadata = metadata | |
self.table = Table( | |
name, | |
metadata, | |
Column("id", String, primary_key=True), | |
Column("vec", Vector(dimension), nullable=False), | |
Column( | |
"metadata", | |
postgresql.JSONB, | |
server_default=text("'{}'::jsonb"), | |
nullable=False, | |
), | |
extend_existing=True, | |
) | |
self._vector_index: str | None = None | |
self._metadata_index: str | None = None | |
def __repr__(self) -> str: | |
""" | |
Returns a string representation of the `Collection` instance. | |
Returns: | |
str: A string representation of the `Collection` instance. | |
""" | |
return f'vecs.Collection(name="{self.name}", dimension={self.dimension})' | |
def size(self, session: Session) -> int: | |
""" | |
Returns the number of vectors in the collection. | |
Args: | |
session: A database session for executing SQL commands. | |
Returns: | |
int: The number of vectors in the collection. | |
""" | |
with session.begin(): | |
stmt = select(func.count()).select_from(self.table) | |
return session.execute(stmt).scalar() or 0 | |
def upsert(self, session: Session, records: Iterable[Record], chunk_size: int = 500) -> None: | |
""" | |
Inserts or updates vector records in the collection. | |
Args: | |
session: The database session for executing SQL commands. | |
records (Iterable[Record]): An iterable of records to upsert. Each record is a tuple: | |
- The first element is a unique string identifier. | |
- The second element is an iterable of numeric values representing the vector. | |
- The third element is metadata associated with the vector. | |
chunk_size (int): The number of records to upsert in each batch (default: 500). | |
""" | |
with session.begin(): | |
for chunk in chunk_list(records, chunk_size): | |
stmt = postgresql.insert(self.table).values(chunk) | |
stmt = stmt.on_conflict_do_update( | |
index_elements=[self.table.c.id], | |
set_=dict(vec=stmt.excluded.vec, metadata=stmt.excluded.metadata), | |
) | |
session.execute(stmt) | |
return None | |
def fetch(self, session: Session, ids: Iterable[str], chunk_size: int = 12) -> list[Record]: | |
""" | |
Fetches vectors from the collection by their identifiers. | |
Args: | |
session: The database session for executing SQL commands. | |
ids (Iterable[str]): An iterable of vector identifiers. | |
chunk_size (int): The number of records to fetch in each batch (default: 12). | |
Returns: | |
list[Record]: A list of the fetched vectors. | |
""" | |
if isinstance(ids, str): | |
raise ArgError("ids must be a list of strings") | |
records: list[Record] = [] | |
with session.begin(): | |
for id_chunk in chunk_list(ids, chunk_size): | |
stmt = select(self.table).where(self.table.c.id.in_(id_chunk)) | |
chunk_records = [typing_cast(Record, row) for row in session.execute(stmt)] | |
records.extend(chunk_records) | |
return records | |
def delete( | |
self, session: Session, ids: Iterable[str] | None = None, filters: Metadata | None = None, chunk_size: int = 12 | |
) -> list[str]: | |
""" | |
Deletes vectors from the collection by matching filters or ids. | |
Args: | |
session: The database session for executing SQL commands. | |
ids (Iterable[str], optional): An iterable of vector identifiers. | |
filters (Metadata | None, optional): Filters to apply to the deletion. Defaults to None. | |
chunk_size (int): The number of records to delete in each batch (default: 12). | |
Returns: | |
list[str]: A list of the identifiers of the deleted vectors. | |
Raises: | |
ArgError: If neither ids nor filters are provided, or both are provided. | |
""" | |
if ids is None and filters is None: | |
raise ArgError("Either ids or filters must be provided.") | |
if ids is not None and filters is not None: | |
raise ArgError("Either ids or filters must be provided, not both.") | |
if ids and isinstance(ids, str): | |
raise ArgError("ids must be a list of strings") | |
del_ids: list[str] = [] | |
with session.begin(): | |
if ids: | |
for id_chunk in chunk_list(ids, chunk_size): | |
stmt = delete(self.table).where(self.table.c.id.in_(id_chunk)).returning(self.table.c.id) | |
del_ids.extend(session.execute(stmt).scalars().fetchall()) | |
if filters: | |
meta_filter = build_filters(self.table.c.metadata, filters) | |
stmt = delete(self.table).where(meta_filter).returning(self.table.c.id) | |
del_ids.extend(session.execute(stmt).scalars().fetchall()) | |
return del_ids | |
def create_metadata_index( | |
self, | |
session: Session, | |
replace: bool = True, | |
) -> Self: | |
""" | |
Create metadata index for the collection. | |
Args: | |
session: The database session for executing SQL commands. | |
replace (bool, optional): Whether to replace the existing index. Defaults to True. | |
Returns: | |
Collection: The newly created collection. | |
""" | |
unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] | |
index_statement = text( | |
f""" | |
CREATE INDEX ix_meta_{unique_string} | |
ON {SCHEMANAME}."{self.table.name}" | |
USING gin (metadata jsonb_path_ops); | |
""" | |
) | |
with session.begin(): | |
# drop existing index if it exists | |
index = self.get_metadata_index(session) | |
if index is not None: | |
if not replace: | |
raise ArgError("An index already exists and replace is set to False.") | |
session.execute(text(f"DROP INDEX IF EXISTS {index};")) | |
self._metadata_index = None | |
session.execute(index_statement) | |
return self | |
def get_metadata_index(self, session: Session) -> str | None: | |
""" | |
Retrieves the SQL name of the collection's vector index, if it exists. | |
Args: | |
session: The database session for executing SQL commands. | |
Returns: | |
str | None: The name of the index, or None if no index exists. | |
""" | |
if self._metadata_index is None: | |
query = text( | |
f""" | |
SELECT | |
relname as table_name | |
FROM | |
pg_class pc | |
WHERE | |
pc.relnamespace = '{SCHEMANAME}'::regnamespace | |
AND relname ILIKE 'ix_meta%' | |
AND pc.relkind = 'i' | |
""" | |
) | |
self._metadata_index = session.execute(query).scalar_one_or_none() | |
return self._metadata_index | |
def create_vector_index( | |
self, | |
session: Session, | |
measure: IndexMeasure = IndexMeasure.cosine_distance, | |
m: int = 16, | |
ef_construction: int = 64, | |
replace: bool = True, | |
) -> None: | |
""" | |
Creates vector similarity search index for the collection. | |
Args: | |
session: A database session or connection (client.SessionLocal()) | |
measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'. | |
m (int): Maximum number of connections per node per layer (default: 16) | |
ef_construction (int): Size of the dynamic candidate list for constructing the graph (default: 64) | |
replace (bool, optional): Whether to replace the existing index. Defaults to True. | |
Raises: | |
ArgError: If an invalid index method is used, or if *replace* is False and an index already exists. | |
""" | |
unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] | |
if (ops := INDEX_MEASURE_TO_OPS.get(measure, None)) is None: | |
raise ArgError("Unknown index measure") | |
index_statement = text( | |
f""" | |
CREATE INDEX ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string} | |
ON {SCHEMANAME}.{self.table.name} | |
USING hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction}); | |
""" | |
) | |
with session.begin(): | |
# drop existing index if it exists | |
index = self.get_vector_index(session) | |
if index is not None: | |
if not replace: | |
raise ArgError("An index already exists and replace is set to False.") | |
session.execute(text(f"DROP INDEX IF EXISTS {index};")) | |
self._vector_index = None | |
session.execute(index_statement) | |
return None | |
def get_vector_index(self, session: Session) -> str | None: | |
""" | |
Retrieves the SQL name of the collection's vector index, if it exists. | |
Args: | |
session: The database session for executing SQL commands. | |
Returns: | |
str | None: The name of the index, or None if no index exists. | |
""" | |
if self._vector_index is None: | |
query = text( | |
f""" | |
SELECT | |
relname as table_name | |
FROM | |
pg_class pc | |
WHERE | |
pc.relnamespace = '{SCHEMANAME}'::regnamespace | |
AND relname ILIKE 'ix_vector%' | |
AND pc.relkind = 'i' | |
""" | |
) | |
self._vector_index = session.execute(query).scalar_one_or_none() | |
return self._vector_index | |
def query( | |
self, | |
session: Session, | |
data: Iterable[Numeric], | |
limit: int = 10, | |
filters: dict | None = None, | |
measure: IndexMeasure = IndexMeasure.cosine_distance, | |
include_value: bool = False, | |
include_metadata: bool = False, | |
ef_search: int = 40, | |
) -> list[Record] | list[str]: | |
""" | |
Executes a similarity search in the collection. | |
Args: | |
session: The database session for executing SQL commands. | |
data (Iterable[Numeric]): The vector to use as the query. | |
limit (int, optional): The maximum number of results to return. Defaults to 10. | |
filters (dict, optional): Filters to apply to the search. Defaults to None. | |
measure (IndexMeasure, optional): The distance measure to use for the search. Defaults to 'cosine_distance'. | |
include_value (bool, optional): Whether to include the distance value in the results. Defaults to False. | |
include_metadata (bool, optional): Whether to include the metadata in the results. Defaults to False. | |
ef_search (int, optional): Size of the dynamic candidate list for HNSW index search. | |
Returns: | |
list[Record] | list[str]: The result of the similarity search. | |
""" | |
if limit > 1000: | |
raise ArgError("limit must be <= 1000") | |
# Ensure valid measure | |
if measure not in INDEX_MEASURE_TO_SQLA_ACC: | |
raise ArgError("Invalid index measure") | |
distance_lambda = INDEX_MEASURE_TO_SQLA_ACC[measure] | |
distance_clause = distance_lambda(self.table.c.vec)(data) | |
cols = [self.table.c.id] | |
if include_value: | |
cols.append(distance_clause) | |
if include_metadata: | |
cols.append(self.table.c.metadata) | |
stmt = select(*cols) | |
if filters is not None: | |
stmt = stmt.filter(build_filters(self.table.c.metadata, filters)) # type: ignore | |
stmt = stmt.order_by(distance_clause) | |
stmt = stmt.limit(limit) | |
with session.begin(): | |
if self.get_vector_index(session) is not None: | |
session.execute(text("SET LOCAL hnsw.ef_search = :ef_search").bindparams(ef_search=ef_search)) | |
if len(cols) == 1: | |
return [str(x) for x in session.scalars(stmt)] | |
return [typing_cast(Record, x) for x in session.execute(stmt)] | |
class CollectionController: | |
def get_with_dimension_query(self, name: str) -> TextClause: | |
return text( | |
f""" | |
select | |
relname as table_name, | |
atttypmod as embedding_dim | |
from | |
pg_class pc | |
join pg_attribute pa | |
on pc.oid = pa.attrelid | |
where | |
pc.relnamespace = '{SCHEMANAME}'::regnamespace | |
and pc.relkind = 'r' | |
and pa.attname = 'vec' | |
and not pc.relname ^@ '_' | |
and pc.relname = :name | |
""" | |
).bindparams(name=name) | |
def get_or_create_collection( | |
self, | |
session: Session, | |
name: str, | |
dimension: int, | |
meta: MetaData, | |
) -> Collection: | |
""" | |
Get a vector collection by name, or create it if no collection with | |
*name* exists. | |
Args: | |
session (Session): The database session for executing SQL commands. | |
name (str): The name of the collection. | |
dimension (int): The dimensionality of the vectors in the collection. | |
meta (MetaData): The metadata associated with the DB schema. | |
Keyword Args: | |
pipeline (int): The dimensionality of the vectors in the collection. | |
Returns: | |
Collection: The created collection. | |
Raises: | |
CollectionAlreadyExists: If a collection with the same name already exists | |
""" | |
query = self.get_with_dimension_query(name) | |
with session.begin(): | |
query_result = session.execute(query).fetchone() | |
if query_result is None: | |
# collection = self._create(sess) | |
collection = Collection(name, dimension, metadata=meta) | |
collection.table.create(session.bind) # client.engine | |
else: | |
_, collection_dimension = query_result | |
if dimension != collection_dimension: | |
raise MismatchedDimension( | |
"Dimensions reported by dimension argument and existing collection do not match" | |
) | |
collection = Collection(name, dimension, metadata=meta) | |
return collection | |
def list_collections(self, session: Session) -> list[Collection]: | |
""" | |
List all vector collections. | |
Args: | |
session (Session): The database session for executing SQL commands. | |
Returns: | |
list[Collection]: A list of all collections. | |
""" | |
metadata = self.client.meta | |
query = text( | |
f""" | |
SELECT | |
relname as table_name, | |
atttypmod as embedding_dim | |
FROM | |
pg_class pc | |
JOIN pg_attribute pa | |
ON pc.oid = pa.attrelid | |
WHERE | |
pc.relnamespace = '{SCHEMANAME}'::regnamespace | |
AND pc.relkind = 'r' | |
AND pa.attname = 'vec' | |
AND not pc.relname ^@ '_' | |
""" | |
) | |
with session.execute(query) as result: | |
return [Collection(name, dimension, metadata) for name, dimension in result] | |
def delete_collection(self, session: Session, name: str) -> None: | |
""" | |
Delete a vector collection. If no collection with requested name exists, does nothing. | |
Args: | |
session (Session): The database session for executing SQL commands. | |
name (str): The name of the collection. | |
Returns: | |
None | |
""" | |
query = self.get_with_dimension_query(name) | |
with session.begin(): | |
query_result = session.execute(query).fetchone() | |
if query_result is None: | |
raise CollectionNotFound("No collection found with requested name") | |
name, dimension = query_result | |
collection = Collection(name, dimension, metadata=self.client.meta) | |
session.execute(DropTable(collection.table, if_exists=True)) | |
# session.commit() | |
return | |
class Client: | |
""" | |
The `Client` class serves as an interface to a PostgreSQL database with pgvector support. | |
It facilitates the creation, retrieval, listing and deletion of vector collections, | |
while managing connections to the database. | |
A `Client` instance represents a connection to a PostgreSQL database. This connection can be used to create | |
and manipulate vector collections, where each collection is a group of vector records in a PostgreSQL table. | |
The `Client` class can be also supports usage as a context manager by __enter__ and __exit__ methods | |
to ensure the connection to the database is properly closed after operations, or used directly. | |
""" | |
def __init__(self, connection_string: str): | |
""" | |
Initialize a Client instance. | |
Args: | |
connection_string (str): A string representing the database connection information. | |
""" | |
self.engine = create_engine(connection_string, echo=True, pool_pre_ping=True) | |
self.meta = MetaData(schema=SCHEMANAME) | |
self.SessionLocal = sessionmaker(bind=self.engine) | |
# pgvector version assertion | |
self.vector_version: str = "0.5" | |
with self.SessionLocal() as sess: | |
with sess.begin(): | |
sess.execute(text(f"create schema if not exists {SCHEMANAME};")) | |
sess.execute(text("create extension if not exists vector;")) | |
self.vector_version: str = sess.execute( | |
text("select installed_version from pg_available_extensions where name = 'vector' limit 1;") | |
).scalar_one() | |
if ( | |
not self.vector_version.startswith("0.4") | |
and not self.vector_version.startswith("0.3") | |
and not self.vector_version.startswith("0.2") | |
and not self.vector_version.startswith("0.1") | |
and not self.vector_version.startswith("0.0") | |
): | |
raise ArgError("HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support") | |
# NOTE: session は解放されるが engine は存続し、その接続プールは引き続き利用可能。 | |
# アプリケーションの終了時に破棄されれば良いので、ここでは明示的に dispose しない。 | |
self.collection_controller = CollectionController(self) | |
def disconnect(self) -> None: | |
""" | |
Disconnect the client from the database. | |
Returns: | |
None | |
""" | |
self.engine.dispose() | |
return | |
def __enter__(self) -> Self: | |
""" | |
Enable use of the 'with' statement. | |
Returns: | |
Client: The current instance of the Client. | |
""" | |
return self | |
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: | |
""" | |
Disconnect the client on exiting the 'with' statement context. | |
Args: | |
exc_type: The exception type, if any. | |
exc_val: The exception value, if any. | |
exc_tb: The traceback, if any. | |
Returns: | |
None | |
""" | |
# import traceback | |
# if exc_type is not None: | |
# print(f"Exception type: {exc_type}") | |
# print(f"Exception value: {exc_val}") | |
# print("Traceback:") | |
# traceback.print_tb(exc_tb) | |
self.disconnect() | |
return | |
def get_or_create_collection( | |
self, | |
name: str, | |
dimension: int, | |
) -> Collection: | |
with self.SessionLocal() as sess: | |
return self.collection_controller.get_or_create_collection(sess, name, dimension) | |
def list_collections(self) -> list[Collection]: | |
with self.SessionLocal() as sess: | |
return self.collection_controller.list_collections(sess) | |
def delete_collection(self, name: str) -> None: | |
with self.SessionLocal() as sess: | |
return self.collection_controller.delete_collection(sess, name) | |
def create_client(connection_string: str) -> Client: | |
"""Creates a client from a Postgres connection string""" | |
return Client(connection_string) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment