Created
March 11, 2021 19:21
-
-
Save xhluca/157920dd54a3b959f3a5ad6097803f48 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from typing import List | |
try: | |
from sklearn.decomposition import TruncatedSVD | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.metrics.pairwise import cosine_similarity | |
except: | |
error_msg = ( | |
"Couldn't import scikit-learn. To use the toy models, you will need to " | |
"install it with `pip install scikit-learn`." | |
) | |
raise Exception(error_msg) | |
import numpy as np | |
class SearchEngine: | |
def __init__(self): | |
self.svd = TruncatedSVD(300) | |
self.vectorizer = TfidfVectorizer(max_df=0.9, min_df=1) | |
def build_knowledge_base(self, passages: List[dict]): | |
self.passages = np.array(passages) | |
self.contents = np.array([p["content"] for p in self.passages]) | |
content_tfidf = self.vectorizer.fit_transform(self.contents) | |
self.content_encs = self.svd.fit_transform(content_tfidf) | |
def retrieve_idx(self, query: str, k: int = 10) -> List[int]: | |
enc = self.vectorizer.transform([query]) | |
ls = self.svd.transform(enc) | |
sim_scores = cosine_similarity(ls, self.content_encs).squeeze() | |
best_idx = sim_scores.argsort()[::-1][:k].tolist() | |
return best_idx | |
def retrieve(self, query: str, k: int = 10) -> List[dict]: | |
best_idx = self.retrieve_idx(query, k) | |
best_candidates = self.passages[best_idx].tolist() | |
return best_candidates |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment