Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Last active January 11, 2025 14:59
Show Gist options
  • Save grahama1970/3e5127bbd4a4e69f3ae8bcd96fb056a1 to your computer and use it in GitHub Desktop.
Save grahama1970/3e5127bbd4a4e69f3ae8bcd96fb056a1 to your computer and use it in GitHub Desktop.
This script implements a hybrid search system using ArangoDB that combines: 1. Vector similarity search using COSINE_SIMILARITY 2. BM25 text search with custom text analyzer 3. Fuzzy string matching using Levenshtein distance
import os
from loguru import logger
from typing import List, Dict
def load_aql_query(filename: str) -> str:
"""
Load an AQL query from a file.
"""
try:
file_path = os.path.join("app/backend/vllm/beta/utils/aql", filename)
with open(file_path, "r") as file:
return file.read()
except FileNotFoundError:
logger.error(f"File not found: {filename}")
raise
except Exception as e:
logger.error(f"Error loading AQL query {filename}: {e}")
raise
def format_aql_query(aql_query, bind_vars):
formatted_query = aql_query
for key, value in bind_vars.items():
placeholder = f"@{key}"
if isinstance(value, str):
value = f'"{value}"' # Add quotes around strings
elif isinstance(value, list): # Convert lists to array format
value = "[" + ", ".join(map(str, value)) + "]"
formatted_query = formatted_query.replace(placeholder, str(value))
return formatted_query
def merge_duplicate_results(results: List[Dict]) -> List[Dict]:
"""
Merge duplicate results by combining scores for documents with the same key.
"""
merged = {}
for result in results:
key = result["_key"]
if key not in merged:
merged[key] = result
else:
for score in ["similarity_score", "bm25_score", "text_similarity"]:
merged[key][score] = max(merged[key][score], result.get(score, 0))
return sorted(
merged.values(),
key=lambda x: (x.get("exact_match", False), x["text_similarity"], x["similarity_score"], x["bm25_score"]),
reverse=True,
)
LET queryVector = @vector
LET search_text = @search_text
LET k1 = @k1
LET b = @b
LET bm25_threshold = @bm25_threshold
// Vector similarity results
LET vector_results = (
FOR doc IN @@view_name
LET similarity = COSINE_SIMILARITY(doc.vector, queryVector)
FILTER similarity >= @similarity_threshold
RETURN {
doc: doc,
_key: doc._key,
similarity_score: similarity,
bm25_score: 0
}
)
// BM25 results
LET bm25_results = (
FOR doc IN @@view_name
SEARCH ANALYZER(doc.description IN TOKENS(@search_text, "custom_text_analyzer"), "custom_text_analyzer")
LET bm25 = BM25(doc, @k1, @b)
FILTER bm25 >= @bm25_threshold
RETURN {
doc: doc,
_key: doc._key,
similarity_score: 0,
bm25_score: bm25
}
)
// Merged results
LET merged_results = (
FOR result IN UNION_DISTINCT(vector_results, bm25_results)
COLLECT key = result._key INTO group
LET doc = FIRST(group[*].result.doc)
LET similarity_score = MAX(group[*].result.similarity_score)
LET bm25_score = MAX(group[*].result.bm25_score)
RETURN {
doc: doc,
_key: key,
similarity_score: similarity_score,
bm25_score: bm25_score
}
)
// Final sorted results
FOR result IN merged_results
SORT result.similarity_score DESC, result.bm25_score DESC
LIMIT @top_n
RETURN result
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')
config = {
"arangodb": {
"hosts": "http://localhost:8529",
"db_name": "verifaix",
"username": "root",
"password": "openSesame",
},
"search_view": {
"name": "combined_search_view",
"collections": ["microsoft_products", "microsoft_issues"],
"stopwords": stopwords.words("english"),
},
"query_params": {
"bm25_threshold": 10,
"similarity_threshold": 0.5,
"top_n": 10,
"citation": False,
"levenshtein_threshold": 0.95,
},
}
"""
This script implements a hybrid search system using ArangoDB that combines:
1. Vector similarity search using COSINE_SIMILARITY
2. BM25 text search with custom text analyzer
3. Fuzzy string matching using Levenshtein distance
"""
from arango import AQLQueryExecuteError, ArangoClient
from loguru import logger
import pyperclip
import asyncio
import os
import time
import warnings
from datetime import datetime, timezone
from typing import Dict, List, Union
import torch
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from nltk.corpus import stopwords
from nltk import download
from rapidfuzz import fuzz
from app.backend.vllm.beta.utils.aql.aql_utils import load_aql_query
# Ensure NLTK stopwords are available (if needed for preprocessing)
download("stopwords")
from app.backend.vllm.beta.utils.embedding_utils import create_embedding
from app.backend.vllm.beta.utils.arango_utils import create_search_view
def apply_levenshtein_filter(
results: List[Dict], search_text: str, threshold: float
) -> List[Dict]:
"""
Filter results using Levenshtein distance.
Args:
results: List of search results.
search_text: Query text for comparison.
threshold: Minimum Levenshtein similarity score.
Returns:
Filtered and sorted results.
"""
cutoff = int(threshold * 100)
for result in results:
description = result["doc"]["description"].lower().strip()
search_text_cleaned = search_text.lower().strip()
result["levenshtein_score"] = max(
fuzz.token_sort_ratio(search_text_cleaned, description, score_cutoff=cutoff) / 100.0,
fuzz.partial_ratio(search_text_cleaned, description, score_cutoff=cutoff) / 100.0,
)
return sorted(
[res for res in results if res["levenshtein_score"] >= threshold],
key=lambda x: (
x["levenshtein_score"] == 1.0, # Exact matches first
x["levenshtein_score"], # Then by Levenshtein score
x["similarity_score"], # Then by vector similarity
x["bm25_score"], # Finally by BM25
),
reverse=True,
)
async def combined_search(
db,
view_name,
search_text,
query_params,
):
"""
Perform a combined search with BM25 and vector similarity, with optional citation matching.
Args:
db: ArangoDB database instance.
view_name: Name of the view to search.
search_text: Text to search for.
top_n: Number of results to return.
similarity_threshold: Minimum threshold for vector similarity.
bm25_threshold: Minimum threshold for BM25 scores.
citation: If True, applies strict Levenshtein matching for citations (default: False).
levenshtein_threshold: Minimum Levenshtein ratio threshold for citations (default: 0.95).
"""
try:
# Create embedding for the search text
embedding_result = await create_embedding(search_text)
vector = embedding_result["embedding"]
# Load AQL query
aql_query = load_aql_query("bm25_vector_rag_v2.aql")
bind_vars = {
"@view_name": view_name,
"vector": vector,
"search_text": search_text,
"top_n": query_params.get('top_n', 10),
"similarity_threshold": query_params.get('similarity_threshold', 0.5),
"bm25_threshold": query_params.get('bm25_threshold', 10),
"k1": query_params.get("k1", 1.5), # Default values for optional params
"b": query_params.get("b", 0.6),
}
# Execute Query in a thread to avoid blocking the async event loop
results = await asyncio.to_thread(db.aql.execute, aql_query, bind_vars=bind_vars)
results = list(results)
# Apply Levenshtein filtering if citation mode is enabled
if query_params.get('citation', False):
return apply_levenshtein_filter(
results,
search_text,
query_params.get('levenshtein_threshold', 0.95)
)
# Sort results by similarity score and BM25 score
top_n = query_params.get('top_n', 10)
return sorted(
results,
key=lambda x: (x["similarity_score"], x["bm25_score"]),
reverse=True,
)[:top_n]
except AQLQueryExecuteError as e:
logger.error(f"AQL execution error: {e.error_message} (Code: {e.error_code})")
except Exception as e:
logger.error(f"Error in combined_search: {e}")
return []
async def main():
from app.backend.vllm.beta.rag_config import config
# Connect to ArangoDB
client = ArangoClient(hosts=config["arangodb"]["hosts"])
db = await asyncio.to_thread(
client.db,
config["arangodb"]["db_name"],
username=config["arangodb"]["username"],
password=config["arangodb"]["password"]
)
# Create custom search view
await create_search_view(
db,
config["search_view"]["name"],
config["search_view"].get("collections", []),
config["search_view"].get("stopwords", []),
)
# Perform Search
search_text = "Blue screen of death when updating Windows"
results = await combined_search(
db=db,
view_name=config["search_view"]["name"],
search_text=search_text,
query_params=config["query_params"]
)
# Display Results
for result in results:
print(result)
print("--------------------------------")
if __name__ == "__main__":
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment