Created
September 19, 2022 19:58
-
-
Save kacperlukawski/2d12faa49e06a5080f4c35ebcb89a2a3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import multiprocessing | |
import tempfile | |
import time | |
from pathlib import Path | |
from urllib.request import urlretrieve | |
import h5py | |
from qdrant_client import QdrantClient | |
from qdrant_client.conversions.common_types import VectorParams | |
from qdrant_client.http.models import Distance, SearchRequest | |
collection_name = "glove-25-angular" | |
client = QdrantClient("localhost", 6333, timeout=None) | |
client.recreate_collection( | |
collection_name=collection_name, | |
vectors_config=VectorParams(size=25, distance=Distance.COSINE), | |
) | |
class Searcher: | |
client: QdrantClient = None | |
@classmethod | |
def init_client(cls): | |
if cls.client is None: | |
cls.client = QdrantClient("localhost", 6333, timeout=None) | |
@classmethod | |
def search(cls, vector): | |
return cls.client.search( | |
collection_name=collection_name, | |
query_vector=vector, | |
limit=10, | |
) | |
@classmethod | |
def search_batch(cls, vectors): | |
return cls.client.search_batch( | |
collection_name=collection_name, | |
requests=[ | |
SearchRequest( | |
vector=vector, | |
limit=10, | |
) | |
for vector in vectors | |
] | |
) | |
Searcher.init_client() | |
with tempfile.TemporaryDirectory() as tmpdir: | |
# Load the glove-25-angular dataset from net | |
tmp_path = Path(tmpdir) | |
file_path = tmp_path / "glove-25-angular.hdf5" | |
urlretrieve("http://ann-benchmarks.com/glove-25-angular.hdf5", file_path) | |
data = h5py.File(file_path, "r", driver="stdio") | |
train_vectors = [vector.tolist() for vector in data["train"]] | |
test_vectors = [vector.tolist() for vector in data["test"]] | |
client.upload_collection( | |
collection_name=collection_name, | |
vectors=train_vectors, | |
batch_size=1024, | |
) | |
# Perform the sequential search | |
start = time.monotonic() | |
for vector in data["test"]: | |
Searcher.search(vector) | |
print("Sequential search:", time.monotonic() - start) | |
# Run the requests as a single batch | |
start = time.monotonic() | |
Searcher.search_batch(test_vectors) | |
print("Batch search:", time.monotonic() - start) | |
# Use many processes and sequential search | |
start = time.monotonic() | |
with multiprocessing.Pool( | |
processes=8, initializer=Searcher.init_client | |
) as pool: | |
results = list( | |
pool.imap_unordered(Searcher.search, test_vectors) | |
) | |
print("Multiprocessing search:", time.monotonic() - start) | |
# Use many processes and batch search | |
start = time.monotonic() | |
with multiprocessing.Pool( | |
processes=8, initializer=Searcher.init_client | |
) as pool: | |
size = 10 | |
test_batches = [ | |
test_vectors[pos:pos + size] for pos in range(0, len(test_vectors), size) | |
] | |
results = list( | |
pool.imap_unordered(Searcher.search_batch, test_batches) | |
) | |
print("Multiprocessing batch search:", time.monotonic() - start) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment