Last active
July 8, 2022 17:54
-
-
Save generall/4d427a286a255d782d4c9ba3b6496032 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# File: neural_searcher.py | |
from qdrant_client import QdrantClient | |
from sentence_transformers import SentenceTransformer | |
class NeuralSearcher: | |
def __init__(self, collection_name): | |
self.collection_name = collection_name | |
# Initialize encoder model | |
self.model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens', device='cpu') | |
# initialize Qdrant client | |
self.qdrant_client = QdrantClient(host='localhost', port=6333) | |
# The search function looks as simple as possible: | |
def search(self, text: str): | |
# Convert text query into vector | |
vector = self.model.encode(text).tolist() | |
# Use `vector` for search for closest vectors in the collection | |
search_result = self.qdrant_client.search( | |
collection_name=self.collection_name, | |
query_vector=vector, | |
query_filter=None, # We don't want any filters for now | |
top=5 # 5 the most closest results is enough | |
) | |
# `search_result` contains found vector ids with similarity scores along with the stored payload | |
# In this function we are interested in payload only | |
payloads = [hit.payload for hit in search_result] | |
return payloads |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment