Created
October 10, 2023 22:30
-
-
Save fsndzomga/843a32f5c51c9039c29265263080b5ee to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wikipedia | |
import torch | |
from transformers import BertTokenizer, BertModel | |
from elasticsearch import Elasticsearch | |
from config import USER, ELASTIC_PASSWORD | |
import warnings | |
import numpy as np | |
warnings.simplefilter(action='ignore', category=Warning) | |
def download_wikipedia_articles(titles): | |
articles = [] | |
for title in titles: | |
try: | |
content = wikipedia.page(title).content | |
articles.append({"title": title, "content": content}) | |
except wikipedia.exceptions.DisambiguationError as e: | |
print(f"Disambiguation error for {title}: {e.options}") | |
except wikipedia.exceptions.PageError: | |
print(f"Page not found for {title}") | |
return articles | |
def initialize_elastic_search(): | |
return Elasticsearch( | |
hosts="https://localhost:9200", | |
http_auth=(USER, ELASTIC_PASSWORD), | |
verify_certs=False | |
) | |
def text_to_embedding(text, tokenizer, model): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
return outputs['pooler_output'].numpy()[0] | |
def index_articles_to_elastic_search(articles, es, tokenizer, model): | |
for article in articles: | |
embedding = text_to_embedding(article["content"], tokenizer, model) | |
es.index(index="wikipedia", id=article["title"], body={"title": article["title"], "content": article["content"], "embedding": embedding.tolist()}) | |
def search(es, query): | |
body = { | |
"query": { | |
"bool": { | |
"must": [{"match": {"content": query}}], | |
} | |
} | |
} | |
response = es.search(index="wikipedia", body=body) | |
return response['hits']['hits'] | |
def cosine_similarity(v1, v2): | |
"""Compute cosine similarity between two vectors.""" | |
return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) | |
def search_and_rank_by_similarity(es, query, tokenizer, model): | |
# Initial search to get articles that contain the query in their content | |
results = search(es, query) | |
# If no results are found, return an empty list | |
if not results: | |
return [] | |
# Convert query to embedding | |
query_embedding = text_to_embedding(query, tokenizer, model) | |
# Get embeddings of the articles from the results | |
embeddings = [result["_source"]["embedding"] for result in results] | |
# Calculate cosine similarities | |
similarities = [cosine_similarity(query_embedding, np.array(embedding)) for embedding in embeddings] | |
# Sort results by similarity and get the top 2 | |
top_results = sorted(zip(results, similarities), key=lambda x: x[1], reverse=True)[:4] | |
print([(result[0]['_source']['content'][:200], result[1]) for result in top_results]) | |
return [result[0] for result in top_results] | |
def cli(es, tokenizer, model): | |
while True: | |
query = input("Enter your question (or type 'exit' to quit): ") | |
if query.lower() == 'exit': | |
break | |
results = search_and_rank_by_similarity(es, query, tokenizer, model) | |
if results: | |
print("Top 4 Wikipedia results based on similarity:") | |
for hit in results: | |
print(f"Title: {hit['_source']['title']}\nSnippet: {hit['_source']['content'][:200]}...\n") | |
else: | |
print("No relevant articles found.") | |
if __name__ == "__main__": | |
titles = ["president obama", "biden president", "president Trump", "Clinton", "president Bush", "macron president"] | |
articles = download_wikipedia_articles(titles) | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased') | |
es = initialize_elastic_search() | |
index_articles_to_elastic_search(articles, es, tokenizer, model) | |
cli(es, tokenizer, model) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment