Skip to content

Instantly share code, notes, and snippets.

@generall
Created May 20, 2022 21:05
Show Gist options
  • Save generall/276ad1a161c34df799d10101374b8c3d to your computer and use it in GitHub Desktop.
Save generall/276ad1a161c34df799d10101374b8c3d to your computer and use it in GitHub Desktop.
Search 10k by 10k vectors fast
import asyncio
import time
from multiprocessing import Pool
import httpx
import numpy as np
from grpclib.client import Channel
from qdrant_client import QdrantClient
from qdrant_client.grpc import PointsStub, WithPayloadSelector
from qdrant_client.http.models import Distance, OptimizersConfigDiff, \
UpdateCollection, CollectionStatus
class Searcher:
client: QdrantClient = None
collection_name = None
top = None
channel = None
points_client: PointsStub = None
@classmethod
def init_client(
cls,
collection_name="sample",
host="localhost",
top=3,
):
cls.top = top
cls.collection_name = collection_name
cls.client = QdrantClient(
host=host,
limits=httpx.Limits(max_connections=None, max_keepalive_connections=0)
)
cls.channel = Channel(host, port=6334)
cls.points_client = PointsStub(cls.channel)
@classmethod
async def _search(cls, vector):
return await cls.points_client.search(
collection_name=cls.collection_name,
vector=list(vector),
filter=None,
top=cls.top,
with_vector=False,
with_payload=WithPayloadSelector(enable=False)
)
@classmethod
def search_one(cls, vector):
loop = asyncio.get_event_loop()
res = loop.run_until_complete(cls._search(vector))
return res
def upload_data(qdrant_client, data, collection_name):
qdrant_client.recreate_collection(
collection_name=collection_name,
vector_size=data.shape[1],
distance=Distance.COSINE,
optimizers_config=OptimizersConfigDiff(
default_segment_number=3
)
)
qdrant_client.upload_collection(
collection_name=collection_name,
vectors=data,
payload=None,
ids=None,
batch_size=256,
parallel=4
)
qdrant_client.http.collections_api.update_collection(
collection_name=collection_name,
update_collection=UpdateCollection(
optimizers_config=OptimizersConfigDiff(
indexing_threshold=1000
)
)
)
def wait_collection_green(client, collection_name):
wait_time = 1.0
total = 0
status = None
while status != CollectionStatus.GREEN:
time.sleep(wait_time)
total += wait_time
collection_info = client.openapi_client.collections_api.get_collection(collection_name)
status = collection_info.result.status
print(status)
return total
wait_collection_green(client=qdrant_client, collection_name=collection_name)
if __name__ == '__main__':
data = np.random.random((10000, 2200)) * 2
data = np.around(data)
qdrant_client = QdrantClient(
host='localhost',
prefer_grpc=True,
port=6333,
)
collection_name = "sample"
upload_data(
qdrant_client=qdrant_client,
data=data,
collection_name=collection_name,
)
print("Indexes are ready, searching")
num_proc = 10
start_time = time.time()
with Pool(
processes=num_proc,
initializer=Searcher.init_client,
initargs=(collection_name, "localhost", 3)
) as pool:
result = list(pool.imap(
Searcher.search_one,
iterable=data
))
end_time = time.time()
print("Elapsed search time:", end_time - start_time)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment