Last active
January 16, 2025 14:03
-
-
Save grahama1970/44f6fe6b1ecfee145b8275b0227227cf to your computer and use it in GitHub Desktop.
ArangoDB hybrid search implementation combining BM25 text search, embedding similarity (using sentence-transformers), and keyword matching. Includes Python utilities and AQL query for intelligent document retrieval with configurable thresholds and scoring. Perhaps, use RapidFuzz for post-processing later
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
LET results = ( | |
// Get embedding results | |
LET embedding_results = ( | |
FOR doc IN glossary_view | |
LET similarity = COSINE_SIMILARITY(doc.embedding, @embedding_search) | |
FILTER similarity >= @embedding_similarity_threshold | |
SORT similarity DESC | |
LIMIT @top_n | |
RETURN { | |
doc: doc, | |
_key: doc._key, | |
similarity_score: similarity, | |
bm25_score: 0 | |
} | |
) | |
// Get BM25 results | |
LET bm25_results = ( | |
FOR doc IN glossary_view | |
SEARCH ANALYZER( | |
doc.term IN TOKENS(@search_text, "text_en") OR | |
doc.primary_definition IN TOKENS(@search_text, "text_en"), | |
"text_en" | |
) | |
OPTIONS { collections: ["glossary"] } | |
LET bm25_score = BM25(doc, @k, @b) | |
FILTER bm25_score > @bm25_similarity_threshold | |
SORT bm25_score DESC | |
LIMIT @top_n | |
RETURN { | |
doc: doc, | |
_key: doc._key, | |
similarity_score: 0, | |
bm25_score: bm25_score | |
} | |
) | |
// Keyword processing | |
LET keyword_processing = ( | |
// Get keyword results | |
LET keyword_results = ( | |
FOR entry IN glossary_view | |
LET normalizedEntryTerm = REGEX_REPLACE( | |
REGEX_REPLACE(entry.term, "<[^>]+>", ""), | |
"[^a-zA-Z0-9 ]", "" | |
) | |
FILTER LENGTH(entry.term) > 0 | |
FILTER REGEX_TEST(@search_text, CONCAT('\\b', LOWER(normalizedEntryTerm), '\\b'), true) | |
FILTER LENGTH(normalizedEntryTerm) >= 2 | |
FILTER SUBSTRING(entry.term, 0, 1) == UPPER(SUBSTRING(entry.term, 0, 1)) | |
LET termLength = LENGTH(normalizedEntryTerm) | |
RETURN { | |
term: entry.term, | |
entry_length: termLength, | |
definition: entry.primary_definition | |
} | |
) | |
// Group keyword results by root word | |
LET grouped_keyword_results = ( | |
FOR t IN keyword_results | |
LET normalizedTerm = t.term | |
LET rootWord = FIRST(TOKENS(normalizedTerm, "text_analyzer")) | |
COLLECT groupKey = rootWord INTO groupedTerms | |
LET longestTermInfo = FIRST( | |
FOR g IN groupedTerms | |
SORT LENGTH(g.t.term) DESC | |
RETURN { | |
term: g.t.term, | |
definition: g.t.definition | |
} | |
) | |
RETURN { | |
rootWord: groupKey, | |
longestTerm: longestTermInfo.term, | |
definition: longestTermInfo.definition | |
} | |
) | |
// Sort and limit keyword results | |
LET final_keyword_results = ( | |
FOR result IN grouped_keyword_results | |
SORT result.longestTerm ASC | |
LIMIT @top_n | |
RETURN { | |
term: result.longestTerm, | |
definition: result.definition | |
} | |
) | |
// Return the final keyword results | |
RETURN final_keyword_results | |
) | |
// Merge and deduplicate embedding and BM25 results | |
LET merged_results = ( | |
FOR result IN UNION_DISTINCT(embedding_results, bm25_results) | |
COLLECT key = result._key INTO group | |
LET doc = FIRST(group[*].result.doc) | |
LET similarity_score = MAX(group[*].result.similarity_score) | |
LET bm25_score = MAX(group[*].result.bm25_score) | |
RETURN { | |
"doc": doc, | |
"_key": key, | |
"similarity_score": similarity_score, | |
"bm25_score": bm25_score | |
} | |
) | |
// Sort and limit merged results | |
LET final_merged_results = ( | |
FOR result IN merged_results | |
SORT result.similarity_score DESC, result.bm25_score DESC | |
LIMIT @top_n | |
RETURN result | |
) | |
// Return all results | |
RETURN { | |
bm25_results: bm25_results, | |
embedding_results: embedding_results, | |
keyword_results: keyword_processing, | |
merged_results: final_merged_results | |
} | |
) | |
RETURN results |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from arango import ArangoClient | |
import arango | |
import regex as re | |
import json | |
from loguru import logger | |
from deepmerge import always_merger | |
from sentence_transformers import SentenceTransformer | |
from utils.get_project_root import get_project_root | |
# Calculate project directory once | |
project_dir = get_project_root() | |
# Helper functions | |
def load_config(file_path: str | Path) -> dict: | |
"""Load configuration from a JSON file.""" | |
try: | |
with open(file_path, "r") as file: | |
return json.load(file) | |
except Exception as e: | |
logger.error(f"Failed to load configuration from {file_path}: {e}") | |
raise | |
def load_aql_query(file_path: str | Path) -> str: | |
"""Load AQL query from file.""" | |
try: | |
path = Path(file_path) if isinstance(file_path, str) else file_path | |
with open(path, "r") as file: | |
return file.read().strip() | |
except Exception as e: | |
logger.error(f"Failed to load AQL query from {file_path}: {e}") | |
raise | |
def normalize_text(text: str) -> str: | |
"""Normalize the input text by removing HTML, special characters, and extra spaces.""" | |
cleaned_question = re.sub('<[^>]+>', '', text) | |
cleaned_question = re.sub(r'[^a-zA-Z0-9 ]', '', cleaned_question) | |
cleaned_question = re.sub(r'\s+', ' ', cleaned_question) | |
return cleaned_question.strip() | |
def initialize_arangodb_client(config: dict) -> tuple[ArangoClient, arango.database.StandardDatabase]: | |
"""Initialize and return the ArangoDB client and database.""" | |
client = ArangoClient(hosts=config["arango"]["hosts"]) | |
db = client.db( | |
config["arango"]["db_name"], | |
username=config["arango"]["username"], | |
password=config["arango"]["password"] | |
) | |
return client, db | |
def generate_search_embedding(search_text: str, model: SentenceTransformer) -> list: | |
"""Normalize the search text and generate its embedding.""" | |
search_text = normalize_text(search_text) | |
return model.encode(search_text).tolist() | |
def execute_bm25_embedding_keyword_query(config: dict, model: SentenceTransformer) -> list: | |
""" | |
Execute the AQL query using the provided configuration and model. | |
""" | |
try: | |
# Initialize ArangoDB client | |
logger.info("Connecting to ArangoDB...") | |
client, db = initialize_arangodb_client(config) | |
# Load AQL query file | |
logger.info("Loading AQL query...") | |
aql_query = load_aql_query(config["aql"]["query_path"]) | |
# Generate search embedding | |
logger.info("Normalizing search text and generating embeddings...") | |
search_embedding = generate_search_embedding(config["search"]["text"], model) | |
# Bind variables for AQL query | |
bind_vars = { | |
"search_text": config["search"]["text"], | |
"embedding_search": search_embedding, | |
"embedding_similarity_threshold": config["aql"]["embedding_similarity_threshold"], | |
"bm25_similarity_threshold": config["aql"]["bm25_similarity_threshold"], | |
"k": config["aql"]["k"], | |
"b": config["aql"]["b"], | |
"top_n": config["aql"]["top_n"] | |
} | |
# Execute the AQL query | |
logger.info("Executing AQL query...") | |
results = db.aql.execute(aql_query, bind_vars=bind_vars) | |
results = list(results) | |
logger.debug(f"Query results: {json.dumps(results, indent=4)}") | |
return results | |
except arango.AQLQueryExecuteError as e: | |
logger.error(f"AQL query execution failed: {e.error_message}") | |
logger.debug(f"AQL query: {e.query}") | |
logger.debug(f"Query parameters: {e.parameters}") | |
raise | |
except arango.ArangoServerError as e: | |
logger.error(f"ArangoDB server error: {e.error_message}") | |
raise | |
except Exception as e: | |
logger.error(f"Unexpected error executing AQL query: {e}") | |
raise | |
def validate_config(config: dict) -> bool: | |
"""Validate the configuration dictionary.""" | |
required_keys = { | |
"arango": ["hosts", "db_name", "username", "password"], | |
"aql": ["query_path", "embedding_similarity_threshold", "bm25_similarity_threshold", "k", "b", "top_n"], | |
"model": ["name"], | |
"search": ["text"] | |
} | |
for section, keys in required_keys.items(): | |
if section not in config: | |
raise ValueError(f"Missing section in config: {section}") | |
for key in keys: | |
if key not in config[section]: | |
raise ValueError(f"Missing key in config[{section}]: {key}") | |
return True | |
def main(): | |
"""Main function to execute the script.""" | |
try: | |
# Load base configuration from file | |
config_path = project_dir / "utils/config.json" | |
config = load_config(config_path) | |
# Validate the configuration | |
validate_config(config) | |
# Define the updates | |
updates = { | |
"search": { | |
"text": "What is an Audit administrator where do I find an Auditor?!" | |
}, | |
"aql": { | |
"query_path": str(project_dir / "utils/aql/bm25_embedding_keyword_combined.aql"), | |
"embedding_similarity_threshold": 0.6, | |
"top_n": 5 | |
}, | |
"model": { | |
"name": "sentence-transformers/all-mpnet-base-v2" | |
} | |
} | |
# Perform a deep merge | |
config = always_merger.merge(config, updates) | |
# Load the SentenceTransformer model | |
logger.info("Loading the SentenceTransformer model...") | |
model = SentenceTransformer(config["model"]["name"]) | |
# Execute the query using the configuration | |
results = execute_bm25_embedding_keyword_query(config, model) | |
logger.success("Query executed successfully!") | |
except Exception as e: | |
logger.error(f"Script execution failed: {e}") | |
raise | |
if __name__ == "__main__": | |
main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"arango": { | |
"hosts": "http://localhost:8529", | |
"db_name": "verifaix", | |
"username": "root", | |
"password": "openSesame" | |
}, | |
"aql": { | |
"query_path": "utils/aql/bm25_embedding_keyword_combined.aql", | |
"embedding_similarity_threshold": 0.6, | |
"bm25_similarity_threshold": 5, | |
"k": 2.4, | |
"b": 1, | |
"top_n": 5 | |
}, | |
"model": { | |
"name": "sentence-transformers/all-mpnet-base-v2" | |
}, | |
"search": { | |
"text": "What is an Audit administrator where do I find an Auditor?!" | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment