Created
May 20, 2022 21:05
-
-
Save generall/276ad1a161c34df799d10101374b8c3d to your computer and use it in GitHub Desktop.
Search 10k by 10k vectors fast
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import time | |
from multiprocessing import Pool | |
import httpx | |
import numpy as np | |
from grpclib.client import Channel | |
from qdrant_client import QdrantClient | |
from qdrant_client.grpc import PointsStub, WithPayloadSelector | |
from qdrant_client.http.models import Distance, OptimizersConfigDiff, \ | |
UpdateCollection, CollectionStatus | |
class Searcher: | |
client: QdrantClient = None | |
collection_name = None | |
top = None | |
channel = None | |
points_client: PointsStub = None | |
@classmethod | |
def init_client( | |
cls, | |
collection_name="sample", | |
host="localhost", | |
top=3, | |
): | |
cls.top = top | |
cls.collection_name = collection_name | |
cls.client = QdrantClient( | |
host=host, | |
limits=httpx.Limits(max_connections=None, max_keepalive_connections=0) | |
) | |
cls.channel = Channel(host, port=6334) | |
cls.points_client = PointsStub(cls.channel) | |
@classmethod | |
async def _search(cls, vector): | |
return await cls.points_client.search( | |
collection_name=cls.collection_name, | |
vector=list(vector), | |
filter=None, | |
top=cls.top, | |
with_vector=False, | |
with_payload=WithPayloadSelector(enable=False) | |
) | |
@classmethod | |
def search_one(cls, vector): | |
loop = asyncio.get_event_loop() | |
res = loop.run_until_complete(cls._search(vector)) | |
return res | |
def upload_data(qdrant_client, data, collection_name): | |
qdrant_client.recreate_collection( | |
collection_name=collection_name, | |
vector_size=data.shape[1], | |
distance=Distance.COSINE, | |
optimizers_config=OptimizersConfigDiff( | |
default_segment_number=3 | |
) | |
) | |
qdrant_client.upload_collection( | |
collection_name=collection_name, | |
vectors=data, | |
payload=None, | |
ids=None, | |
batch_size=256, | |
parallel=4 | |
) | |
qdrant_client.http.collections_api.update_collection( | |
collection_name=collection_name, | |
update_collection=UpdateCollection( | |
optimizers_config=OptimizersConfigDiff( | |
indexing_threshold=1000 | |
) | |
) | |
) | |
def wait_collection_green(client, collection_name): | |
wait_time = 1.0 | |
total = 0 | |
status = None | |
while status != CollectionStatus.GREEN: | |
time.sleep(wait_time) | |
total += wait_time | |
collection_info = client.openapi_client.collections_api.get_collection(collection_name) | |
status = collection_info.result.status | |
print(status) | |
return total | |
wait_collection_green(client=qdrant_client, collection_name=collection_name) | |
if __name__ == '__main__': | |
data = np.random.random((10000, 2200)) * 2 | |
data = np.around(data) | |
qdrant_client = QdrantClient( | |
host='localhost', | |
prefer_grpc=True, | |
port=6333, | |
) | |
collection_name = "sample" | |
upload_data( | |
qdrant_client=qdrant_client, | |
data=data, | |
collection_name=collection_name, | |
) | |
print("Indexes are ready, searching") | |
num_proc = 10 | |
start_time = time.time() | |
with Pool( | |
processes=num_proc, | |
initializer=Searcher.init_client, | |
initargs=(collection_name, "localhost", 3) | |
) as pool: | |
result = list(pool.imap( | |
Searcher.search_one, | |
iterable=data | |
)) | |
end_time = time.time() | |
print("Elapsed search time:", end_time - start_time) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment