Created
February 2, 2023 10:25
-
-
Save goncalossilva/d165c7e10a6df4eb7d3c6b8236570250 to your computer and use it in GitHub Desktop.
sentence-transformers-search.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
from string import punctuation | |
import gzip | |
import os | |
import sys | |
import re | |
import torch | |
def index_markdown_files(dir_path): | |
# Use the Bi-Encoder to encode all passages, so that we can use it with sematic search. | |
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') | |
bi_encoder.max_seq_length = 256 # Truncate long passages to 256 tokens. | |
# The bi-encoder will retrieve 100 documents. Use a cross-encoder, to re-rank the results list to improve the quality. | |
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') | |
# Punctuation regexp. | |
re_punctuation = re.compile(r'[\s{}]+'.format(re.escape(punctuation))) | |
# Walk through the directory tree | |
passages = [] | |
for root, dirs, files in os.walk(dir_path): | |
for file in files: | |
print(".", end="", flush=True) | |
# Check if file is a markdown file | |
if file.endswith(".md"): | |
# Read paragraphs | |
with open(os.path.join(root, file), "r") as f: | |
content = f.read() | |
# Add paragraphs. | |
passages.extend([p for p in content.split('\n\n') if not p.startswith("#")]) | |
# Add sentences. | |
passages.extend(re_punctuation.split(content)) | |
print("Passages:", len(passages)) | |
# Encode all passages into our vector space. | |
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True) | |
return (passages, bi_encoder, cross_encoder, corpus_embeddings) | |
def search_markdown_files(passages, bi_encoder, cross_encoder, corpus_embeddings, query): | |
top_k = 32 # Number of passages we want to retrieve with the bi-encoder. | |
# Encode the query using the bi-encoder and find potentially relevant passages | |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True) | |
#question_embedding = question_embedding.cuda() | |
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) | |
hits = hits[0] # Get the hits for the first query | |
# Score all retrieved passages with the cross_encoder | |
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits] | |
cross_scores = cross_encoder.predict(cross_inp) | |
# Sort results by the cross-encoder scores | |
for idx in range(len(cross_scores)): | |
hits[idx]['cross-score'] = cross_scores[idx] | |
hit = hits[0] | |
print(passages[hit['corpus_id']].replace("\n", " ")) | |
# Pass root folder as arg, e.g., python sentence-transformers-search.py docs/ | |
if __name__ == "__main__": | |
(passages, bi_encoder, cross_encoder, corpus_embeddings) = index_markdown_files(sys.argv[1]) | |
while True: | |
query = input("Q: ") | |
print("A: ", end="") | |
search_markdown_files(passages, bi_encoder, cross_encoder, corpus_embeddings, query) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment