Skip to content

Instantly share code, notes, and snippets.

@kzinmr
Last active December 31, 2023 17:16
Show Gist options
  • Save kzinmr/c4b45728e1306c8148fad715ea86b77c to your computer and use it in GitHub Desktop.
Save kzinmr/c4b45728e1306c8148fad715ea86b77c to your computer and use it in GitHub Desktop.
Simplify the [supabase/vecs](https://github.com/supabase/vecs) library
import uuid
from collections.abc import Iterable
from dataclasses import dataclass
from enum import Enum
from typing import Any, Self
from typing import cast as typing_cast
from pgvector.sqlalchemy import Vector
from sqlalchemy import (
BinaryExpression,
Boolean,
Column,
ColumnElement,
MetaData,
String,
Table,
and_,
cast,
create_engine,
delete,
func,
or_,
select,
text,
TextClause,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.schema import DropTable
from src.lib.util import chunk_list
SCHEMANAME = "vecs"
MetadataValues = str | int | float | bool | list[str | int] | dict[str, Any]
Metadata = dict[str, MetadataValues]
Numeric = int | float | complex
Record = tuple[str, Iterable[Numeric], Metadata]
class IndexMeasure(str, Enum):
"""
An enum representing the types of distance measures available for indexing.
Attributes:
cosine_distance (str): The cosine distance measure for indexing.
l2_distance (str): The Euclidean (L2) distance measure for indexing.
max_inner_product (str): The maximum inner product measure for indexing.
"""
cosine_distance = "cosine_distance"
l2_distance = "l2_distance"
max_inner_product = "max_inner_product"
INDEX_MEASURE_TO_OPS = {
# Maps the IndexMeasure enum options to the SQL ops string required by
# the pgvector `create index` statement
IndexMeasure.cosine_distance: "vector_cosine_ops",
IndexMeasure.l2_distance: "vector_l2_ops",
IndexMeasure.max_inner_product: "vector_ip_ops",
}
INDEX_MEASURE_TO_SQLA_ACC = {
IndexMeasure.cosine_distance: lambda x: x.cosine_distance,
IndexMeasure.l2_distance: lambda x: x.l2_distance,
IndexMeasure.max_inner_product: lambda x: x.max_inner_product,
} # New intermediary class
class CollectionNotFound(Exception):
"""
Exception raised when attempting to access or manipulate a collection that does not exist.
"""
...
class ArgError(Exception):
"""
Exception raised for invalid arguments when calling a method.
"""
...
class MismatchedDimension(ArgError):
"""
Exception raised when multiple sources of truth for a collection's embedding dimension do not match.
"""
...
class FilterError(Exception):
"""
Exception raised when there's an error related to filter usage in a query.
"""
...
class Unreachable(Exception):
"""
Exception raised when an unreachable part of the code is executed.
This is typically used for error handling in cases that should be logically impossible.
"""
...
def build_filters(json_col: Column, filters: dict) -> ColumnElement[bool] | BinaryExpression[bool]:
"""
Builds filters for SQL query based on provided dictionary.
Args:
json_col (Column): The column in the database table.
filters (dict): The dictionary specifying filter conditions.
Raises:
FilterError: If filter conditions are not correctly formatted.
Returns:
The filter clause for the SQL query.
"""
if not isinstance(filters, dict):
raise FilterError("filters must be a dict")
if len(filters) > 1:
raise FilterError("max 1 entry per filter")
for key, value in filters.items():
if not isinstance(key, str):
raise FilterError("*filters* keys must be strings")
if key in ("$and", "$or"):
if not isinstance(value, list):
raise FilterError("$and/$or filters must have associated list of conditions")
if key == "$and":
return and_(*[build_filters(json_col, subcond) for subcond in value])
if key == "$or":
return or_(*[build_filters(json_col, subcond) for subcond in value])
raise Unreachable()
if isinstance(value, dict):
if len(value) > 1:
raise FilterError("only one operator permitted")
for operator, clause in value.items():
if operator not in ("$eq", "$ne", "$lt", "$lte", "$gt", "$gte", "$in"):
raise FilterError("unknown operator")
# equality of singular values can take advantage of the metadata index
# using containment operator. Containment can not be used to test equality
# of lists or dicts so we restrict to single values with a __len__ check.
if operator == "$eq" and not hasattr(clause, "__len__"):
_contains_value = cast({key: clause}, postgresql.JSONB)
return cast(json_col.op("@>")(_contains_value), Boolean)
if operator == "$in":
if not isinstance(clause, list):
raise FilterError("argument to $in filter must be a list")
for elem in clause:
if not isinstance(elem, (int, str, float)):
raise FilterError("argument to $in filter must be a list or scalars")
# cast the array of scalars to a postgres array of jsonb so we can
# directly compare json types in the query
_contains_value_lst = [cast(elem, postgresql.JSONB) for elem in clause]
return json_col.op("->")(key).in_(_contains_value_lst)
matches_value = cast(clause, postgresql.JSONB)
# handles non-singular values
if operator == "$eq":
return json_col.op("->")(key) == matches_value
elif operator == "$ne":
return json_col.op("->")(key) != matches_value
elif operator == "$lt":
return json_col.op("->")(key) < matches_value
elif operator == "$lte":
return json_col.op("->")(key) <= matches_value
elif operator == "$gt":
return json_col.op("->")(key) > matches_value
elif operator == "$gte":
return json_col.op("->")(key) >= matches_value
raise Unreachable()
class Collection:
def __init__(self, name: str, dimension: int, metadata: MetaData):
"""
Initializes a new instance of the `Collection` class.
Args:
name (str): The name of the collection.
dimension (int): The dimension of the vectors in the collection.
metadata (Metadata): The metadata associated with the collection.
"""
self.name = name
self.dimension = dimension
self.metadata = metadata
self.table = Table(
name,
metadata,
Column("id", String, primary_key=True),
Column("vec", Vector(dimension), nullable=False),
Column(
"metadata",
postgresql.JSONB,
server_default=text("'{}'::jsonb"),
nullable=False,
),
extend_existing=True,
)
self._vector_index: str | None = None
self._metadata_index: str | None = None
def __repr__(self) -> str:
"""
Returns a string representation of the `Collection` instance.
Returns:
str: A string representation of the `Collection` instance.
"""
return f'vecs.Collection(name="{self.name}", dimension={self.dimension})'
def size(self, session: Session) -> int:
"""
Returns the number of vectors in the collection.
Args:
session: A database session for executing SQL commands.
Returns:
int: The number of vectors in the collection.
"""
with session.begin():
stmt = select(func.count()).select_from(self.table)
return session.execute(stmt).scalar() or 0
def upsert(self, session: Session, records: Iterable[Record], chunk_size: int = 500) -> None:
"""
Inserts or updates vector records in the collection.
Args:
session: The database session for executing SQL commands.
records (Iterable[Record]): An iterable of records to upsert. Each record is a tuple:
- The first element is a unique string identifier.
- The second element is an iterable of numeric values representing the vector.
- The third element is metadata associated with the vector.
chunk_size (int): The number of records to upsert in each batch (default: 500).
"""
with session.begin():
for chunk in chunk_list(records, chunk_size):
stmt = postgresql.insert(self.table).values(chunk)
stmt = stmt.on_conflict_do_update(
index_elements=[self.table.c.id],
set_=dict(vec=stmt.excluded.vec, metadata=stmt.excluded.metadata),
)
session.execute(stmt)
return None
def fetch(self, session: Session, ids: Iterable[str], chunk_size: int = 12) -> list[Record]:
"""
Fetches vectors from the collection by their identifiers.
Args:
session: The database session for executing SQL commands.
ids (Iterable[str]): An iterable of vector identifiers.
chunk_size (int): The number of records to fetch in each batch (default: 12).
Returns:
list[Record]: A list of the fetched vectors.
"""
if isinstance(ids, str):
raise ArgError("ids must be a list of strings")
records: list[Record] = []
with session.begin():
for id_chunk in chunk_list(ids, chunk_size):
stmt = select(self.table).where(self.table.c.id.in_(id_chunk))
chunk_records = [typing_cast(Record, row) for row in session.execute(stmt)]
records.extend(chunk_records)
return records
def delete(
self, session: Session, ids: Iterable[str] | None = None, filters: Metadata | None = None, chunk_size: int = 12
) -> list[str]:
"""
Deletes vectors from the collection by matching filters or ids.
Args:
session: The database session for executing SQL commands.
ids (Iterable[str], optional): An iterable of vector identifiers.
filters (Metadata | None, optional): Filters to apply to the deletion. Defaults to None.
chunk_size (int): The number of records to delete in each batch (default: 12).
Returns:
list[str]: A list of the identifiers of the deleted vectors.
Raises:
ArgError: If neither ids nor filters are provided, or both are provided.
"""
if ids is None and filters is None:
raise ArgError("Either ids or filters must be provided.")
if ids is not None and filters is not None:
raise ArgError("Either ids or filters must be provided, not both.")
if ids and isinstance(ids, str):
raise ArgError("ids must be a list of strings")
del_ids: list[str] = []
with session.begin():
if ids:
for id_chunk in chunk_list(ids, chunk_size):
stmt = delete(self.table).where(self.table.c.id.in_(id_chunk)).returning(self.table.c.id)
del_ids.extend(session.execute(stmt).scalars().fetchall())
if filters:
meta_filter = build_filters(self.table.c.metadata, filters)
stmt = delete(self.table).where(meta_filter).returning(self.table.c.id)
del_ids.extend(session.execute(stmt).scalars().fetchall())
return del_ids
def create_metadata_index(
self,
session: Session,
replace: bool = True,
) -> Self:
"""
Create metadata index for the collection.
Args:
session: The database session for executing SQL commands.
replace (bool, optional): Whether to replace the existing index. Defaults to True.
Returns:
Collection: The newly created collection.
"""
unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]
index_statement = text(
f"""
CREATE INDEX ix_meta_{unique_string}
ON {SCHEMANAME}."{self.table.name}"
USING gin (metadata jsonb_path_ops);
"""
)
with session.begin():
# drop existing index if it exists
index = self.get_metadata_index(session)
if index is not None:
if not replace:
raise ArgError("An index already exists and replace is set to False.")
session.execute(text(f"DROP INDEX IF EXISTS {index};"))
self._metadata_index = None
session.execute(index_statement)
return self
def get_metadata_index(self, session: Session) -> str | None:
"""
Retrieves the SQL name of the collection's vector index, if it exists.
Args:
session: The database session for executing SQL commands.
Returns:
str | None: The name of the index, or None if no index exists.
"""
if self._metadata_index is None:
query = text(
f"""
SELECT
relname as table_name
FROM
pg_class pc
WHERE
pc.relnamespace = '{SCHEMANAME}'::regnamespace
AND relname ILIKE 'ix_meta%'
AND pc.relkind = 'i'
"""
)
self._metadata_index = session.execute(query).scalar_one_or_none()
return self._metadata_index
def create_vector_index(
self,
session: Session,
measure: IndexMeasure = IndexMeasure.cosine_distance,
m: int = 16,
ef_construction: int = 64,
replace: bool = True,
) -> None:
"""
Creates vector similarity search index for the collection.
Args:
session: A database session or connection (client.SessionLocal())
measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'.
m (int): Maximum number of connections per node per layer (default: 16)
ef_construction (int): Size of the dynamic candidate list for constructing the graph (default: 64)
replace (bool, optional): Whether to replace the existing index. Defaults to True.
Raises:
ArgError: If an invalid index method is used, or if *replace* is False and an index already exists.
"""
unique_string = str(uuid.uuid4()).replace("-", "_")[0:7]
if (ops := INDEX_MEASURE_TO_OPS.get(measure, None)) is None:
raise ArgError("Unknown index measure")
index_statement = text(
f"""
CREATE INDEX ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string}
ON {SCHEMANAME}.{self.table.name}
USING hnsw (vec {ops}) WITH (m={m}, ef_construction={ef_construction});
"""
)
with session.begin():
# drop existing index if it exists
index = self.get_vector_index(session)
if index is not None:
if not replace:
raise ArgError("An index already exists and replace is set to False.")
session.execute(text(f"DROP INDEX IF EXISTS {index};"))
self._vector_index = None
session.execute(index_statement)
return None
def get_vector_index(self, session: Session) -> str | None:
"""
Retrieves the SQL name of the collection's vector index, if it exists.
Args:
session: The database session for executing SQL commands.
Returns:
str | None: The name of the index, or None if no index exists.
"""
if self._vector_index is None:
query = text(
f"""
SELECT
relname as table_name
FROM
pg_class pc
WHERE
pc.relnamespace = '{SCHEMANAME}'::regnamespace
AND relname ILIKE 'ix_vector%'
AND pc.relkind = 'i'
"""
)
self._vector_index = session.execute(query).scalar_one_or_none()
return self._vector_index
def query(
self,
session: Session,
data: Iterable[Numeric],
limit: int = 10,
filters: dict | None = None,
measure: IndexMeasure = IndexMeasure.cosine_distance,
include_value: bool = False,
include_metadata: bool = False,
ef_search: int = 40,
) -> list[Record] | list[str]:
"""
Executes a similarity search in the collection.
Args:
session: The database session for executing SQL commands.
data (Iterable[Numeric]): The vector to use as the query.
limit (int, optional): The maximum number of results to return. Defaults to 10.
filters (dict, optional): Filters to apply to the search. Defaults to None.
measure (IndexMeasure, optional): The distance measure to use for the search. Defaults to 'cosine_distance'.
include_value (bool, optional): Whether to include the distance value in the results. Defaults to False.
include_metadata (bool, optional): Whether to include the metadata in the results. Defaults to False.
ef_search (int, optional): Size of the dynamic candidate list for HNSW index search.
Returns:
list[Record] | list[str]: The result of the similarity search.
"""
if limit > 1000:
raise ArgError("limit must be <= 1000")
# Ensure valid measure
if measure not in INDEX_MEASURE_TO_SQLA_ACC:
raise ArgError("Invalid index measure")
distance_lambda = INDEX_MEASURE_TO_SQLA_ACC[measure]
distance_clause = distance_lambda(self.table.c.vec)(data)
cols = [self.table.c.id]
if include_value:
cols.append(distance_clause)
if include_metadata:
cols.append(self.table.c.metadata)
stmt = select(*cols)
if filters is not None:
stmt = stmt.filter(build_filters(self.table.c.metadata, filters)) # type: ignore
stmt = stmt.order_by(distance_clause)
stmt = stmt.limit(limit)
with session.begin():
if self.get_vector_index(session) is not None:
session.execute(text("SET LOCAL hnsw.ef_search = :ef_search").bindparams(ef_search=ef_search))
if len(cols) == 1:
return [str(x) for x in session.scalars(stmt)]
return [typing_cast(Record, x) for x in session.execute(stmt)]
class CollectionController:
def get_with_dimension_query(self, name: str) -> TextClause:
return text(
f"""
select
relname as table_name,
atttypmod as embedding_dim
from
pg_class pc
join pg_attribute pa
on pc.oid = pa.attrelid
where
pc.relnamespace = '{SCHEMANAME}'::regnamespace
and pc.relkind = 'r'
and pa.attname = 'vec'
and not pc.relname ^@ '_'
and pc.relname = :name
"""
).bindparams(name=name)
def get_or_create_collection(
self,
session: Session,
name: str,
dimension: int,
meta: MetaData,
) -> Collection:
"""
Get a vector collection by name, or create it if no collection with
*name* exists.
Args:
session (Session): The database session for executing SQL commands.
name (str): The name of the collection.
dimension (int): The dimensionality of the vectors in the collection.
meta (MetaData): The metadata associated with the DB schema.
Keyword Args:
pipeline (int): The dimensionality of the vectors in the collection.
Returns:
Collection: The created collection.
Raises:
CollectionAlreadyExists: If a collection with the same name already exists
"""
query = self.get_with_dimension_query(name)
with session.begin():
query_result = session.execute(query).fetchone()
if query_result is None:
# collection = self._create(sess)
collection = Collection(name, dimension, metadata=meta)
collection.table.create(session.bind) # client.engine
else:
_, collection_dimension = query_result
if dimension != collection_dimension:
raise MismatchedDimension(
"Dimensions reported by dimension argument and existing collection do not match"
)
collection = Collection(name, dimension, metadata=meta)
return collection
def list_collections(self, session: Session) -> list[Collection]:
"""
List all vector collections.
Args:
session (Session): The database session for executing SQL commands.
Returns:
list[Collection]: A list of all collections.
"""
metadata = self.client.meta
query = text(
f"""
SELECT
relname as table_name,
atttypmod as embedding_dim
FROM
pg_class pc
JOIN pg_attribute pa
ON pc.oid = pa.attrelid
WHERE
pc.relnamespace = '{SCHEMANAME}'::regnamespace
AND pc.relkind = 'r'
AND pa.attname = 'vec'
AND not pc.relname ^@ '_'
"""
)
with session.execute(query) as result:
return [Collection(name, dimension, metadata) for name, dimension in result]
def delete_collection(self, session: Session, name: str) -> None:
"""
Delete a vector collection. If no collection with requested name exists, does nothing.
Args:
session (Session): The database session for executing SQL commands.
name (str): The name of the collection.
Returns:
None
"""
query = self.get_with_dimension_query(name)
with session.begin():
query_result = session.execute(query).fetchone()
if query_result is None:
raise CollectionNotFound("No collection found with requested name")
name, dimension = query_result
collection = Collection(name, dimension, metadata=self.client.meta)
session.execute(DropTable(collection.table, if_exists=True))
# session.commit()
return
class Client:
"""
The `Client` class serves as an interface to a PostgreSQL database with pgvector support.
It facilitates the creation, retrieval, listing and deletion of vector collections,
while managing connections to the database.
A `Client` instance represents a connection to a PostgreSQL database. This connection can be used to create
and manipulate vector collections, where each collection is a group of vector records in a PostgreSQL table.
The `Client` class can be also supports usage as a context manager by __enter__ and __exit__ methods
to ensure the connection to the database is properly closed after operations, or used directly.
"""
def __init__(self, connection_string: str):
"""
Initialize a Client instance.
Args:
connection_string (str): A string representing the database connection information.
"""
self.engine = create_engine(connection_string, echo=True, pool_pre_ping=True)
self.meta = MetaData(schema=SCHEMANAME)
self.SessionLocal = sessionmaker(bind=self.engine)
# pgvector version assertion
self.vector_version: str = "0.5"
with self.SessionLocal() as sess:
with sess.begin():
sess.execute(text(f"create schema if not exists {SCHEMANAME};"))
sess.execute(text("create extension if not exists vector;"))
self.vector_version: str = sess.execute(
text("select installed_version from pg_available_extensions where name = 'vector' limit 1;")
).scalar_one()
if (
not self.vector_version.startswith("0.4")
and not self.vector_version.startswith("0.3")
and not self.vector_version.startswith("0.2")
and not self.vector_version.startswith("0.1")
and not self.vector_version.startswith("0.0")
):
raise ArgError("HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support")
# NOTE: session は解放されるが engine は存続し、その接続プールは引き続き利用可能。
# アプリケーションの終了時に破棄されれば良いので、ここでは明示的に dispose しない。
self.collection_controller = CollectionController(self)
def disconnect(self) -> None:
"""
Disconnect the client from the database.
Returns:
None
"""
self.engine.dispose()
return
def __enter__(self) -> Self:
"""
Enable use of the 'with' statement.
Returns:
Client: The current instance of the Client.
"""
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""
Disconnect the client on exiting the 'with' statement context.
Args:
exc_type: The exception type, if any.
exc_val: The exception value, if any.
exc_tb: The traceback, if any.
Returns:
None
"""
# import traceback
# if exc_type is not None:
# print(f"Exception type: {exc_type}")
# print(f"Exception value: {exc_val}")
# print("Traceback:")
# traceback.print_tb(exc_tb)
self.disconnect()
return
def get_or_create_collection(
self,
name: str,
dimension: int,
) -> Collection:
with self.SessionLocal() as sess:
return self.collection_controller.get_or_create_collection(sess, name, dimension)
def list_collections(self) -> list[Collection]:
with self.SessionLocal() as sess:
return self.collection_controller.list_collections(sess)
def delete_collection(self, name: str) -> None:
with self.SessionLocal() as sess:
return self.collection_controller.delete_collection(sess, name)
def create_client(connection_string: str) -> Client:
"""Creates a client from a Postgres connection string"""
return Client(connection_string)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment