Created
October 10, 2023 21:55
-
-
Save fsndzomga/65c171e8b34474a8e4e23f1153c71f2a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import wikipedia | |
import torch | |
from transformers import BertTokenizer, BertModel | |
from elasticsearch import Elasticsearch | |
from config import USER, ELASTIC_PASSWORD | |
import warnings | |
warnings.simplefilter(action='ignore', category=Warning) | |
def download_wikipedia_articles(titles): | |
articles = [] | |
for title in titles: | |
try: | |
content = wikipedia.page(title).content | |
articles.append({"title": title, "content": content}) | |
except wikipedia.exceptions.DisambiguationError as e: | |
print(f"Disambiguation error for {title}: {e.options}") | |
except wikipedia.exceptions.PageError: | |
print(f"Page not found for {title}") | |
return articles | |
def initialize_elastic_search(): | |
return Elasticsearch( | |
hosts="https://localhost:9200", | |
http_auth=(USER, ELASTIC_PASSWORD), | |
verify_certs=False | |
) | |
def text_to_embedding(text, tokenizer, model): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
return outputs['pooler_output'].numpy()[0] | |
def index_articles_to_elastic_search(articles, es, tokenizer, model): | |
for article in articles: | |
embedding = text_to_embedding(article["content"], tokenizer, model) | |
es.index(index="wikipedia", id=article["title"], body={"title": article["title"], "content": article["content"], "embedding": embedding.tolist()}) | |
def search(es, query): | |
body = { | |
"query": { | |
"bool": { | |
"must": [{"match": {"content": query}}], | |
} | |
} | |
} | |
response = es.search(index="wikipedia", body=body) | |
return response['hits']['hits'] | |
def cli(es): | |
while True: | |
query = input("Enter your question (or type 'exit' to quit): ") | |
if query.lower() == 'exit': | |
break | |
results = search(es, query) | |
if results: | |
print("Top Wikipedia results:") | |
for hit in results: | |
print(f"Title: {hit['_source']['title']}\nSnippet: {hit['_source']['content'][:200]}...\n") | |
else: | |
print("No relevant articles found.") | |
if __name__ == "__main__": | |
titles = ["obama", "biden president", "trump", "clinton", "bush"] | |
articles = download_wikipedia_articles(titles) | |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
model = BertModel.from_pretrained('bert-base-uncased') | |
es = initialize_elastic_search() | |
index_articles_to_elastic_search(articles, es, tokenizer, model) | |
cli(es) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment