Skip to content

Instantly share code, notes, and snippets.

@cutecutecat
Last active September 24, 2025 10:47
Show Gist options
  • Save cutecutecat/b0ae54ce8d7c499a6fe9e0fb6e0380c5 to your computer and use it in GitHub Desktop.
Save cutecutecat/b0ae54ce8d7c499a6fe9e0fb6e0380c5 to your computer and use it in GitHub Desktop.
VectorChord sampled K-Means
conda create -n data PYTHON=3.11
conda activate data
conda install conda-forge::pgvector-python numpy pytorch::faiss-cpu conda-forge::psycopg tqdm

EXPORT OMP_WAIT_POLICY=PASSIVE

python sample_kmeans_export.py \
  --url "postgresql://postgres:123@localhost:5432/postgres" \
  --table embeddings2 \
  --vector-col dense \
  --k 100000 \
  --dim 384 \
  --metric cos \
  --output centroids
CREATE INDEX ON embeddings2 USING vchordrq (dense vector_cosine_ops) WITH (options = $$
residual_quantization = true
build.pin = true
[build.external]
table = 'public.centroids'
$$);
#!/usr/bin/env python3
import asyncio
import argparse
from typing import List
import numpy as np
from tqdm import tqdm
import psycopg
from pgvector.psycopg import register_vector_async
from faiss import Kmeans
KEEPALIVE_KWARGS = {
"keepalives": 1,
"keepalives_idle": 30,
"keepalives_interval": 5,
"keepalives_count": 5,
}
SAMPLES_PER_CLUSTER = 10
NITER = 10
SEED = 42
def parse_args():
p = argparse.ArgumentParser(
description="Sample vectors via tsm_system_rows, KMeans, export centroids"
)
p.add_argument(
"--url", required=True, help="postgresql://user:pass@host:port/dbname"
)
p.add_argument(
"--schema",
default="public",
help="Schema of the source table (default: public)",
)
p.add_argument("--table", required=True, help="Source table name")
p.add_argument("--vector-col", required=True, help="Vector column name (pgvector)")
p.add_argument(
"--k", type=int, required=True, help="Number of clusters (centroids)"
)
p.add_argument(
"--dim", type=int, required=True, help="Vector dimension (if omitted, inferred)"
)
p.add_argument(
"--metric",
choices=["l2", "cos"],
default="cos",
help="Distance for clustering normalization",
)
p.add_argument(
"--output",
default=None,
help="Output table name for centroids (default: <table>_centroids)",
)
return p.parse_args()
async def create_connection(url: str):
conn = await psycopg.AsyncConnection.connect(
conninfo=url, autocommit=True, **KEEPALIVE_KWARGS
)
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
await register_vector_async(conn)
return conn
async def fetch_sampled_vectors(
conn: psycopg.AsyncConnection,
schema: str,
table: str,
vcol: str,
sample_rows: int,
) -> List[np.ndarray]:
await conn.execute(f"ANALYZE {schema}.{table}")
async with conn.cursor() as cur:
await cur.execute(
"SELECT reltuples::BIGINT AS estimated_count FROM pg_class WHERE relname = %s",
(table,),
)
row = await cur.fetchone()
if not row:
raise ValueError(f"Table {schema}.{table} not found")
total_rows = int(row[0])
print(f"Estimated total rows in {schema}.{table}: {total_rows}")
prob = 100 * sample_rows / total_rows
if prob > 100.0:
prob = 100.0
sql = f"SELECT {vcol} FROM {schema}.{table} TABLESAMPLE SYSTEM(%s)"
vectors: List[np.ndarray] = []
fetch_size = 10_000
async with conn.cursor() as cur:
await cur.execute(sql, (prob,))
batch = await cur.fetchmany(fetch_size)
with tqdm(
desc="Fetching sampled vectors", unit="vec", total=sample_rows
) as pbar:
while batch:
for (v,) in batch:
if v is not None:
vectors.append(np.asarray(v, dtype=np.float32))
pbar.update(len(batch))
batch = await cur.fetchmany(fetch_size)
if not vectors:
raise ValueError("No vectors returned from sampling query.")
return vectors
def preprocess(X: np.ndarray, metric: str) -> np.ndarray:
if metric == "cos":
norms = np.linalg.norm(X, axis=1, keepdims=True)
norms[norms == 0.0] = 1.0
return X / norms
return X
def run_kmeans(X: np.ndarray, dim: int, metric: str, k: int) -> np.ndarray:
if k > len(X):
raise ValueError(
f"k={k} is greater than number of sampled vectors ({len(X)}). Reduce k or increase sample size."
)
kmeans = Kmeans(
dim,
k,
gpu=False,
verbose=True,
niter=NITER,
seed=SEED,
spherical=metric != "l2",
)
kmeans.train(X)
return kmeans.centroids
async def write_centroids(
conn: psycopg.AsyncConnection,
output_table: str,
dim: int,
centroids: np.ndarray,
):
await conn.execute(f"DROP TABLE IF EXISTS {output_table}")
await conn.execute(
f"CREATE TABLE {output_table} (id integer, parent integer, vector vector({dim}))"
)
async with conn.cursor().copy(
f"COPY {output_table} (id, parent, vector) FROM STDIN WITH (FORMAT BINARY)"
) as copy:
copy.set_types(["integer", "integer", "vector"])
for i, c in enumerate(centroids):
await copy.write_row((i, None, c))
while conn.pgconn.flush() == 1:
await asyncio.sleep(0)
print(f"Wrote {len(centroids)} centroids to {output_table}")
async def main():
args = parse_args()
out_table = args.output or f"{args.table}_centroids"
conn = await create_connection(args.url)
dim = args.dim
sample_rows = SAMPLES_PER_CLUSTER * args.k
vecs_list = await fetch_sampled_vectors(
conn,
args.schema,
args.table,
args.vector_col,
sample_rows,
)
X = np.vstack(vecs_list)
print(f"Collected {len(X)} sampled vectors of dim={X.shape[1]}")
if X.shape[1] != dim:
raise ValueError(
f"Vector dim mismatch: inferred/provided dim={dim}, sampled dim={X.shape[1]}"
)
X_proc = preprocess(X, args.metric)
centers = run_kmeans(X_proc, args.dim, args.metric, args.k)
centers = preprocess(centers, args.metric)
await write_centroids(conn, f"{args.schema}.{out_table}", dim, centers)
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment