Last active
October 12, 2024 17:53
-
-
Save grahama1970/135a876f68583305ba07466dbe72329e to your computer and use it in GitHub Desktop.
The ArangoDBHelper class provides comprehensive management for an ArangoDB instance, handling initialization, connection, schema retrieval, and collection management. It integrates LLM-based metadata generation, ensuring structured data for collections. The class supports asynchronous database initialization, embedding storage, AQL query execut…
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import importlib | |
import os | |
import json | |
import asyncio | |
import sys | |
from arango import ArangoClient | |
from arango.exceptions import ArangoError, CollectionCreateError | |
import datetime | |
import logging | |
from typing import List, Dict, Optional, Any, Union | |
from functools import lru_cache | |
from dotenv import load_dotenv | |
import nltk | |
from nltk.corpus import wordnet as wn | |
import uuid | |
from embed_rank.config import Config | |
from embed_rank.utils.run_async_or_sync import run_async_or_sync | |
from verifaix.llm_client.get_litellm_response import get_litellm_response | |
from verifaix.utils.chunked import chunked | |
from verifaix.utils.json_cleaner import clean_json_string | |
from verifaix.utils.load_cache_from_json import load_cache_from_json | |
import concurrent.futures | |
from typing import List, Dict, Optional, Any, Union | |
load_dotenv('../../.env') | |
config = Config() | |
class ArangoDBHelper: | |
def __init__(self, config: Dict[str, Any]): | |
""" | |
Initialize the ArangoDBHelper class with a unified config dictionary. | |
Args: | |
config (Dict[str, Any]): A configuration dictionary containing all relevant settings. | |
""" | |
self.log_level = config.get("log_level", "INFO").upper() | |
self.debug_mode = bool(int(os.getenv('DEBUG_MODE', 0))) # True/False | |
self._logger = None # Lazy Load the logger | |
# Merge arango_config from the unified config dictionary | |
self._merged_config = self._merge_arango_config(config.get('arango_config')) | |
# Load values from merged config | |
self.host = self._merged_config['host'] | |
self.db_name = self._merged_config['db_name'] | |
self.username = self._merged_config['username'] | |
self.password = self._merged_config['password'] | |
self.collection_name = self._merged_config['collection_name'] | |
self.schema_collection_name = self._merged_config['schema_collection'] | |
self.llm_params = config.get('llm_params', None) | |
self.json_file_path = config.get('json_file_path', None) | |
self.force = config.get('force', False) | |
self.generate_metadata = config.get('generate_metadata', False) | |
self.client = ArangoClient(hosts=self.host) | |
self.db = None | |
### | |
# Lazy Loads | |
### | |
@property | |
def logger(self) -> logging.Logger: | |
"""Lazy-load the logger and ensure it's properly set up with the ColoredFormatter.""" | |
if self._logger is None: | |
try: | |
# Import the ColoredFormatter from your utility module | |
from verifaix.utils.colored_logger_mini import ColoredFormatter | |
# Initialize logger | |
self._logger = logging.getLogger("ArangoDBHelper") | |
# Clear any existing handlers to prevent conflicts | |
self._logger.handlers.clear() | |
# Set up the ColoredFormatter and StreamHandler | |
handler = logging.StreamHandler(sys.stdout) | |
handler.setFormatter(ColoredFormatter('%(levelname)s: %(message)s')) | |
self._logger.addHandler(handler) | |
# Set the log level dynamically based on the log_level parameter from the config | |
self._logger.setLevel(getattr(logging, self.log_level, logging.INFO)) | |
except ImportError as e: | |
raise ImportError(f"Failed to import ColoredFormatter: {e}") from e | |
return self._logger | |
def _lazy_load_nltk(self): | |
"""Lazy load NLTK resources.""" | |
import nltk | |
nltk.download('wordnet') | |
nltk.download('punkt') | |
### | |
# Initialization and Connection Handling | |
### | |
async def connect(self): | |
"""Establish connection to ArangoDB if not already connected.""" | |
if self.db is None: | |
try: | |
sys_db = self.client.db('_system', username=self.username, password=self.password) | |
if not sys_db.has_database(self.db_name): | |
sys_db.create_database(self.db_name) | |
self.db = self.client.db(self.db_name, username=self.username, password=self.password) | |
# If a JSON file path is provided, initialize the database | |
if self.json_file_path: | |
self.logger.info(f"JSON file path provided: {self.json_file_path}. Initializing database.") | |
await self.initialize_database(self.json_file_path, force=self.force) | |
except ArangoError as e: | |
self.logger.error(f"Failed to connect to the database: {str(e)}", exc_info=self.debug_mode) | |
raise | |
except Exception as e: | |
self.logger.error(f"Unexpected error during connection or database initialization: {str(e)}") | |
raise | |
def _merge_arango_config(self, arango_config: Optional[Dict[str, str]]) -> Dict[str, str]: | |
""" | |
Merge user-provided ArangoDB configuration with defaults from environment variables. | |
Ensures that all necessary fields are present. | |
""" | |
default_arango_config = { | |
"host": os.getenv("ARANGO_DB_HOST", "http://localhost:8529"), | |
"username": os.getenv("ARANGO_DB_USERNAME", "root"), | |
"password": os.getenv("ARANGO_DB_PASSWORD", "openSesame"), | |
"db_name": os.getenv("ARANGO_DB_NAME", "default_db"), | |
"collection_name": os.getenv("ARANGO_DB_COLLECTION_NAME", "default_collection"), | |
"schema_collection": os.getenv("ARANGO_DB_SCHEMA_COLLECTION_NAME", "schema_cache") | |
} | |
merged_config = default_arango_config.copy() | |
if arango_config: | |
merged_config.update(arango_config) | |
# Validate that all required keys are present | |
required_keys = ["host", "username", "password", "db_name", "collection_name", "schema_collection"] | |
missing_keys = [key for key in required_keys if key not in merged_config] | |
if missing_keys: | |
logging.warning(f"Missing ArangoDB configuration keys: {missing_keys}") | |
return merged_config | |
def ensure_database_exists(self): | |
"""Ensure that the database exists. If not, create it.""" | |
try: | |
sys_db = self.client.db('_system', username=self.username, password=self.password) | |
# Check if the database exists and create it if not | |
has_db = sys_db.has_database(self.db_name) | |
if not has_db: | |
sys_db.create_database(self.db_name) | |
logging.debug(f"Created database: {self.db_name}") # Changed to debug | |
# Connect to the database | |
self.db = self.client.db(self.db_name, username=self.username, password=self.password) | |
logging.debug(f"Connected to database: {self.db_name}") # Changed to debug | |
self._ensure_schema_collection() | |
except ArangoError as e: | |
logging.error(f"Error ensuring database exists: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def validate_arango_connection(self): | |
"""Validate ArangoDB connection details are properly configured.""" | |
required_vars = ['ARANGO_DB_HOST', 'ARANGO_DB_USERNAME', 'ARANGO_DB_PASSWORD', 'ARANGO_DB_NAME'] | |
for var in required_vars: | |
if not os.getenv(var): | |
raise ValueError(f"Environment variable {var} is missing or not set.") | |
self.logger.debug("ArangoDB environment variables validated successfully.") # Changed to debug | |
### | |
# Database Schema Management | |
### | |
def get_database_schema(self, included_collections: Optional[List[str]] = None) -> dict: | |
"""Retrieve the schema for the connected ArangoDB.""" | |
try: | |
collections = self.db.collections() | |
# Clean up the collection names (lowercase and replace spaces with underscores) | |
included_collections = set(col.lower().replace(" ", "_") for col in (included_collections or [])) | |
included_collections = list(included_collections) | |
# Filter collections if included_collections is specified | |
if included_collections: | |
collections = [ | |
col | |
for col in collections | |
if col['name'] in included_collections | |
] | |
schema = {} | |
for collection in collections: | |
collection_name = collection['name'] | |
# Skip system collections and only include collections specified in included_collections | |
if collection_name.startswith('_'): | |
continue | |
if included_collections and collection_name not in included_collections: | |
continue | |
# Fetch the properties of the collection | |
schema[collection_name] = { | |
'name': collection_name, | |
'properties': self._get_collection_properties(collection_name) | |
} | |
self.logger.debug(f"Database schema retrieved: {schema}") | |
return schema | |
except Exception as e: | |
self.logger.error(f"Failed to get database schema: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def _get_collection_properties(self, collection_name: str) -> dict: | |
"""Get the properties of a collection.""" | |
try: | |
query = f"FOR doc IN {collection_name} LIMIT 1 RETURN KEEP(doc, ATTRIBUTES(doc, true))" | |
cursor = self.db.aql.execute(query) | |
sample_doc = next(cursor, None) | |
return {k: type(v).__name__ for k, v in sample_doc.items()} if sample_doc else {} | |
except Exception as e: | |
self.logger.error(f"Error retrieving collection properties: {str(e)}", exc_info=self.debug_mode) | |
raise | |
@lru_cache(maxsize=1) | |
def _get_cached_schema(self, db_name): | |
try: | |
cache_collection = self.db.collection(self.schema_collection_name) | |
cache_doc = cache_collection.get(db_name) | |
if cache_doc: | |
return json.loads(cache_doc['schema']) | |
except ArangoError: | |
return None | |
def _cache_schema(self, db_name, schema): | |
try: | |
cache_collection = self.db.collection(self.schema_collection_name) | |
# Ensure schema collection exists | |
if not self.db.has_collection(self.schema_collection_name): | |
self.db.create_collection(self.schema_collection_name) | |
# Force rewrite of schema cache if needed | |
if self.force: | |
cache_collection.truncate() | |
cache_collection.insert({ | |
'_key': db_name, | |
'schema': json.dumps(schema), | |
'timestamp': datetime.datetime.now().isoformat() | |
}, overwrite=True) | |
except ArangoError as e: | |
self.logger.error(f"Error caching schema: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def _ensure_schema_collection(self): | |
"""Ensure the schema cache collection exists.""" | |
if self.db is None: | |
raise ValueError("Database connection not established") | |
try: | |
has_collection = self.db.has_collection(self.schema_collection_name) | |
if not has_collection: | |
self.db.create_collection(self.schema_collection_name) | |
except CollectionCreateError as e: | |
logging.error(f"Failed to create schema collection '{self.schema_collection_name}': {e.message} - {e.error_code}", exc_info=self.debug_mode) | |
raise | |
def initialize_schemas(self) -> tuple[Dict, Dict]: | |
""" | |
Initialize the database schema and decision schema. | |
Optionally filter collections using included_collections from the configuration. | |
Returns: | |
tuple: A tuple containing the ArangoDB schema and the decision schema. | |
""" | |
try: | |
# Pass the included_collections argument if it's set in the configuration | |
included_collections = getattr(self.config, 'INCLUDED_COLLECTIONS', None) | |
# Fetch the ArangoDB schema with optional collection filtering | |
arangodb_schema = self.get_database_schema(included_collections=included_collections) | |
if not arangodb_schema: | |
raise ValueError("Failed to retrieve ArangoDB schema") | |
# Create the decision schema based on the ArangoDB schema | |
decision_schema = self.get_decision_schema(arangodb_schema) | |
if not decision_schema: | |
raise ValueError("Failed to create decision schema") | |
self.logger.info("Database schema and decision schema initialized successfully.") | |
# Return both schemas as a tuple | |
return arangodb_schema, decision_schema | |
except Exception as e: | |
# Log the error and re-raise the exception | |
self.logger.error(f"Error initializing schemas: {str(e)}", exc_info=self.debug_mode) | |
raise | |
### | |
# Collection Metadata Handling | |
### | |
async def load_collection_metadata(self, metadata: Dict[str, Any]): | |
"""Insert collection metadata directly into the ArangoDB collection.""" | |
try: | |
collection_name = "collection_metadata" | |
# Ensure the metadata collection exists | |
if not self.db.has_collection(collection_name): | |
self.db.create_collection(collection_name) | |
# Insert metadata into the collection | |
self.insert_bulk_json(collection_name, [metadata]) | |
self.logger.info(f"Collection metadata inserted into '{collection_name}'") | |
except ArangoError as e: | |
self.logger.error(f"Error inserting metadata: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def ensure_collection_metadata(self, metadata: Dict[str, Any]): | |
"""Ensure the collection metadata is stored in ArangoDB.""" | |
try: | |
collection_name = "collection_metadata" | |
# Ensure the metadata collection exists | |
if not self.db.has_collection(collection_name): | |
self.db.create_collection(collection_name) | |
# Insert metadata into the collection with the collection name as the _key | |
formatted_metadata = [ | |
{"_key": name, **details} for name, details in metadata.items() | |
] | |
self.insert_bulk_json(collection_name, formatted_metadata) | |
self.logger.info(f"Metadata for collections inserted into '{collection_name}'") | |
except ArangoError as e: | |
self.logger.error(f"Error inserting metadata: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def get_collection_metadata(self, collection_name: str) -> Optional[Dict[str, Any]]: | |
"""Retrieve metadata for a specific collection from the collection_metadata collection.""" | |
try: | |
# Check if the collection_metadata collection exists | |
if not self.db.has_collection("collection_metadata"): | |
self.logger.warning(f"Metadata collection 'collection_metadata' does not exist.") | |
return None | |
# Query the collection_metadata for the requested collection_name | |
query = "FOR doc IN collection_metadata FILTER doc._key == @collection_name RETURN doc" | |
bind_vars = {"collection_name": collection_name} | |
cursor = self.db.aql.execute(query, bind_vars=bind_vars) | |
metadata = next(cursor, None) # Retrieve the first document if exists | |
if metadata: | |
return metadata | |
else: | |
self.logger.warning(f"No metadata found for collection: {collection_name}") | |
return None | |
except ArangoError as e: | |
self.logger.error(f"Error retrieving metadata for collection '{collection_name}': {str(e)}", exc_info=self.debug_mode) | |
raise | |
async def generate_collection_metadata(self, data: Dict[str, Any], llm_params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: | |
""" | |
Generate collection metadata using LLM based on the provided data concurrently. | |
Args: | |
data (Dict[str, Any]): The data dictionary containing collections and their documents. | |
llm_params (Optional[Dict[str, Any]]): Parameters for the LLM call. | |
Returns: | |
List[Dict[str, Any]]: A list of metadata documents for each collection. | |
""" | |
default_llm_params = { | |
"model": "openai/gpt-4o-mini", # Replace with your desired model | |
"max_tokens": 500, | |
"temperature": 0.2, | |
"json_mode": True, | |
# "top_k": 1, | |
# "top_p": 1.0 | |
} | |
# Merge default_llm_params with self.llm_params | |
llm_params = default_llm_params | (self.llm_params or {}) | |
collection_metadata = [] | |
async def generate_metadata_for_collection(collection_name: str, documents: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: | |
# Use only the first few documents to limit token usage | |
sample_docs = documents[:5] | |
# Construct a prompt for the LLM | |
prompt = ( | |
f"Given the following sample data from the '{collection_name}' collection:\n\n" | |
f"{sample_docs}\n\n" | |
"Please provide:\n\n" | |
f"1. A concise description of the '{collection_name}' collection.\n\n" | |
"2. A list of fields in the collection, each with a brief description.\n\n" | |
"Provide the output in JSON format as:\n\n" | |
"{\n" | |
f' "collection_name": "{collection_name}",\n' | |
' "description": "<collection description>",\n' | |
' "fields": {\n' | |
' "field1": "<field1 description>",\n' | |
' "field2": "<field2 description>",\n' | |
" ...\n" | |
" }\n" | |
"}" | |
) | |
system_message = ( | |
"You are a database expert who writes clear and concise descriptions. " | |
"Respond in JSON format only." | |
) | |
messages = [ | |
{"role": "system", "content": system_message}, | |
{"role": "user", "content": prompt} | |
] | |
# Use LLM to get the response | |
try: | |
response = await get_litellm_response(messages, llm_params) | |
# Clean and return the response content | |
if response and 'choices' in response and response['choices']: | |
metadata_collection= clean_json_string(response['choices'][0]['message']['content'], return_dict=True, logger=self.logger) | |
else: | |
self.logger.error("No valid response received from LLM.", exc_info=self.debug_mode) | |
metadata_collection = {} | |
# If no metadata is generated, return None to signify an error or failure | |
if not metadata_collection: | |
return None | |
# Create the metadata document | |
metadata_doc = { | |
"_key": collection_name, | |
"description": metadata_collection.get('description', ''), | |
"fields": metadata_collection.get('fields', {}) | |
} | |
return metadata_doc | |
except Exception as e: | |
self.logger.error(f"Error generating metadata for collection '{collection_name}': {str(e)}", exc_info=self.debug_mode) | |
return None | |
# Create a list of coroutines for all collections | |
tasks = [ | |
generate_metadata_for_collection(collection_name, documents) | |
for collection_name, documents in data.items() | |
if collection_name != 'collection_metadata' | |
] | |
# Run all tasks concurrently | |
results = await asyncio.gather(*tasks, return_exceptions=True) | |
# Collect successful results | |
for result in results: | |
if isinstance(result, dict): | |
collection_metadata.append(result) | |
elif isinstance(result, Exception): | |
self.logger.error(f"Exception occurred: {str(result)}", exc_info=self.debug_mode) | |
return collection_metadata | |
### | |
# Collection Management | |
### | |
def ensure_collection_exists(self, collection_name: str): | |
""" | |
Ensure that the specified collection exists, and create it if it doesn't. | |
""" | |
try: | |
if not self.db.has_collection(collection_name): | |
self.db.create_collection(collection_name) | |
logging.info(f"Created collection: {collection_name}") | |
except CollectionCreateError as e: | |
logging.error(f"Error creating collection '{collection_name}': {e}", exc_info=self.debug_mode) | |
raise | |
except ArangoError as e: | |
logging.error(f"ArangoDB error: {e}", exc_info=self.debug_mode) | |
raise | |
@staticmethod | |
def _validate_collection_name(name): | |
if not name or not isinstance(name, str) or not name.isidentifier() or (name.startswith('_') and name != '_schema_cache'): | |
raise ValueError(f"Invalid collection name: {name}. Collection names must be alphanumeric and not start with an underscore unless it's a system collection.") | |
def load_json_in_chunks(file_path: str, chunk_size: int = 1000): | |
with open(file_path, 'r') as file: | |
buffer = "" | |
for line in file: | |
buffer += line.strip() | |
if buffer.endswith("},"): | |
chunk = json.loads(buffer[:-1]) | |
yield chunk | |
buffer = "" | |
def insert_bulk_json(self, collection_name: str, json_data: List[Dict[str, Any]], batch_size: int = 500): | |
"""Inserts bulk JSON data into the specified collection in batches.""" | |
# await self.connect() | |
if not self.db.has_collection(collection_name): | |
self.db.create_collection(collection_name) | |
collection = self.db.collection(collection_name) | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
futures = [] | |
for batch in chunked(json_data, size=batch_size, logger=self.logger): | |
futures.append(executor.submit(collection.insert_many, batch)) | |
for future in concurrent.futures.as_completed(futures): | |
try: | |
future.result() # Check for exceptions | |
except Exception as e: | |
logging.error(f"Error inserting batch: {str(e)}", exc_info=self.debug_mode) | |
raise | |
#### | |
# Query and Decision Schema Execution | |
#### | |
def execute_aql_query(self, query: str, bind_vars: Optional[Dict] = None) -> list: | |
"""Execute an AQL query with optional bind variables.""" | |
try: | |
cursor = self.db.aql.execute(query, bind_vars=bind_vars) | |
return list(cursor) | |
except ArangoError as e: | |
self.logger.error(f"Error retrieving collection properties: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def execute_aql_query_with_schema(self, query: str, schema: dict): | |
"""Execute AQL query using the schema.""" | |
try: | |
# Ensure database exists before executing the query | |
self.arango_db_helper.ensure_database_exists() | |
result = self.arango_db_helper.execute_aql_query(query) | |
if not result: | |
return {"status": "empty_result", "message": "No results found."} | |
return {"status": "success", "data": result} | |
except Exception as e: | |
self.logger.error(f"Error executing AQL query: {e}", exc_info=self.debug_mode) | |
raise | |
def get_decision_schema(self, arangodb_schema: Dict[str, Dict[str, Any]], ignore_empty_collections: bool = False) -> Dict: | |
""" | |
Creates an optimized decision schema tailored for efficient processing by an LLM. | |
Detects foreign key relationships between collections dynamically using common naming patterns. | |
Args: | |
arangodb_schema (Dict[str, Dict[str, Any]]): The ArangoDB schema. | |
ignore_empty_collections (bool): If True, empty collections will be ignored. | |
Returns: | |
Dict: The optimized decision schema, including detected foreign key relationships. | |
""" | |
try: | |
decision_schema = { | |
"conditions": [], | |
"relations": {} | |
} | |
common_fields = ["_key", "_id", "_rev"] | |
foreign_key_suffixes = ["_id", "_fk", "id_fk", "id", "fk"] | |
# Iterate through collections in the schema | |
for collection_name, details in arangodb_schema.items(): | |
fields = details.get("properties", {}) | |
# Skip empty collections if ignore_empty_collections | |
if ignore_empty_collections and details.get("documents", 0) == 0: | |
self.logger.debug(f"Skipping empty collection: {collection_name}") | |
continue | |
# Detect foreign key fields based on naming patterns | |
foreign_key_fields = [ | |
field for field in fields.keys() | |
if any(field.endswith(suffix) for suffix in foreign_key_suffixes) and field not in common_fields | |
] | |
# If foreign key fields are found, infer relationships | |
for fk_field in foreign_key_fields: | |
# Assume the foreign key references another collection named similarly | |
target_collection = fk_field.replace('_id', '') | |
if target_collection in arangodb_schema: | |
if collection_name not in decision_schema["relations"]: | |
decision_schema["relations"][collection_name] = [] | |
decision_schema["relations"][collection_name].append({ | |
"foreign_key": fk_field, | |
"references": target_collection | |
}) | |
# Retrieve collection metadata | |
metadata = self.get_collection_metadata(collection_name) | |
if metadata is None: | |
self.logger.warning(f"No metadata available for collection: {collection_name}") | |
description = "" | |
fields_description = {} | |
else: | |
# Simplify metadata by removing internal identifiers and shortening descriptions | |
description = metadata.get("description", "") | |
fields_description = metadata.get("fields", {}) | |
# Simplify field descriptions | |
simplified_fields = { | |
field: fields_description.get(field, "").split(".")[0] for field in fields.keys() if field not in common_fields | |
} | |
# Add the collection's unique conditions to the decision schema | |
decision = { | |
"if": list(fields.keys()), | |
"then": collection_name, | |
"description": description.split(".")[0] + ".", | |
"fields": simplified_fields | |
} | |
decision_schema["conditions"].append(decision) | |
return decision_schema | |
except Exception as e: | |
self.logger.error(f"Error creating decision schema: {str(e)}", exc_info=self.debug_mode) | |
raise | |
### | |
# Document Handling | |
### | |
def store_embedding(self, collection_name: str, text: str, embedding: List[float], metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
"""Store embedding into the database.""" | |
try: | |
document = { | |
"text": text, | |
"embedding": embedding, | |
"metadata": metadata or {}, | |
"embedding_id": str(uuid.uuid4()), | |
"created_at": datetime.datetime.now().isoformat() | |
} | |
return self.db.collection(collection_name).insert(document) | |
except ArangoError as e: | |
logging.error(f"Error storing embedding in collection {collection_name}: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def retrieve_embedding(self, collection_name: str, text: Optional[str] = None, doc_id: Optional[str] = None) -> Optional[Dict[str, Any]]: | |
"""Retrieve embedding based on text or document ID.""" | |
try: | |
if text: | |
cursor = self.db.collection(collection_name).find({'text': text}, limit=1) | |
elif doc_id: | |
cursor = self.db.collection(collection_name).get(doc_id) | |
else: | |
raise ValueError("Either 'text' or 'doc_id' must be provided to retrieve an embedding.") | |
return cursor.next() if cursor else None | |
except ArangoError as e: | |
logging.error(f"Error retrieving embedding from collection {collection_name}: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def get_all_documents(self, collection_name): | |
try: | |
query = f"FOR doc IN {collection_name} RETURN doc" | |
cursor = self.db.aql.execute(query) | |
return list(cursor) | |
except ArangoError as e: | |
logging.error(f"Error retrieving all documents: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def get_last_processed_document(self, collection_name): | |
try: | |
query = f"FOR doc IN {collection_name} SORT doc.created_at DESC LIMIT 1 RETURN doc" | |
cursor = self.db.aql.execute(query) | |
return cursor.next() if cursor else None | |
except ArangoError as e: | |
logging.error(f"Error retrieving last processed document: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def get_new_or_updated_documents(self, collection_name, last_processed_timestamp): | |
try: | |
query = f"FOR doc IN {collection_name} FILTER doc.created_at > @last_processed_timestamp RETURN doc" | |
bind_vars = {"last_processed_timestamp": last_processed_timestamp} | |
cursor = self.db.aql.execute(query, bind_vars=bind_vars) | |
return list(cursor) | |
except ArangoError as e: | |
logging.error(f"Error retrieving new or updated documents: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def execute_aql_query(self, query: str, bind_vars: Optional[Dict] = None): | |
self.ensure_database_exists() | |
cursor = self.db.aql.execute(query, bind_vars=bind_vars) | |
return list(cursor) | |
def insert_edge_document(self, collection_name, document): | |
try: | |
self.db.collection(collection_name).insert(document) | |
except ArangoError as e: | |
logging.error(f"Error inserting edge document: {str(e)}", exc_info=self.debug_mode) | |
raise | |
def create_edge_collection(self, collection_name): | |
try: | |
if not self.db.has_collection(collection_name): | |
self.db.create_collection(collection_name, edge=True) | |
except CollectionCreateError as e: | |
logging.error(f"Error creating edge collection {collection_name}: {e.message} - {e.error_code}", exc_info=self.debug_mode) | |
raise | |
except ArangoError as e: | |
logging.error(f"ArangoDB error while creating edge collection: {e.message}", exc_info=self.debug_mode) | |
raise | |
### | |
# Database Intitialization | |
### | |
async def initialize_database(self, json_file_path: Optional[str] = None, force: bool = True, generate_metadata: bool = True): | |
""" | |
Initialize the database and insert collections and metadata from the provided JSON file. | |
If force is True, existing collections will be truncated before inserting new data. | |
Args: | |
json_file_path (Optional[str]): Path to the JSON file containing the data. | |
force (bool): If True, truncate existing collections before inserting new data. | |
""" | |
try: | |
# If the json_file_path is provided during initialization, automatically use it | |
if json_file_path is None: | |
if self.json_file_path: | |
json_file_path = self.json_file_path | |
else: | |
raise ValueError("JSON file path must be provided or set during initialization.") | |
# Load JSON data | |
data = load_cache_from_json(json_file_path, self.logger) | |
# Ensure the database exists | |
self.ensure_database_exists() | |
# Check if 'collection_metadata' exists | |
if 'collection_metadata' not in data and generate_metadata: | |
self.logger.debug("No 'collection_metadata' found in JSON. Generating using LLM.") # Changed to debug | |
# Generate collection metadata using the LLM | |
collection_metadata = await self.generate_collection_metadata(data) | |
# Add 'collection_metadata' to the data dictionary | |
data['collection_metadata'] = collection_metadata | |
# Iterate through collections and ensure they exist | |
for collection_name, collection_data in data.items(): | |
# Ensure the collection exists; create if it doesn't | |
if not self.db.has_collection(collection_name): | |
self.logger.debug(f"Collection '{collection_name}' does not exist. Creating collection.") # Changed to debug | |
self.ensure_collection_exists(collection_name) | |
else: | |
self.logger.debug(f"Collection '{collection_name}' already exists.") # Changed to debug | |
# If force is True, truncate the collection | |
if force: | |
self.logger.debug(f"Truncating collection '{collection_name}' before inserting new data.") # Changed to debug | |
self.db.collection(collection_name).truncate() | |
# Insert data into the collection | |
self.insert_bulk_json(collection_name, collection_data) | |
self.logger.debug(f"Data inserted into collection '{collection_name}'.") # Changed to debug | |
except FileNotFoundError as e: | |
self.logger.error(f"JSON file not found: {json_file_path}") | |
raise | |
except json.JSONDecodeError as e: | |
self.logger.error(f"Error decoding JSON data: {e}") | |
raise | |
except Exception as e: | |
self.logger.error(f"Unexpected error during database initialization: {str(e)}") | |
raise | |
### | |
# Usage Examples | |
### | |
def initialize_nuclear_database(db_name, json_file_path: str, batch_size: int = 500) -> None: | |
""" | |
Initialize the 'nuclear' database and load data from the provided JSON file. | |
Args: | |
json_file_path (str): Path to the JSON file containing the reactor data. | |
batch_size (int): Number of documents to insert in one batch (default is 500). | |
Raises: | |
FileNotFoundError: If the JSON file is not found. | |
json.JSONDecodeError: If there's an error decoding the JSON data. | |
Exception: For any other unexpected errors. | |
""" | |
try: | |
logging.info(f"Initializing the nuclear database '{db_name}' from file: {json_file_path}") | |
# Initialize ArangoDBHelper | |
arango_helper = ArangoDBHelper() | |
# Set the database name | |
arango_helper.db_name = db_name | |
# Ensure the 'nuclear' database exists | |
arango_helper.ensure_database_exists() | |
# Load JSON data in a memory-efficient manner | |
with open(json_file_path, 'r') as file: | |
try: | |
data = json.load(file) | |
except json.JSONDecodeError as e: | |
logging.error(f"Error decoding JSON file {json_file_path}: {str(e)}") | |
raise | |
# Insert data for each collection | |
for collection_name, documents in data.items(): | |
if isinstance(documents, list): | |
# Process in batches | |
logging.info(f"Processing collection '{collection_name}' with {len(documents)} documents.") | |
for i in range(0, len(documents), batch_size): | |
batch = documents[i:i + batch_size] | |
success = arango_helper.insert_bulk_json(collection_name, batch) | |
if isinstance(success, str): # If an error message is returned | |
logging.error(f"Failed to insert data into {collection_name}: {success}") | |
else: | |
logging.info(f"Successfully inserted batch {i // batch_size + 1} into {collection_name}") | |
else: | |
logging.warning(f"Collection '{collection_name}' has unexpected data format and was skipped.") | |
logging.info("Nuclear database initialization completed successfully.") | |
except FileNotFoundError as e: | |
logging.error(f"JSON file not found: {json_file_path}") | |
raise | |
except json.JSONDecodeError as e: | |
logging.error(f"Error decoding JSON data: {e}") | |
raise | |
except Exception as e: | |
logging.error(f"Unexpected error during database initialization: {str(e)}") | |
raise | |
def usage_example_initialize_db(): | |
db_name = "nuclear" or os.environ.get("ARANGO_DB_NAME") | |
project_path = os.environ.get("PROJECT_PATH") | |
json_file_path = os.path.join(project_path, "verifaix/data/db/reactor_data.json") | |
initialize_nuclear_database(db_name, json_file_path) | |
async def usage_example_get_schema(): | |
project_path = os.environ.get("PROJECT_PATH") | |
arango_config = { | |
"db_name": "nuclear", | |
"included_collections": | |
["reactor_data","radiation_levels", "employee_records", "waste_management"] | |
} | |
arango_helper = ArangoDBHelper(arango_config) | |
await arango_helper.connect() | |
included_collections = [ | |
"reactor_data","radiation_levels", "employee_records", "waste_management" | |
] | |
schema = arango_helper.get_database_schema(included_collections) | |
print(schema) | |
async def usage_example_get_decision_schema(): | |
project_path = os.environ.get("PROJECT_PATH") | |
config = { | |
"arango_config": { | |
"db_name": "nuclear", | |
"included_collections": [ | |
"reactor_data", "radiation_levels", "employee_records", "waste_management" | |
], | |
"host": "http://localhost:8529", | |
"username": "root", | |
"password": "openSesame" | |
}, | |
"llm_config": { | |
"model": "openai/gpt-4o-mini", | |
"json_mode": True | |
}, | |
"json_file_path": os.path.join(project_path, "verifaix/data/db/reactor_data_no_metadata.json"), | |
"force": True, | |
"generate_metadata": True, | |
"debug_mode": True, | |
"log_level": "INFO", | |
} | |
arango_helper = ArangoDBHelper(config) | |
await arango_helper.connect() | |
included_collections = config['arango_config']['included_collections'] | |
schema = arango_helper.get_database_schema(included_collections) | |
decision_schema = arango_helper.get_decision_schema(schema) | |
print(decision_schema) | |
if __name__ == "__main__": | |
asyncio.run(usage_example_get_decision_schema()) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"reactor_data": [ | |
{ | |
"_key": "reactor_1", | |
"name": "Reactor 1", | |
"type": "Pressurized Water Reactor", | |
"status": "Active", | |
"thermal_capacity_mw": 3000, | |
"location": "Unit 1" | |
}, | |
{ | |
"_key": "reactor_2", | |
"name": "Reactor 2", | |
"type": "Boiling Water Reactor", | |
"status": "Maintenance", | |
"thermal_capacity_mw": 2500, | |
"location": "Unit 2" | |
} | |
], | |
"radiation_levels": [ | |
{ | |
"_key": "rad_sensor_1", | |
"sensor_location": "North Boundary", | |
"reactor_data_id": "reactor_1", | |
"radiation_level_sieverts": 0.002, | |
"timestamp": "2023-09-26T08:00:00Z" | |
}, | |
{ | |
"_key": "rad_sensor_2", | |
"sensor_location": "Reactor Control Room", | |
"reactor_data_id": "reactor_2", | |
"radiation_level_sieverts": 0.0005, | |
"timestamp": "2023-09-26T08:05:00Z" | |
} | |
], | |
"employee_records": [ | |
{ | |
"_key": "emp_001", | |
"name": "John Doe", | |
"role": "Reactor Operator", | |
"reactor_data_id": "reactor_1", | |
"radiation_exposure_sieverts": 0.0008, | |
"last_radiation_check": "2023-09-25" | |
}, | |
{ | |
"_key": "emp_002", | |
"name": "Jane Smith", | |
"role": "Safety Engineer", | |
"reactor_data_id": "reactor_2", | |
"radiation_exposure_sieverts": 0.0002, | |
"last_radiation_check": "2023-09-25" | |
} | |
], | |
"waste_management": [ | |
{ | |
"_key": "waste_001", | |
"waste_type": "Spent Fuel", | |
"reactor_data_id": "reactor_1", | |
"storage_location": "Dry Cask Storage", | |
"amount_kg": 500, | |
"last_inspection": "2023-09-20" | |
}, | |
{ | |
"_key": "waste_002", | |
"waste_type": "Low-Level Waste", | |
"reactor_data_id": "reactor_2", | |
"storage_location": "Waste Treatment Facility", | |
"amount_kg": 150, | |
"last_inspection": "2023-09-22" | |
} | |
] | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"conditions": [ | |
{ | |
"if": [ | |
"location", | |
"name", | |
"status", | |
"thermal_capacity_mw", | |
"type" | |
], | |
"then": "reactor_data", | |
"description": "The 'reactor_data' collection contains information about various nuclear reactors, including their operational status, type, thermal capacity, and location.", | |
"fields": { | |
"location": "The specific location or unit designation of the reactor", | |
"name": "The name of the reactor", | |
"status": "The current operational status of the reactor (e", | |
"thermal_capacity_mw": "The thermal capacity of the reactor measured in megawatts (MW)", | |
"type": "The type of reactor, indicating its design and operational characteristics" | |
} | |
}, | |
{ | |
"if": [ | |
"amount_kg", | |
"last_inspection", | |
"reactor_data_id", | |
"storage_location", | |
"waste_type" | |
], | |
"then": "waste_management", | |
"description": "The 'waste_management' collection stores information about various types of waste generated from nuclear reactors, including their storage locations, amounts, and inspection dates.", | |
"fields": { | |
"amount_kg": "The quantity of waste measured in kilograms", | |
"last_inspection": "The date when the last inspection of the waste was conducted", | |
"reactor_data_id": "The identifier for the reactor that generated the waste", | |
"storage_location": "The location where the waste is stored", | |
"waste_type": "The category of waste, such as Spent Fuel or Low-Level Waste" | |
} | |
}, | |
{ | |
"if": [ | |
"last_radiation_check", | |
"name", | |
"radiation_exposure_sieverts", | |
"reactor_data_id", | |
"role" | |
], | |
"then": "employee_records", | |
"description": "The 'employee_records' collection stores information about employees working in a nuclear facility, including their roles, radiation exposure levels, and last radiation check dates.", | |
"fields": { | |
"last_radiation_check": "The date when the last radiation exposure check was conducted for the employee", | |
"name": "The full name of the employee", | |
"radiation_exposure_sieverts": "The amount of radiation exposure the employee has received, measured in sieverts", | |
"reactor_data_id": "A reference ID linking the employee to a specific reactor they are associated with", | |
"role": "The job title or position of the employee within the organization" | |
} | |
}, | |
{ | |
"if": [ | |
"radiation_level_sieverts", | |
"reactor_data_id", | |
"sensor_location", | |
"timestamp" | |
], | |
"then": "radiation_levels", | |
"description": "The 'radiation_levels' collection stores data from radiation sensors, including their locations, associated reactor data, measured radiation levels in sieverts, and timestamps of the readings.", | |
"fields": { | |
"radiation_level_sieverts": "The measured radiation level in sieverts, indicating the amount of radiation exposure", | |
"reactor_data_id": "An identifier linking the radiation data to a specific reactor", | |
"sensor_location": "The physical location of the radiation sensor", | |
"timestamp": "The date and time when the radiation level was recorded, in ISO 8601 format" | |
} | |
} | |
], | |
"relations": { | |
"waste_management": [ | |
{ | |
"foreign_key": "reactor_data_id", | |
"references": "reactor_data" | |
} | |
], | |
"employee_records": [ | |
{ | |
"foreign_key": "reactor_data_id", | |
"references": "reactor_data" | |
} | |
], | |
"radiation_levels": [ | |
{ | |
"foreign_key": "reactor_data_id", | |
"references": "reactor_data" | |
} | |
] | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment