-
-
Save VoVAllen/63b465f47f17e1997b0bdaf3c1eb1efd to your computer and use it in GitHub Desktop.
ANN Benchmark (https://github.com/erikbern/ann-benchmarks) modules for HNSW implementations of pgvector and pg_embedding
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
float: | |
any: | |
- base_args: ['@metric'] | |
constructor: PGEmbedding | |
disabled: false | |
docker_tag: ann-benchmarks-pg_embedding_hnsw | |
module: ann_benchmarks.algorithms.pg_embedding_hnsw | |
name: pg_embedding_hnsw | |
run_groups: | |
M-12: | |
arg_groups: [{M: 12, efConstruction: 60}] | |
args: {} | |
query_args: [[10, 20, 40, 80, 120, 200, 400, 600, 800]] | |
M-16: | |
arg_groups: [{M: 16, efConstruction: 40}] | |
args: {} | |
query_args: [[10, 20, 40, 80, 120, 200, 400, 600, 800]] | |
M-24: | |
arg_groups: [{M: 24, efConstruction: 40}] | |
args: {} | |
query_args: [[10, 20, 40, 80, 120, 200, 400, 600, 800]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: https://github.com/erikbern/ann-benchmarks/blob/main/LICENSE | |
import subprocess | |
import sys | |
import psycopg | |
from ..base.module import BaseANN | |
class PGEmbedding(BaseANN): | |
def __init__(self, metric, method_param): | |
self._metric = metric | |
self._m = method_param["M"] | |
self._ef_construction = method_param["efConstruction"] | |
self._cur = None | |
if metric == "angular": | |
self._query = "SELECT id FROM items ORDER BY embedding <=> %s::real[] LIMIT %s" | |
elif metric == "euclidean": | |
self._query = "SELECT id FROM items ORDER BY embedding <-> %s::real[] LIMIT %s" | |
else: | |
raise RuntimeError(f"unknown metric {metric}") | |
def fit(self, X): | |
# subprocess.run("service postgresql start", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr) | |
conn = psycopg.connect(user="ubuntu", password="ann", dbname="ann", host="/tmp", autocommit=True) | |
cur = conn.cursor() | |
self._cur = cur | |
cur.execute("CREATE TABLE items (id int, embedding real[])") | |
cur.execute("ALTER TABLE items ALTER COLUMN embedding SET STORAGE PLAIN") | |
print("copying data...") | |
with cur.copy("COPY items (id, embedding) FROM STDIN") as copy: | |
for i, embedding in enumerate(X): | |
copy.write_row((i, embedding.tolist())) | |
print("creating index...") | |
if self._metric == "angular": | |
cur.execute( | |
"CREATE INDEX ON items USING hnsw (embedding ann_cos_ops) WITH (dims=%d, m = %d, efConstruction = %d)" | |
% (X.shape[1], self._m, self._ef_construction) | |
) | |
elif self._metric == "euclidean": | |
cur.execute( | |
"CREATE INDEX ON items USING hnsw (embedding ann_l2_ops) WITH (dims=%d, m = %d, efConstruction = %d)" | |
% (X.shape[1], self._m, self._ef_construction) | |
) | |
else: | |
raise RuntimeError(f"unknown metric {self._metric}") | |
cur.execute("RESET min_parallel_table_scan_size") | |
print("vacuum and checkpoint") | |
cur.execute("VACUUM ANALYZE items;") | |
cur.execute("CHECKPOINT;") | |
print("warm cache") | |
cur.execute("SELECT pg_prewarm('items')") | |
cur.execute("SELECT pg_prewarm('items_embedding_idx')") | |
print("done!") | |
self._cur = cur | |
def set_query_arguments(self, ef_search): | |
self._ef_search = ef_search | |
self._cur.execute("ALTER INDEX items_embedding_idx SET ( efSearch = %d )" % self._ef_search) | |
self._cur.execute("SET work_mem = '4GB'") | |
def query(self, v, n): | |
self._cur.execute(self._query, (v.tolist(), n), binary=True, prepare=True) | |
return [id for id, in self._cur.fetchall()] | |
def get_memory_usage(self): | |
if self._cur is None: | |
return 0 | |
self._cur.execute("SELECT pg_relation_size('items_embedding_idx')") | |
return self._cur.fetchone()[0] / 1024 | |
def done(self): | |
self._cur.execute("DROP TABLE items") | |
self._cur.close() | |
def __str__(self): | |
return f"PGEmbeddingHNSW(m={self._m}, ef_construction={self._ef_construction}, ef_search={self._ef_search})" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
float: | |
any: | |
- base_args: ['@metric'] | |
constructor: PGVector | |
disabled: false | |
docker_tag: ann-benchmarks-pgvector-hnsw | |
module: ann_benchmarks.algorithms.pgvector_hnsw | |
name: pgvector_hnsw | |
run_groups: | |
M-12: | |
arg_groups: [{M: 12, efConstruction: 60}] | |
args: {} | |
query_args: [[10, 20, 40, 80, 120, 200, 400, 600, 800]] | |
M-16: | |
arg_groups: [{M: 16, efConstruction: 40}] | |
args: {} | |
query_args: [[10, 20, 40, 80, 120, 200, 400, 600, 800]] | |
M-24: | |
arg_groups: [{M: 24, efConstruction: 40}] | |
args: {} | |
query_args: [[10, 20, 40, 80, 120, 200, 400, 600, 800]] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# License: https://github.com/erikbern/ann-benchmarks/blob/main/LICENSE | |
import subprocess | |
import sys | |
import pgvector.psycopg | |
import psycopg | |
from ..base.module import BaseANN | |
class PGVector(BaseANN): | |
def __init__(self, metric, method_param): | |
self._metric = metric | |
self._m = method_param["M"] | |
self._ef_construction = method_param["efConstruction"] | |
self._cur = None | |
if metric == "angular": | |
self._query = "SELECT id FROM items ORDER BY embedding <=> %s LIMIT %s" | |
elif metric == "euclidean": | |
self._query = "SELECT id FROM items ORDER BY embedding <-> %s LIMIT %s" | |
else: | |
raise RuntimeError(f"unknown metric {metric}") | |
def fit(self, X): | |
conn = psycopg.connect(user="ubuntu", password="ann", dbname="ann", host="/tmp", autocommit=True) | |
pgvector.psycopg.register_vector(conn) | |
cur = conn.cursor() | |
self._cur = cur | |
cur.execute("CREATE TABLE IF NOT EXISTS items (id int, embedding vector(%d))" % X.shape[1]) | |
cur.execute("ALTER TABLE items ALTER COLUMN embedding SET STORAGE PLAIN") | |
print("copying data...") | |
with cur.copy("COPY items (id, embedding) FROM STDIN") as copy: | |
for i, embedding in enumerate(X): | |
copy.write_row((i, embedding)) | |
print("creating index...") | |
cur.execute("SET min_parallel_table_scan_size TO 1") | |
if self._metric == "angular": | |
cur.execute( | |
"CREATE INDEX ON items USING hnsw (embedding vector_cosine_ops) WITH (m = %s, ef_construction = %d)" % (self._m, self._ef_construction) | |
) | |
elif self._metric == "euclidean": | |
cur.execute("CREATE INDEX ON items USING hnsw (embedding vector_l2_ops) WITH (m = %s, ef_construction = %d)" % (self._m, self._ef_construction)) | |
else: | |
raise RuntimeError(f"unknown metric {self._metric}") | |
cur.execute("RESET min_parallel_table_scan_size") | |
print("vacuum and checkpoint") | |
cur.execute("VACUUM ANALYZE items;") | |
cur.execute("CHECKPOINT;") | |
print("warm cache") | |
cur.execute("SELECT pg_prewarm('items')") | |
cur.execute("SELECT pg_prewarm('items_embedding_idx')") | |
print("done!") | |
def set_query_arguments(self, ef_search): | |
self._ef_search = ef_search | |
self._cur.execute("SET hnsw.ef_search = %d" % ef_search) | |
self._cur.execute("SET work_mem = '4GB'") | |
def query(self, v, n): | |
self._cur.execute(self._query, (v, n), binary=True, prepare=True) | |
return [id for id, in self._cur.fetchall()] | |
def get_memory_usage(self): | |
if self._cur is None: | |
return 0 | |
self._cur.execute("SELECT pg_relation_size('items_embedding_idx')") | |
return self._cur.fetchone()[0] / 1024 | |
def done(self): | |
self._cur.execute("DROP TABLE items") | |
self._cur.close() | |
def __str__(self): | |
return f"PGVectorHNSW(m={self._m}, ef_construction={self._ef_construction}, ef_search={self._ef_search})" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment