Last active
May 28, 2018 17:01
-
-
Save hamelsmu/9d3e23b23bb425bb35b28bfc02111e22 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class search_engine: | |
"""Organizes all the necessary elements we need to make a semantic search tool.""" | |
def __init__(self, | |
nmslib_index, | |
ref_df, | |
query2emb_func): | |
""" | |
Parameters | |
========== | |
nmslib_index : nmslib object | |
This is a pre-computed search index. | |
ref_df : pandas.DataFrame | |
This dataframe contains meta-data for search results. | |
must contain the columns 'code' and 'url' | |
query2emb_func : callable | |
A function that takes as input a string and returns a vector | |
that is in the same vector space as what is loaded into the | |
search index. | |
""" | |
assert 'url' in ref_df.columns | |
assert 'code' in ref_df.columns | |
self.search_index = nmslib_index | |
self.ref_df = ref_df | |
self.query2emb_func = query2emb_func | |
def search(self, str_search, k=2): | |
""" | |
Prints the code that are the nearest neighbors (by cosine distance) | |
to the search query. | |
Parameters | |
========== | |
str_search : str | |
a search query. Ex: "read data into pandas dataframe" | |
k : int | |
the number of nearest neighbors to return. Defaults to 2. | |
""" | |
query = self.query2emb_func(str_search) | |
idxs, dists = self.search_index.knnQuery(query, k=k) | |
for idx, dist in zip(idxs, dists): | |
code = self.ref_df.iloc[idx].code | |
url = self.ref_df.iloc[idx].url | |
print(f'cosine dist:{dist:.4f} url: {url}\n---------------\n') | |
print(code) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment