Skip to content

Instantly share code, notes, and snippets.

@grahama1970
Last active January 16, 2025 14:03
Show Gist options
  • Save grahama1970/44f6fe6b1ecfee145b8275b0227227cf to your computer and use it in GitHub Desktop.
Save grahama1970/44f6fe6b1ecfee145b8275b0227227cf to your computer and use it in GitHub Desktop.
ArangoDB hybrid search implementation combining BM25 text search, embedding similarity (using sentence-transformers), and keyword matching. Includes Python utilities and AQL query for intelligent document retrieval with configurable thresholds and scoring. Perhaps, use RapidFuzz for post-processing later
LET results = (
// Get embedding results
LET embedding_results = (
FOR doc IN glossary_view
LET similarity = COSINE_SIMILARITY(doc.embedding, @embedding_search)
FILTER similarity >= @embedding_similarity_threshold
SORT similarity DESC
LIMIT @top_n
RETURN {
doc: doc,
_key: doc._key,
similarity_score: similarity,
bm25_score: 0
}
)
// Get BM25 results
LET bm25_results = (
FOR doc IN glossary_view
SEARCH ANALYZER(
doc.term IN TOKENS(@search_text, "text_en") OR
doc.primary_definition IN TOKENS(@search_text, "text_en"),
"text_en"
)
OPTIONS { collections: ["glossary"] }
LET bm25_score = BM25(doc, @k, @b)
FILTER bm25_score > @bm25_similarity_threshold
SORT bm25_score DESC
LIMIT @top_n
RETURN {
doc: doc,
_key: doc._key,
similarity_score: 0,
bm25_score: bm25_score
}
)
// Keyword processing
LET keyword_processing = (
// Get keyword results
LET keyword_results = (
FOR entry IN glossary_view
LET normalizedEntryTerm = REGEX_REPLACE(
REGEX_REPLACE(entry.term, "<[^>]+>", ""),
"[^a-zA-Z0-9 ]", ""
)
FILTER LENGTH(entry.term) > 0
FILTER REGEX_TEST(@search_text, CONCAT('\\b', LOWER(normalizedEntryTerm), '\\b'), true)
FILTER LENGTH(normalizedEntryTerm) >= 2
FILTER SUBSTRING(entry.term, 0, 1) == UPPER(SUBSTRING(entry.term, 0, 1))
LET termLength = LENGTH(normalizedEntryTerm)
RETURN {
term: entry.term,
entry_length: termLength,
definition: entry.primary_definition
}
)
// Group keyword results by root word
LET grouped_keyword_results = (
FOR t IN keyword_results
LET normalizedTerm = t.term
LET rootWord = FIRST(TOKENS(normalizedTerm, "text_analyzer"))
COLLECT groupKey = rootWord INTO groupedTerms
LET longestTermInfo = FIRST(
FOR g IN groupedTerms
SORT LENGTH(g.t.term) DESC
RETURN {
term: g.t.term,
definition: g.t.definition
}
)
RETURN {
rootWord: groupKey,
longestTerm: longestTermInfo.term,
definition: longestTermInfo.definition
}
)
// Sort and limit keyword results
LET final_keyword_results = (
FOR result IN grouped_keyword_results
SORT result.longestTerm ASC
LIMIT @top_n
RETURN {
term: result.longestTerm,
definition: result.definition
}
)
// Return the final keyword results
RETURN final_keyword_results
)
// Merge and deduplicate embedding and BM25 results
LET merged_results = (
FOR result IN UNION_DISTINCT(embedding_results, bm25_results)
COLLECT key = result._key INTO group
LET doc = FIRST(group[*].result.doc)
LET similarity_score = MAX(group[*].result.similarity_score)
LET bm25_score = MAX(group[*].result.bm25_score)
RETURN {
"doc": doc,
"_key": key,
"similarity_score": similarity_score,
"bm25_score": bm25_score
}
)
// Sort and limit merged results
LET final_merged_results = (
FOR result IN merged_results
SORT result.similarity_score DESC, result.bm25_score DESC
LIMIT @top_n
RETURN result
)
// Return all results
RETURN {
bm25_results: bm25_results,
embedding_results: embedding_results,
keyword_results: keyword_processing,
merged_results: final_merged_results
}
)
RETURN results
from pathlib import Path
from arango import ArangoClient
import arango
import regex as re
import json
from loguru import logger
from deepmerge import always_merger
from sentence_transformers import SentenceTransformer
from utils.get_project_root import get_project_root
# Calculate project directory once
project_dir = get_project_root()
# Helper functions
def load_config(file_path: str | Path) -> dict:
"""Load configuration from a JSON file."""
try:
with open(file_path, "r") as file:
return json.load(file)
except Exception as e:
logger.error(f"Failed to load configuration from {file_path}: {e}")
raise
def load_aql_query(file_path: str | Path) -> str:
"""Load AQL query from file."""
try:
path = Path(file_path) if isinstance(file_path, str) else file_path
with open(path, "r") as file:
return file.read().strip()
except Exception as e:
logger.error(f"Failed to load AQL query from {file_path}: {e}")
raise
def normalize_text(text: str) -> str:
"""Normalize the input text by removing HTML, special characters, and extra spaces."""
cleaned_question = re.sub('<[^>]+>', '', text)
cleaned_question = re.sub(r'[^a-zA-Z0-9 ]', '', cleaned_question)
cleaned_question = re.sub(r'\s+', ' ', cleaned_question)
return cleaned_question.strip()
def initialize_arangodb_client(config: dict) -> tuple[ArangoClient, arango.database.StandardDatabase]:
"""Initialize and return the ArangoDB client and database."""
client = ArangoClient(hosts=config["arango"]["hosts"])
db = client.db(
config["arango"]["db_name"],
username=config["arango"]["username"],
password=config["arango"]["password"]
)
return client, db
def generate_search_embedding(search_text: str, model: SentenceTransformer) -> list:
"""Normalize the search text and generate its embedding."""
search_text = normalize_text(search_text)
return model.encode(search_text).tolist()
def execute_bm25_embedding_keyword_query(config: dict, model: SentenceTransformer) -> list:
"""
Execute the AQL query using the provided configuration and model.
"""
try:
# Initialize ArangoDB client
logger.info("Connecting to ArangoDB...")
client, db = initialize_arangodb_client(config)
# Load AQL query file
logger.info("Loading AQL query...")
aql_query = load_aql_query(config["aql"]["query_path"])
# Generate search embedding
logger.info("Normalizing search text and generating embeddings...")
search_embedding = generate_search_embedding(config["search"]["text"], model)
# Bind variables for AQL query
bind_vars = {
"search_text": config["search"]["text"],
"embedding_search": search_embedding,
"embedding_similarity_threshold": config["aql"]["embedding_similarity_threshold"],
"bm25_similarity_threshold": config["aql"]["bm25_similarity_threshold"],
"k": config["aql"]["k"],
"b": config["aql"]["b"],
"top_n": config["aql"]["top_n"]
}
# Execute the AQL query
logger.info("Executing AQL query...")
results = db.aql.execute(aql_query, bind_vars=bind_vars)
results = list(results)
logger.debug(f"Query results: {json.dumps(results, indent=4)}")
return results
except arango.AQLQueryExecuteError as e:
logger.error(f"AQL query execution failed: {e.error_message}")
logger.debug(f"AQL query: {e.query}")
logger.debug(f"Query parameters: {e.parameters}")
raise
except arango.ArangoServerError as e:
logger.error(f"ArangoDB server error: {e.error_message}")
raise
except Exception as e:
logger.error(f"Unexpected error executing AQL query: {e}")
raise
def validate_config(config: dict) -> bool:
"""Validate the configuration dictionary."""
required_keys = {
"arango": ["hosts", "db_name", "username", "password"],
"aql": ["query_path", "embedding_similarity_threshold", "bm25_similarity_threshold", "k", "b", "top_n"],
"model": ["name"],
"search": ["text"]
}
for section, keys in required_keys.items():
if section not in config:
raise ValueError(f"Missing section in config: {section}")
for key in keys:
if key not in config[section]:
raise ValueError(f"Missing key in config[{section}]: {key}")
return True
def main():
"""Main function to execute the script."""
try:
# Load base configuration from file
config_path = project_dir / "utils/config.json"
config = load_config(config_path)
# Validate the configuration
validate_config(config)
# Define the updates
updates = {
"search": {
"text": "What is an Audit administrator where do I find an Auditor?!"
},
"aql": {
"query_path": str(project_dir / "utils/aql/bm25_embedding_keyword_combined.aql"),
"embedding_similarity_threshold": 0.6,
"top_n": 5
},
"model": {
"name": "sentence-transformers/all-mpnet-base-v2"
}
}
# Perform a deep merge
config = always_merger.merge(config, updates)
# Load the SentenceTransformer model
logger.info("Loading the SentenceTransformer model...")
model = SentenceTransformer(config["model"]["name"])
# Execute the query using the configuration
results = execute_bm25_embedding_keyword_query(config, model)
logger.success("Query executed successfully!")
except Exception as e:
logger.error(f"Script execution failed: {e}")
raise
if __name__ == "__main__":
main()
{
"arango": {
"hosts": "http://localhost:8529",
"db_name": "verifaix",
"username": "root",
"password": "openSesame"
},
"aql": {
"query_path": "utils/aql/bm25_embedding_keyword_combined.aql",
"embedding_similarity_threshold": 0.6,
"bm25_similarity_threshold": 5,
"k": 2.4,
"b": 1,
"top_n": 5
},
"model": {
"name": "sentence-transformers/all-mpnet-base-v2"
},
"search": {
"text": "What is an Audit administrator where do I find an Auditor?!"
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment