Last active
January 11, 2025 14:59
-
-
Save grahama1970/3e5127bbd4a4e69f3ae8bcd96fb056a1 to your computer and use it in GitHub Desktop.
This script implements a hybrid search system using ArangoDB that combines: 1. Vector similarity search using COSINE_SIMILARITY 2. BM25 text search with custom text analyzer 3. Fuzzy string matching using Levenshtein distance
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from loguru import logger | |
from typing import List, Dict | |
def load_aql_query(filename: str) -> str: | |
""" | |
Load an AQL query from a file. | |
""" | |
try: | |
file_path = os.path.join("app/backend/vllm/beta/utils/aql", filename) | |
with open(file_path, "r") as file: | |
return file.read() | |
except FileNotFoundError: | |
logger.error(f"File not found: {filename}") | |
raise | |
except Exception as e: | |
logger.error(f"Error loading AQL query {filename}: {e}") | |
raise | |
def format_aql_query(aql_query, bind_vars): | |
formatted_query = aql_query | |
for key, value in bind_vars.items(): | |
placeholder = f"@{key}" | |
if isinstance(value, str): | |
value = f'"{value}"' # Add quotes around strings | |
elif isinstance(value, list): # Convert lists to array format | |
value = "[" + ", ".join(map(str, value)) + "]" | |
formatted_query = formatted_query.replace(placeholder, str(value)) | |
return formatted_query | |
def merge_duplicate_results(results: List[Dict]) -> List[Dict]: | |
""" | |
Merge duplicate results by combining scores for documents with the same key. | |
""" | |
merged = {} | |
for result in results: | |
key = result["_key"] | |
if key not in merged: | |
merged[key] = result | |
else: | |
for score in ["similarity_score", "bm25_score", "text_similarity"]: | |
merged[key][score] = max(merged[key][score], result.get(score, 0)) | |
return sorted( | |
merged.values(), | |
key=lambda x: (x.get("exact_match", False), x["text_similarity"], x["similarity_score"], x["bm25_score"]), | |
reverse=True, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
LET queryVector = @vector | |
LET search_text = @search_text | |
LET k1 = @k1 | |
LET b = @b | |
LET bm25_threshold = @bm25_threshold | |
// Vector similarity results | |
LET vector_results = ( | |
FOR doc IN @@view_name | |
LET similarity = COSINE_SIMILARITY(doc.vector, queryVector) | |
FILTER similarity >= @similarity_threshold | |
RETURN { | |
doc: doc, | |
_key: doc._key, | |
similarity_score: similarity, | |
bm25_score: 0 | |
} | |
) | |
// BM25 results | |
LET bm25_results = ( | |
FOR doc IN @@view_name | |
SEARCH ANALYZER(doc.description IN TOKENS(@search_text, "custom_text_analyzer"), "custom_text_analyzer") | |
LET bm25 = BM25(doc, @k1, @b) | |
FILTER bm25 >= @bm25_threshold | |
RETURN { | |
doc: doc, | |
_key: doc._key, | |
similarity_score: 0, | |
bm25_score: bm25 | |
} | |
) | |
// Merged results | |
LET merged_results = ( | |
FOR result IN UNION_DISTINCT(vector_results, bm25_results) | |
COLLECT key = result._key INTO group | |
LET doc = FIRST(group[*].result.doc) | |
LET similarity_score = MAX(group[*].result.similarity_score) | |
LET bm25_score = MAX(group[*].result.bm25_score) | |
RETURN { | |
doc: doc, | |
_key: key, | |
similarity_score: similarity_score, | |
bm25_score: bm25_score | |
} | |
) | |
// Final sorted results | |
FOR result IN merged_results | |
SORT result.similarity_score DESC, result.bm25_score DESC | |
LIMIT @top_n | |
RETURN result |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nltk.corpus import stopwords | |
import nltk | |
nltk.download('stopwords') | |
config = { | |
"arangodb": { | |
"hosts": "http://localhost:8529", | |
"db_name": "verifaix", | |
"username": "root", | |
"password": "openSesame", | |
}, | |
"search_view": { | |
"name": "combined_search_view", | |
"collections": ["microsoft_products", "microsoft_issues"], | |
"stopwords": stopwords.words("english"), | |
}, | |
"query_params": { | |
"bm25_threshold": 10, | |
"similarity_threshold": 0.5, | |
"top_n": 10, | |
"citation": False, | |
"levenshtein_threshold": 0.95, | |
}, | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This script implements a hybrid search system using ArangoDB that combines: | |
1. Vector similarity search using COSINE_SIMILARITY | |
2. BM25 text search with custom text analyzer | |
3. Fuzzy string matching using Levenshtein distance | |
""" | |
from arango import AQLQueryExecuteError, ArangoClient | |
from loguru import logger | |
import pyperclip | |
import asyncio | |
import os | |
import time | |
import warnings | |
from datetime import datetime, timezone | |
from typing import Dict, List, Union | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoModel, AutoTokenizer | |
from nltk.corpus import stopwords | |
from nltk import download | |
from rapidfuzz import fuzz | |
from app.backend.vllm.beta.utils.aql.aql_utils import load_aql_query | |
# Ensure NLTK stopwords are available (if needed for preprocessing) | |
download("stopwords") | |
from app.backend.vllm.beta.utils.embedding_utils import create_embedding | |
from app.backend.vllm.beta.utils.arango_utils import create_search_view | |
def apply_levenshtein_filter( | |
results: List[Dict], search_text: str, threshold: float | |
) -> List[Dict]: | |
""" | |
Filter results using Levenshtein distance. | |
Args: | |
results: List of search results. | |
search_text: Query text for comparison. | |
threshold: Minimum Levenshtein similarity score. | |
Returns: | |
Filtered and sorted results. | |
""" | |
cutoff = int(threshold * 100) | |
for result in results: | |
description = result["doc"]["description"].lower().strip() | |
search_text_cleaned = search_text.lower().strip() | |
result["levenshtein_score"] = max( | |
fuzz.token_sort_ratio(search_text_cleaned, description, score_cutoff=cutoff) / 100.0, | |
fuzz.partial_ratio(search_text_cleaned, description, score_cutoff=cutoff) / 100.0, | |
) | |
return sorted( | |
[res for res in results if res["levenshtein_score"] >= threshold], | |
key=lambda x: ( | |
x["levenshtein_score"] == 1.0, # Exact matches first | |
x["levenshtein_score"], # Then by Levenshtein score | |
x["similarity_score"], # Then by vector similarity | |
x["bm25_score"], # Finally by BM25 | |
), | |
reverse=True, | |
) | |
async def combined_search( | |
db, | |
view_name, | |
search_text, | |
query_params, | |
): | |
""" | |
Perform a combined search with BM25 and vector similarity, with optional citation matching. | |
Args: | |
db: ArangoDB database instance. | |
view_name: Name of the view to search. | |
search_text: Text to search for. | |
top_n: Number of results to return. | |
similarity_threshold: Minimum threshold for vector similarity. | |
bm25_threshold: Minimum threshold for BM25 scores. | |
citation: If True, applies strict Levenshtein matching for citations (default: False). | |
levenshtein_threshold: Minimum Levenshtein ratio threshold for citations (default: 0.95). | |
""" | |
try: | |
# Create embedding for the search text | |
embedding_result = await create_embedding(search_text) | |
vector = embedding_result["embedding"] | |
# Load AQL query | |
aql_query = load_aql_query("bm25_vector_rag_v2.aql") | |
bind_vars = { | |
"@view_name": view_name, | |
"vector": vector, | |
"search_text": search_text, | |
"top_n": query_params.get('top_n', 10), | |
"similarity_threshold": query_params.get('similarity_threshold', 0.5), | |
"bm25_threshold": query_params.get('bm25_threshold', 10), | |
"k1": query_params.get("k1", 1.5), # Default values for optional params | |
"b": query_params.get("b", 0.6), | |
} | |
# Execute Query in a thread to avoid blocking the async event loop | |
results = await asyncio.to_thread(db.aql.execute, aql_query, bind_vars=bind_vars) | |
results = list(results) | |
# Apply Levenshtein filtering if citation mode is enabled | |
if query_params.get('citation', False): | |
return apply_levenshtein_filter( | |
results, | |
search_text, | |
query_params.get('levenshtein_threshold', 0.95) | |
) | |
# Sort results by similarity score and BM25 score | |
top_n = query_params.get('top_n', 10) | |
return sorted( | |
results, | |
key=lambda x: (x["similarity_score"], x["bm25_score"]), | |
reverse=True, | |
)[:top_n] | |
except AQLQueryExecuteError as e: | |
logger.error(f"AQL execution error: {e.error_message} (Code: {e.error_code})") | |
except Exception as e: | |
logger.error(f"Error in combined_search: {e}") | |
return [] | |
async def main(): | |
from app.backend.vllm.beta.rag_config import config | |
# Connect to ArangoDB | |
client = ArangoClient(hosts=config["arangodb"]["hosts"]) | |
db = await asyncio.to_thread( | |
client.db, | |
config["arangodb"]["db_name"], | |
username=config["arangodb"]["username"], | |
password=config["arangodb"]["password"] | |
) | |
# Create custom search view | |
await create_search_view( | |
db, | |
config["search_view"]["name"], | |
config["search_view"].get("collections", []), | |
config["search_view"].get("stopwords", []), | |
) | |
# Perform Search | |
search_text = "Blue screen of death when updating Windows" | |
results = await combined_search( | |
db=db, | |
view_name=config["search_view"]["name"], | |
search_text=search_text, | |
query_params=config["query_params"] | |
) | |
# Display Results | |
for result in results: | |
print(result) | |
print("--------------------------------") | |
if __name__ == "__main__": | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment