Created
September 6, 2023 14:26
-
-
Save VoVAllen/a83d2ee4b56a2a152019d768926f1a40 to your computer and use it in GitHub Desktop.
transaction test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import h5py | |
import psycopg2 | |
import psycopg2.extras | |
from math import sqrt | |
import ipdb | |
import numpy as np | |
from tqdm import tqdm | |
DIMS = 100 | |
DATASET = "./glove-100-angular.hdf5" | |
DATABASE = "dbname=postgres user=postgres host=127.0.0.1 port=5432" | |
def length(x): | |
ans = 0.0 | |
for i in range(DIMS): | |
ans += x[i] * x[i] | |
return sqrt(ans) | |
def norm(a): | |
row_sums = a.sum(axis=1) | |
new_matrix = a / row_sums[:, np.newaxis] | |
return new_matrix | |
CREATE_INDEX = """ | |
CREATE INDEX ON train USING hnsw (embedding {distance_op}) WITH (m = {m}, ef_construction = {ef_construct}); | |
""" | |
from itertools import islice | |
def chunk(it, size): | |
it = iter(it) | |
return iter(lambda: tuple(islice(it, size)), ()) | |
def main(): | |
first_half = [] | |
second_half = [] | |
with h5py.File(DATASET, "r") as dataset: | |
DISTANCES = "distances" | |
NEIGHTBOURS = "neighbors" | |
TEST = "test" | |
TRAIN = "train" | |
N = 10000 | |
assert list(dataset.keys()) == [DISTANCES, NEIGHTBOURS, TEST, TRAIN] | |
distances = dataset[DISTANCES] | |
neighbors = dataset[NEIGHTBOURS] | |
test = dataset[TEST] | |
target = str(list(dataset[TEST][0])) | |
train = dataset[TRAIN][:] | |
for i in tqdm(range(N)): | |
index = i | |
embedding = str(list(train[i])) | |
if i<=N//2: | |
first_half.append((index, embedding)) | |
else: | |
second_half.append((index, embedding)) | |
conn = psycopg2.connect(DATABASE) | |
DDL = """ | |
DROP TABLE IF EXISTS test; | |
DROP TABLE IF EXISTS train; | |
DROP EXTENSION IF EXISTS vector; | |
CREATE EXTENSION vector; | |
CREATE TABLE train (id integer PRIMARY KEY, embedding vector({dims}) NOT NULL); | |
""" | |
INSERT_TRAIN = "INSERT INTO train VALUES %s" | |
print("Start insert") | |
# Insert first half | |
with conn.cursor() as cursor: | |
cursor.execute(DDL.format(dims = DIMS)) | |
cursor.execute(CREATE_INDEX.format(distance_op = "vector_cosine_ops", m = 16, ef_construct = 40)) | |
cursor.execute("set hnsw.ef_search=16") | |
psycopg2.extras.execute_values(cursor, INSERT_TRAIN, first_half, page_size=512) | |
conn.commit() | |
print("Insert first half finished") | |
# New transaction insert second half, not commited | |
cursor = conn.cursor() | |
psycopg2.extras.execute_values(cursor, INSERT_TRAIN, second_half, page_size=512) | |
print("Insert second half finished") | |
# Another transaction, should not be able to see the second half | |
conn2 = psycopg2.connect(DATABASE) | |
cursor2 = conn2.cursor() | |
cursor2.execute("set hnsw.ef_search=32") | |
print("Start query") | |
cursor2.execute(""" | |
SELECT train.id | |
FROM train | |
ORDER BY '{}' <=> train.embedding | |
LIMIT 32""".format(target)) | |
result = cursor2.fetchall() | |
print(result) | |
print("Total num results: {}".format(len(result))) | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
import h5py | |
import psycopg2 | |
import psycopg2.extras | |
from math import sqrt | |
import ipdb | |
import numpy as np | |
from tqdm import tqdm | |
DIMS = 100 | |
DATASET = "./glove-100-angular.hdf5" | |
DATABASE = "dbname=postgres user=postgres host=127.0.0.1 port=5432" | |
def length(x): | |
ans = 0.0 | |
for i in range(DIMS): | |
ans += x[i] * x[i] | |
return sqrt(ans) | |
def norm(a): | |
row_sums = a.sum(axis=1) | |
new_matrix = a / row_sums[:, np.newaxis] | |
return new_matrix | |
CREATE_INDEX = """ | |
CREATE INDEX ON train USING vectors (embedding {distance_op}) | |
WITH (options = $$ | |
capacity = 2097152 | |
[vectors] | |
memmap = "ram" | |
[algorithm.hnsw] | |
memmap = "ram" | |
m = {m} | |
ef = {ef_construct} | |
$$); | |
""" | |
from itertools import islice | |
def chunk(it, size): | |
it = iter(it) | |
return iter(lambda: tuple(islice(it, size)), ()) | |
def main(): | |
first_half = [] | |
second_half = [] | |
with h5py.File(DATASET, "r") as dataset: | |
DISTANCES = "distances" | |
NEIGHTBOURS = "neighbors" | |
TEST = "test" | |
TRAIN = "train" | |
N = 10000 | |
assert list(dataset.keys()) == [DISTANCES, NEIGHTBOURS, TEST, TRAIN] | |
distances = dataset[DISTANCES] | |
neighbors = dataset[NEIGHTBOURS] | |
test = dataset[TEST] | |
target = str(list(dataset[TEST][0])) | |
train = dataset[TRAIN][:] | |
for i in tqdm(range(N)): | |
index = i | |
embedding = str(list(train[i])) | |
if i<=N//2: | |
first_half.append((index, embedding)) | |
else: | |
second_half.append((index, embedding)) | |
conn = psycopg2.connect(DATABASE) | |
DDL = """ | |
DROP TABLE IF EXISTS test; | |
DROP TABLE IF EXISTS train; | |
DROP EXTENSION IF EXISTS vectors; | |
CREATE EXTENSION vectors; | |
CREATE TABLE train (id integer PRIMARY KEY, embedding vector({dims}) NOT NULL); | |
""" | |
INSERT_TRAIN = "INSERT INTO train VALUES %s" | |
print("Start insert") | |
# Insert first half | |
with conn.cursor() as cursor: | |
cursor.execute(DDL.format(dims = DIMS)) | |
cursor.execute(CREATE_INDEX.format(distance_op = "cosine_ops", m = 16, ef_construct = 40)) | |
cursor.execute("set vectors.k=32") | |
psycopg2.extras.execute_values(cursor, INSERT_TRAIN, first_half, page_size=512) | |
conn.commit() | |
print("Insert first half finished") | |
# New transaction insert second half, not commited | |
cursor = conn.cursor() | |
psycopg2.extras.execute_values(cursor, INSERT_TRAIN, second_half, page_size=512) | |
print("Insert second half finished") | |
# Another transaction, should not be able to see the second half | |
conn2 = psycopg2.connect(DATABASE) | |
cursor2 = conn2.cursor() | |
cursor2.execute("set vectors.k=32") | |
print("Start query") | |
cursor2.execute(""" | |
SELECT train.id | |
FROM train | |
ORDER BY '{}' <=> train.embedding | |
LIMIT 32""".format(target)) | |
result = cursor2.fetchall() | |
print(result) | |
print("Total num results: {}".format(len(result))) | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment