Last active
June 3, 2023 10:54
-
-
Save YAHYA-H/5b4c78c6bdd6b31dc702454a4c32b188 to your computer and use it in GitHub Desktop.
Seaarch engine with BM25
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
import nltk | |
from sklearn.datasets import fetch_20newsgroups | |
from rank_bm25 import BM25Okapi | |
from nltk.corpus import stopwords | |
from nltk.tokenize import word_tokenize | |
from nltk.stem import WordNetLemmatizer | |
nltk.download('punkt') | |
nltk.download('stopwords') | |
nltk.download('wordnet') | |
# Load dataset | |
data = fetch_20newsgroups(subset='all') | |
documents = data.data | |
# Preprocess documents | |
stop_words = set(stopwords.words("english")) | |
lemmatizer = WordNetLemmatizer() | |
def preprocess_text(text): | |
# Remove special characters and digits | |
text = re.sub(r"[^a-zA-Z]", " ", text) | |
# Tokenize | |
tokens = word_tokenize(text) | |
# Remove stopwords and lemmatize tokens | |
tokens = [lemmatizer.lemmatize(token) for token in tokens if token not in stop_words] | |
# Join tokens back into a single string | |
preprocessed_text = " ".join(tokens) | |
return preprocessed_text | |
preprocessed_documents = [preprocess_text(doc) for doc in documents] | |
tokenized_documents = [doc.split() for doc in preprocessed_documents] | |
# Fit BM25 | |
bm25 = BM25Okapi(tokenized_documents) | |
def search(query, top_k=5): | |
"""Given a query string, we'll preprocess it and calculate the BM25 scores | |
for each document. We'll sort the documents based on the scores | |
and retrieve the top results.""" | |
preprocessed_query = preprocess_text(query) | |
scores = bm25.get_scores(preprocessed_query) | |
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] | |
top_results = [(preprocessed_documents[i], scores[i]) for i in top_indices] | |
return top_results, scores | |
query = input("Query here: ") | |
results, scores = search(query) | |
print(f"Top {len(results)} results for the query '{query}':") | |
for i, (result, score) in enumerate(results, start=1): | |
print(f"\nResult {i}:") | |
print(f"Score: {score}") | |
print(result) | |
# # Access the search scores separately | |
# print("\nSearch Scores:") | |
# for i, score in enumerate(scores, start=1): | |
# print(f"Result {i} Score: {score}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment