|
#!/usr/bin/env python3 |
|
import asyncio |
|
import argparse |
|
from typing import List |
|
import numpy as np |
|
from tqdm import tqdm |
|
import psycopg |
|
from pgvector.psycopg import register_vector_async |
|
from faiss import Kmeans |
|
|
|
KEEPALIVE_KWARGS = { |
|
"keepalives": 1, |
|
"keepalives_idle": 30, |
|
"keepalives_interval": 5, |
|
"keepalives_count": 5, |
|
} |
|
|
|
SAMPLES_PER_CLUSTER = 10 |
|
NITER = 10 |
|
SEED = 42 |
|
|
|
|
|
def parse_args(): |
|
p = argparse.ArgumentParser( |
|
description="Sample vectors via tsm_system_rows, KMeans, export centroids" |
|
) |
|
p.add_argument( |
|
"--url", required=True, help="postgresql://user:pass@host:port/dbname" |
|
) |
|
p.add_argument( |
|
"--schema", |
|
default="public", |
|
help="Schema of the source table (default: public)", |
|
) |
|
p.add_argument("--table", required=True, help="Source table name") |
|
p.add_argument("--vector-col", required=True, help="Vector column name (pgvector)") |
|
p.add_argument( |
|
"--k", type=int, required=True, help="Number of clusters (centroids)" |
|
) |
|
p.add_argument( |
|
"--dim", type=int, required=True, help="Vector dimension (if omitted, inferred)" |
|
) |
|
p.add_argument( |
|
"--metric", |
|
choices=["l2", "cos"], |
|
default="cos", |
|
help="Distance for clustering normalization", |
|
) |
|
p.add_argument( |
|
"--output", |
|
default=None, |
|
help="Output table name for centroids (default: <table>_centroids)", |
|
) |
|
return p.parse_args() |
|
|
|
|
|
async def create_connection(url: str): |
|
conn = await psycopg.AsyncConnection.connect( |
|
conninfo=url, autocommit=True, **KEEPALIVE_KWARGS |
|
) |
|
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") |
|
await register_vector_async(conn) |
|
return conn |
|
|
|
|
|
async def fetch_sampled_vectors( |
|
conn: psycopg.AsyncConnection, |
|
schema: str, |
|
table: str, |
|
vcol: str, |
|
sample_rows: int, |
|
) -> List[np.ndarray]: |
|
await conn.execute(f"ANALYZE {schema}.{table}") |
|
async with conn.cursor() as cur: |
|
await cur.execute( |
|
"SELECT reltuples::BIGINT AS estimated_count FROM pg_class WHERE relname = %s", |
|
(table,), |
|
) |
|
row = await cur.fetchone() |
|
if not row: |
|
raise ValueError(f"Table {schema}.{table} not found") |
|
total_rows = int(row[0]) |
|
print(f"Estimated total rows in {schema}.{table}: {total_rows}") |
|
|
|
prob = 100 * sample_rows / total_rows |
|
if prob > 100.0: |
|
prob = 100.0 |
|
sql = f"SELECT {vcol} FROM {schema}.{table} TABLESAMPLE SYSTEM(%s)" |
|
vectors: List[np.ndarray] = [] |
|
fetch_size = 10_000 |
|
async with conn.cursor() as cur: |
|
await cur.execute(sql, (prob,)) |
|
batch = await cur.fetchmany(fetch_size) |
|
with tqdm( |
|
desc="Fetching sampled vectors", unit="vec", total=sample_rows |
|
) as pbar: |
|
while batch: |
|
for (v,) in batch: |
|
if v is not None: |
|
vectors.append(np.asarray(v, dtype=np.float32)) |
|
pbar.update(len(batch)) |
|
batch = await cur.fetchmany(fetch_size) |
|
if not vectors: |
|
raise ValueError("No vectors returned from sampling query.") |
|
return vectors |
|
|
|
|
|
def preprocess(X: np.ndarray, metric: str) -> np.ndarray: |
|
if metric == "cos": |
|
norms = np.linalg.norm(X, axis=1, keepdims=True) |
|
norms[norms == 0.0] = 1.0 |
|
return X / norms |
|
return X |
|
|
|
|
|
def run_kmeans(X: np.ndarray, dim: int, metric: str, k: int) -> np.ndarray: |
|
if k > len(X): |
|
raise ValueError( |
|
f"k={k} is greater than number of sampled vectors ({len(X)}). Reduce k or increase sample size." |
|
) |
|
kmeans = Kmeans( |
|
dim, |
|
k, |
|
gpu=False, |
|
verbose=True, |
|
niter=NITER, |
|
seed=SEED, |
|
spherical=metric != "l2", |
|
) |
|
kmeans.train(X) |
|
return kmeans.centroids |
|
|
|
|
|
async def write_centroids( |
|
conn: psycopg.AsyncConnection, |
|
output_table: str, |
|
dim: int, |
|
centroids: np.ndarray, |
|
): |
|
await conn.execute(f"DROP TABLE IF EXISTS {output_table}") |
|
await conn.execute( |
|
f"CREATE TABLE {output_table} (id integer, parent integer, vector vector({dim}))" |
|
) |
|
async with conn.cursor().copy( |
|
f"COPY {output_table} (id, parent, vector) FROM STDIN WITH (FORMAT BINARY)" |
|
) as copy: |
|
copy.set_types(["integer", "integer", "vector"]) |
|
for i, c in enumerate(centroids): |
|
await copy.write_row((i, None, c)) |
|
while conn.pgconn.flush() == 1: |
|
await asyncio.sleep(0) |
|
print(f"Wrote {len(centroids)} centroids to {output_table}") |
|
|
|
|
|
async def main(): |
|
args = parse_args() |
|
out_table = args.output or f"{args.table}_centroids" |
|
conn = await create_connection(args.url) |
|
dim = args.dim |
|
sample_rows = SAMPLES_PER_CLUSTER * args.k |
|
vecs_list = await fetch_sampled_vectors( |
|
conn, |
|
args.schema, |
|
args.table, |
|
args.vector_col, |
|
sample_rows, |
|
) |
|
X = np.vstack(vecs_list) |
|
print(f"Collected {len(X)} sampled vectors of dim={X.shape[1]}") |
|
if X.shape[1] != dim: |
|
raise ValueError( |
|
f"Vector dim mismatch: inferred/provided dim={dim}, sampled dim={X.shape[1]}" |
|
) |
|
X_proc = preprocess(X, args.metric) |
|
centers = run_kmeans(X_proc, args.dim, args.metric, args.k) |
|
centers = preprocess(centers, args.metric) |
|
await write_centroids(conn, f"{args.schema}.{out_table}", dim, centers) |
|
|
|
|
|
if __name__ == "__main__": |
|
asyncio.run(main()) |