Created
October 5, 2024 22:03
-
-
Save grahama1970/e6b9d6b958ee6a995accee2f67254eec to your computer and use it in GitHub Desktop.
test for 2 files
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import datetime | |
import os | |
import sys | |
import json | |
from typing import Optional, Dict, List, Any | |
from concurrent.futures import ThreadPoolExecutor | |
from arango import ArangoClient, CollectionCreateError | |
from arango.exceptions import ArangoError | |
from loguru import logger | |
from dotenv import load_dotenv | |
from datetime import timezone, datetime | |
from verifaix.utils.infer_types import infer_basic_type | |
load_dotenv('../../.env') | |
from verifaix.arangodb_helper.utils.bulk_insert import insert_bulk_json | |
from verifaix.utils.loguru_setup import setup_logger | |
setup_logger() | |
# Config (from config.py) | |
def merge_arango_config(arango_config: Optional[Dict[str, str]]) -> Dict[str, str]: | |
"""Merge user-provided ArangoDB configuration with defaults from environment variables.""" | |
default_arango_config = { | |
"host": os.getenv("ARANGO_DB_HOST", "http://localhost:8529"), | |
"username": os.getenv("ARANGO_DB_USERNAME", "root"), | |
"password": os.getenv("ARANGO_DB_PASSWORD", "openSesame"), | |
"db_name": os.getenv("ARANGO_DB_NAME", "default_db"), | |
"collection_name": os.getenv("ARANGO_DB_COLLECTION_NAME", "default_collection"), | |
"schema_collection": os.getenv("ARANGO_DB_SCHEMA_COLLECTION_NAME", "schema_cache") | |
} | |
if arango_config: | |
default_arango_config.update(arango_config) | |
return default_arango_config | |
# Connection (from connection.py) | |
def setup_arango_client(config: Dict[str, Any]): | |
"""Setup the ArangoDB client.""" | |
logger.remove() | |
logger.add(sys.stdout, colorize=True, format="<green>{time}</green> <level>{message}</level>") | |
arango_config = merge_arango_config(config.get('arango_config')) | |
client = ArangoClient(hosts=arango_config['host']) | |
return client, arango_config | |
def connect_to_arango(client, config: Dict[str, str]): | |
"""Establish a synchronous connection to ArangoDB.""" | |
try: | |
sys_db = client.db('_system', username=config['username'], password=config['password']) | |
if not sys_db.has_database(config['db_name']): | |
sys_db.create_database(config['db_name']) | |
logger.debug(f"Created database: {config['db_name']}") | |
db = client.db(config['db_name'], username=config['username'], password=config['password']) | |
logger.debug(f"Connected to database: {config['db_name']}") | |
return db | |
except ArangoError as e: | |
logger.error(f"Failed to connect to the database: {str(e)}") | |
raise | |
def ensure_database_exists(client, config: Dict[str, str]) -> bool: | |
"""Ensure that the database exists and check if it contains any collections.""" | |
try: | |
sys_db = client.db('_system', username=config['username'], password=config['password']) | |
if not sys_db.has_database(config['db_name']): | |
sys_db.create_database(config['db_name']) | |
logger.info(f"Created database: {config['db_name']}") | |
return True | |
db = client.db(config['db_name'], username=config['username'], password=config['password']) | |
collections = db.collections() | |
non_system_collections = [col for col in collections if not col['name'].startswith('_')] | |
if not non_system_collections: | |
logger.info(f"Database '{config['db_name']}' exists but is empty.") | |
return True | |
else: | |
logger.info(f"Database '{config['db_name']}' exists and contains collections.") | |
return False | |
except ArangoError as e: | |
logger.error(f"Error ensuring database exists: {str(e)}") | |
raise | |
def ensure_collections_exist(db, collections: Dict[str, List[Dict[str, Any]]], force: bool = False): | |
"""Ensure that collections exist in the database, create them if necessary, and truncate if forced.""" | |
for collection_name, collection_data in collections.items(): | |
if not db.has_collection(collection_name): | |
ensure_collection_exists(db, collection_name) | |
if force: | |
db.collection(collection_name).truncate() | |
def ensure_collection_exists(db, collection_name: str): | |
"""Ensure that the specified collection exists in the database.""" | |
try: | |
if not db.has_collection(collection_name): | |
db.create_collection(collection_name) | |
logger.info(f"Created collection: {collection_name}") | |
else: | |
logger.info(f"Collection {collection_name} already exists.") | |
except CollectionCreateError as e: | |
logger.error(f"Error creating collection '{collection_name}': {e}") | |
raise | |
except ArangoError as e: | |
logger.error(f"ArangoDB error: {e}") | |
raise | |
def initialize_database(client, arango_config, json_file_path: str, force: bool = True, batch_size: int = 500): | |
"""Initialize the database and insert collections and metadata from the provided JSON file.""" | |
try: | |
if not json_file_path: | |
raise ValueError("JSON file path must be provided.") | |
with open(json_file_path, 'r') as file: | |
data = json.load(file) | |
db = client.db(arango_config['db_name'], username=arango_config['username'], password=arango_config['password']) | |
# Step 1: Ensure collections exist (create if necessary) | |
ensure_collections_exist(db, data, force) | |
# Step 2: Insert bulk data into collections | |
bulk_insert_into_collections(db, data, batch_size) | |
logger.info(f"Database initialized with data from {json_file_path}") | |
except FileNotFoundError as e: | |
logger.error(f"JSON file not found: {json_file_path}") | |
raise | |
except json.JSONDecodeError as e: | |
logger.error(f"Error decoding JSON data: {e}") | |
raise | |
except Exception as e: | |
logger.error(f"Unexpected error during database initialization: {str(e)}") | |
raise | |
# Data operations (from data_operations.py) | |
def execute_aql_query(db, query: str, bind_vars: Optional[Dict] = None): | |
"""Execute an AQL query with optional bind variables.""" | |
try: | |
cursor = db.aql.execute(query, bind_vars=bind_vars) | |
return list(cursor) | |
except ArangoError as e: | |
logger.error(f"Error executing AQL query: {str(e)}") | |
raise | |
def get_collection_metadata(db, collection_name: str) -> Optional[Dict[str, Any]]: | |
"""Retrieve metadata for a specific collection.""" | |
try: | |
if not db.has_collection("collection_metadata"): | |
logger.warning("Metadata collection 'collection_metadata' does not exist.") | |
return None | |
query = "FOR doc IN collection_metadata FILTER doc._key == @collection_name RETURN doc" | |
bind_vars = {"collection_name": collection_name} | |
cursor = db.aql.execute(query, bind_vars=bind_vars) | |
metadata = next(cursor, None) | |
if metadata: | |
return metadata | |
else: | |
logger.warning(f"No metadata found for collection: {collection_name}") | |
return None | |
except ArangoError as e: | |
logger.error(f"Error retrieving metadata for collection '{collection_name}': {str(e)}") | |
raise | |
def insert_document(db, collection_name: str, document: dict): | |
"""Inserts a document into the specified collection in the ArangoDB, adding a _last_updated field.""" | |
try: | |
collection = db.collection(collection_name) | |
# Add the _last_updated field with the current timestamp | |
document['_last_updated'] = datetime.now(timezone.utc).isoformat() | |
collection.insert(document) | |
logger.info(f"Inserted document into collection '{collection_name}'.") | |
except ArangoError as e: | |
logger.error(f"Error inserting document into collection '{collection_name}': {e}") | |
raise | |
def update_document(db, collection_name: str, document_key: str, update_data: dict): | |
"""Updates a document in the specified collection in ArangoDB, adding or updating the _last_updated field.""" | |
try: | |
collection = db.collection(collection_name) | |
# Add the _last_updated field with the current timestamp | |
update_data['_last_updated'] = datetime.now(timezone.utc).isoformat() | |
collection.update(document_key, update_data) | |
logger.info(f"Updated document in collection '{collection_name}' with key '{document_key}'.") | |
except ArangoError as e: | |
logger.error(f"Error updating document in collection '{collection_name}' with key '{document_key}': {e}") | |
raise | |
def upsert_document(db, collection_name: str, document: dict): | |
""" | |
Upserts a document into the specified collection in the ArangoDB. | |
If the document exists, it is updated. Otherwise, it is inserted. | |
The _last_updated field is added or updated with the current timestamp. | |
Args: | |
db: The ArangoDB connection object. | |
collection_name (str): The name of the collection to store the document. | |
document (dict): The document to be upserted. It must contain the _key field. | |
""" | |
try: | |
# Ensure the collection exists | |
if not db.has_collection(collection_name): | |
db.create_collection(collection_name) | |
logger.info(f"Created collection: {collection_name}") | |
collection = db.collection(collection_name) | |
# Add or update the _last_updated field with the current timestamp | |
document['_last_updated'] = datetime.now(timezone.utc).isoformat() | |
# Check if the document exists based on _key | |
document_key = document.get('_key') | |
if not document_key: | |
raise ValueError("Document must contain a '_key' field for upsert operation.") | |
existing_doc = collection.get(document_key) | |
if existing_doc: | |
# If document exists, update it | |
collection.update(document) | |
logger.info(f"Updated document with _key '{document_key}' in collection '{collection_name}'.") | |
else: | |
# If document doesn't exist, insert it | |
collection.insert(document) | |
logger.info(f"Inserted document with _key '{document_key}' into collection '{collection_name}'.") | |
except ArangoError as e: | |
logger.error(f"Error upserting document in collection '{collection_name}': {str(e)}") | |
raise | |
async def fetch_documents(db, collection_name: str, limit: int = 5) -> List[Dict[str, Any]]: | |
"""Fetch sample documents from the specified collection.""" | |
try: | |
query = f"FOR doc IN {collection_name} LIMIT {limit} RETURN doc" | |
cursor = await asyncio.to_thread(execute_aql_query, db, query) | |
return list(cursor) if cursor else [] | |
except ArangoError as e: | |
logger.error(f"Error fetching documents from collection '{collection_name}': {str(e)}") | |
return [] | |
def bulk_insert_into_collections(db, collections: Dict[str, List[Dict[str, Any]]], batch_size: int = 500): | |
"""Insert bulk data into collections.""" | |
for collection_name, collection_data in collections.items(): | |
insert_bulk_json(db, collection_name, collection_data, batch_size) | |
def collection_last_updated(db, collection_name: str) -> Optional[str]: | |
""" | |
Get the most recent _last_updated field or add the current time to all documents if it doesn't exist. | |
Args: | |
db: ArangoDB connection. | |
collection_name (str): The name of the collection to query. | |
Returns: | |
str: The most recent _last_updated field in ISO format, or the current time if updated. | |
""" | |
current_time = datetime.now(timezone.utc).isoformat() | |
try: | |
# Query for the most recent _last_updated field | |
query = f""" | |
FOR doc IN {collection_name} | |
SORT doc._last_updated DESC | |
LIMIT 1 | |
RETURN doc._last_updated | |
""" | |
cursor = db.aql.execute(query) | |
last_updated = cursor.next() | |
# If _last_updated is found, return it | |
if last_updated: | |
return last_updated | |
else: | |
# If _last_updated doesn't exist, update all documents with the current time | |
update_query = f""" | |
FOR doc IN {collection_name} | |
UPDATE doc WITH {{ _last_updated: "{current_time}" }} IN {collection_name} | |
""" | |
db.aql.execute(update_query) | |
return current_time | |
except ArangoError as e: | |
# Handle any errors, and add logging | |
print(f"Error while querying/updating collection '{collection_name}': {str(e)}") | |
raise | |
def get_collection_indexes(db, collection_name: str) -> List[Dict]: | |
""" | |
Fetches the index information for a given collection in ArangoDB. | |
""" | |
try: | |
collection = db.collection(collection_name) | |
indexes = collection.indexes() | |
return indexes | |
except Exception as e: | |
logger.error(f"Error retrieving indexes for collection '{collection_name}': {str(e)}") | |
return [] | |
def get_field_types(db, collection_name: str) -> Dict[str, str]: | |
""" | |
Infers field types for a given collection in ArangoDB by analyzing a sample document. | |
Handles basic types such as string, integer, float, boolean, and datetime. | |
Args: | |
db: ArangoDB database connection. | |
collection_name (str): The name of the collection to analyze. | |
Returns: | |
Dict[str, str]: A dictionary where the keys are field names and the values are inferred types. | |
""" | |
try: | |
collection = db.collection(collection_name) | |
# Check if the collection has a schema defined first (if ArangoDB schema validation is used) | |
collection_info = db.collection(collection_name).properties() | |
schema_info = collection_info.get('schema', {}) | |
field_types = {} | |
# If schema is defined, use it | |
if schema_info and 'rule' in schema_info: | |
rule = schema_info['rule'] | |
if 'properties' in rule: | |
field_types = { | |
field: field_info.get('type', 'unknown') | |
for field, field_info in rule['properties'].items() | |
} | |
# If no schema is available, infer field types from a sample document | |
if not field_types: | |
sample_document = next(collection.all(), None) # Retrieve a sample document | |
if sample_document: | |
field_types = {k: infer_basic_type(v) for k, v in sample_document.items() if not k.startswith('_')} | |
return field_types | |
except Exception as e: | |
logger.error(f"Error inferring field types for collection '{collection_name}': {str(e)}") | |
return {} | |
def sample_documents(db, collection_name: str, sample_size: int = 5) -> list: | |
""" | |
Fetches a random sample of documents from the specified collection. | |
Args: | |
db: The database connection. | |
collection_name (str): The name of the collection to sample from. | |
sample_size (int): The number of documents to sample. Default is 5. | |
Returns: | |
list: A list of sampled documents. | |
""" | |
collection = db.collection(collection_name) | |
try: | |
cursor = collection.all() | |
# Randomly sample a small number of documents for context | |
return [doc for i, doc in enumerate(cursor) if i < sample_size] | |
except Exception as e: | |
logger.error(f"Error sampling documents from collection '{collection_name}': {str(e)}") | |
return [] | |
### | |
# Usage examples | |
### | |
def usage_example_query_random_glossary_terms(): | |
"""Query the 'Glossary' collection for 5 randomly sampled documents.""" | |
config = { | |
"arango_config": { | |
"host": "http://localhost:8529", | |
"db_name": "nuclear", | |
"username": "root", | |
"password": "openSesame", | |
"collection_name": "Glossary" | |
}, | |
"log_level": "DEBUG" | |
} | |
client, arango_config = setup_arango_client(config) | |
db = connect_to_arango(client, arango_config) | |
query = """ | |
FOR doc IN Glossary | |
SORT RAND() | |
LIMIT 5 | |
RETURN doc | |
""" | |
try: | |
results = execute_aql_query(db, query) | |
print("Random sample of 5 Glossary terms:") | |
for idx, doc in enumerate(results, 1): | |
print(f"{idx}. Term: {doc.get('term', 'N/A')}") | |
print(f" Definition: {doc.get('definition', 'N/A')}") | |
print(f" Source: {doc.get('source', 'N/A')}") | |
print("---") | |
except Exception as e: | |
logger.error(f"Error querying Glossary collection: {str(e)}") | |
if __name__ == "__main__": | |
usage_example_query_random_glossary_terms() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from loguru import logger | |
from typing import Optional, List, Dict, Any | |
import json | |
import datetime | |
from functools import lru_cache | |
from concurrent.futures import ThreadPoolExecutor | |
from arango import ArangoError, CollectionCreateError | |
from verifaix.arangodb_helper.utils import filter_document_fields | |
from verifaix.arangodb_helper.utils.human_readable_last_updated import human_readable_last_updated | |
from verifaix.arangodb_helper.utils.upsert_decision_schema import upsert_decision_schema | |
from verifaix.utils.calculate_token_count import calculate_token_count | |
from verifaix.utils.count_tokens_in_json import count_tokens_in_json | |
from verifaix.utils.save_cache_to_json import save_cache_to_json | |
from verifaix.arangodb_helper.arango_client import ( | |
fetch_documents, get_collection_indexes, get_field_types, insert_document, sample_documents, setup_arango_client, connect_to_arango, | |
ensure_collection_exists, execute_aql_query, collection_last_updated | |
) | |
from verifaix.arangodb_helper.utils.bulk_insert import upsert_json_list | |
from verifaix.arangodb_helper.utils.generate_llm_metadata import generate_llm_metadata | |
from verifaix.arangodb_helper.utils.get_collection_metadata import get_collection_metadata, get_collection_properties | |
from verifaix.arangodb_helper.utils.infer_cardinality import infer_cardinality | |
from verifaix.arangodb_helper.utils.filter_document_fields import filter_document_fields | |
from verifaix.utils.loguru_setup import setup_logger | |
setup_logger() | |
### | |
# Main functions: get_schemas (which returns both), get_database_schema, get_decision_schema | |
### | |
async def get_schemas(db, config: dict) -> tuple: | |
""" | |
Main function to retrieve and return both database schema and decision schema. | |
""" | |
try: | |
# Retrieve the database schema. Await because there's some LLM calls in it. | |
database_schema = await get_database_schema(db, config) | |
# Create the decision schema. Not awaiting because it's not doing anything intensive. | |
decision_schema = get_decision_schema(database_schema, db, config) | |
return decision_schema, database_schema | |
except Exception as e: | |
logger.error(f"Error during schema retrieval: {str(e)}") | |
raise | |
async def get_database_schema(db, config: dict) -> dict: | |
try: | |
# Retrieve the list of collections in the database using a non-blocking call | |
collections = await asyncio.to_thread(db.collections) | |
# Filter the collections based on the included collections specified in the config | |
included_collections = config.get('INCLUDED_COLLECTIONS', None) | |
included_collections = set(col.lower().replace(" ", "_") for col in (included_collections or [])) | |
if included_collections: | |
collections = [col for col in collections if col['name'].lower() in included_collections] | |
schema = {} | |
metadata_tasks = [] | |
properties_tasks = [] | |
collection_names_needing_metadata = [] | |
# Iterate over each collection to initiate metadata or property retrieval tasks | |
for collection in collections: | |
collection_name = collection['name'] | |
# Skip system collections | |
if collection_name.startswith('_'): | |
continue | |
# Fetch the document count for each collection | |
document_count = await asyncio.to_thread(db[collection_name].count) | |
# Last updated collection timestamp | |
last_updated = await asyncio.to_thread(collection_last_updated, db, collection_name) | |
# Fetch the indexes for the collection | |
indexes = await asyncio.to_thread(get_collection_indexes, db, collection_name) | |
# Fetch field types (inferred from schema or sample documents) | |
field_types = await asyncio.to_thread(get_field_types, db, collection_name) | |
# Sample a few documents from the collection for LLM context | |
sample_document_result = await asyncio.to_thread( | |
sample_documents, db, collection_name, sample_size=2 | |
) | |
# Remove _key, _id, _rev from sample documents | |
# cleaned_sample_documents = [ | |
# filter_document_fields(doc) | |
# for doc in sample_documents | |
# ] | |
# Attempt to fetch existing metadata for the collection | |
metadata = await get_collection_metadata(db, schema, collection_name, config) | |
if metadata is None: | |
# If metadata is not available, create a task to generate it using LLM | |
documents = await fetch_documents(db, collection_name) | |
if documents: | |
metadata_task = asyncio.create_task(generate_llm_metadata(collection_name, documents, config)) | |
metadata_task.set_name(collection_name) | |
metadata_tasks.append(metadata_task) | |
schema[collection_name] = { | |
'name': collection_name, | |
'document_count': document_count, # Include document count | |
'indexes': indexes, # Add indexes | |
'field_types': field_types, # Add inferred object types | |
'sample_documents': sample_document_result, # Add random sample documents for LLM context | |
'metadata': None # Metadata will be updated later | |
} | |
else: | |
# If metadata exists, add a task to retrieve collection properties | |
properties_task = asyncio.to_thread(get_collection_properties, db, collection_name) | |
properties_tasks.append((collection_name, properties_task)) | |
schema[collection_name] = { | |
'name': collection_name, | |
'document_count': document_count, | |
'metadata': metadata, | |
'last_updated': last_updated, | |
'indexes': indexes, # Add indexes | |
'field_types': field_types, # Add inferred or existing field types | |
'sample_documents': sample_document_result # Add random sample documents for LLM context | |
} | |
# Run all metadata generation tasks concurrently | |
if metadata_tasks: | |
generated_metadata_list = await asyncio.gather(*metadata_tasks) | |
for collection_name, metadata in zip(collection_names_needing_metadata, generated_metadata_list): | |
if metadata: | |
schema[collection_name] = { | |
'name': collection_name, | |
'metadata': metadata, | |
'indexes': indexes, | |
'field_types': field_types, | |
'sample_documents': sample_document_result | |
} | |
# Add a task to retrieve collection properties after metadata is generated | |
properties_task = asyncio.to_thread(get_collection_properties, db, collection_name) | |
properties_tasks.append((collection_name, properties_task)) | |
# Run all property retrieval tasks concurrently | |
properties_results = await asyncio.gather(*(task for _, task in properties_tasks)) | |
for (collection_name, _), properties in zip(properties_tasks, properties_results): | |
if collection_name in schema: | |
schema[collection_name]['properties'] = properties | |
# Cache the updated schema to avoid redundant LLM calls in the future | |
await asyncio.to_thread(upsert_json_list, db, "schema_cache", schema) | |
logger.debug(f"Database schema retrieved: {schema}") | |
return schema | |
except Exception as e: | |
logger.error(f"Failed to get database schema: {str(e)}") | |
raise | |
def get_decision_schema(arangodb_schema: Dict, db, config: Dict) -> Dict: | |
""" | |
Creates an optimized decision schema tailored for efficient processing by an LLM. | |
Detects foreign key relationships between collections dynamically using common naming patterns. | |
""" | |
try: | |
included_collections = config.get("included_collections", []) | |
ignore_empty_collections = config.get("ignore_empty_collections", False) | |
decision_schema = { | |
"conditions": [], | |
"relations": {}, | |
} | |
# Suffixes used to identify potential foreign keys | |
foreign_key_suffixes = ["_id", "_fk", "id_fk", "id", "fk"] | |
# Iterate through each collection in the schema | |
for collection_name, details in arangodb_schema.items(): | |
if collection_name not in included_collections: | |
continue | |
properties = details.get("properties", {}) | |
document_count = details.get("document_count", 0) | |
last_updated = details.get("last_updated", "N/A") | |
metadata = details.get("metadata", {}) | |
indexes = details.get("indexes", []) | |
# Fetch and filter sample documents | |
sample_documents = details.get("sample_documents", []) | |
cleaned_sample_documents = [ | |
filter_document_fields(doc) for doc in sample_documents # Properly filter out unwanted fields | |
] | |
# Skip collection if empty and ignore_empty_collections is set | |
if ignore_empty_collections and not properties: | |
continue | |
# Extract description and field descriptions from metadata | |
description = metadata.get("description", "") | |
fields_description = metadata.get("fields", {}) | |
# Filter out properties that start with '_' (like _key, _id, _rev) | |
filtered_properties = { | |
field: properties[field] for field in properties if not field.startswith('_') | |
} | |
# Prepare a simplified fields dictionary using descriptions from metadata | |
simplified_fields = { | |
field: { | |
"description": fields_description.get(field, "").split(".")[0], | |
"type": filtered_properties.get(field, "unknown") | |
} for field in filtered_properties.keys() | |
} | |
# Detect foreign key fields based on naming patterns | |
foreign_key_fields = [ | |
field for field in filtered_properties.keys() | |
if any(field.endswith(suffix) for suffix in foreign_key_suffixes) | |
] | |
# Infer relationships for foreign key fields | |
for fk_field in foreign_key_fields: | |
# Assume the foreign key references another collection named similarly | |
target_collection = fk_field.replace('_id', '') | |
if target_collection in arangodb_schema: | |
if collection_name not in decision_schema["relations"]: | |
decision_schema["relations"][collection_name] = [] | |
# Infer cardinality between source and target collection | |
cardinality = infer_cardinality(db, collection_name, fk_field, target_collection) | |
decision_schema["relations"][collection_name].append({ | |
"foreign_key": fk_field, | |
"references": target_collection, | |
"cardinality": cardinality | |
}) | |
# Add each collection's unique conditions to the decision schema | |
condition_data = { | |
"collection_name": collection_name, | |
"document_count": document_count, | |
"last_updated": human_readable_last_updated(last_updated), # Human-readable last updated | |
"if": list(filtered_properties.keys()), # Use filtered properties | |
"then": collection_name, | |
"description": description, | |
"fields": simplified_fields, | |
"indexes": [ # Simplified indexes, only for primary type | |
{"fields": idx["fields"], "type": idx["type"]} | |
for idx in indexes if idx["type"] == "primary" | |
], | |
"sample_documents": cleaned_sample_documents[:1] # Limit to 1 sample document, filtered properly | |
} | |
# Add the condition data to the decision schema | |
decision_schema["conditions"].append(condition_data) | |
# Calculate token count for the entire decision schema at the end | |
schema_json = json.dumps(decision_schema) | |
total_token_count = calculate_token_count(schema_json) | |
decision_schema["token_count"] = total_token_count | |
# Return the populated decision schema if it has conditions | |
if decision_schema["conditions"]: | |
logger.info("Saving decision schema to cache") | |
save_cache_to_json(decision_schema, 'verifaix/arangodb_helper/data/decision_schema.json') | |
# Cache the updated decision schema to avoid redundant LLM calls in the future | |
upsert_decision_schema(db, "decision_schema", decision_schema) | |
return decision_schema | |
else: | |
logger.warning("No decision schema found") | |
return None | |
except Exception as e: | |
logger.error(f"Error creating decision schema: {str(e)}", exc_info=True) | |
raise | |
### | |
# Helper functions: get_collection_metadata, generate_metadata_for_collection_from_llm, get_collection_properties | |
### | |
# async def get_collection_metadata(db, schema: dict, collection_name: str, config: dict) -> Optional[Dict[str, Any]]: | |
# try: | |
# # Ensure the 'collection_metadata' collection exists in the database | |
# await asyncio.to_thread(ensure_collection_exists, db, "collection_metadata") | |
# # Query the 'collection_metadata' collection to retrieve existing metadata | |
# query = "FOR doc IN collection_metadata FILTER doc._key == @collection_name RETURN doc" | |
# bind_vars = {"collection_name": collection_name} | |
# cursor = await asyncio.to_thread(execute_aql_query, db, query, bind_vars) | |
# metadata = cursor[0] if cursor else None | |
# # Return the metadata if it exists | |
# if metadata: | |
# return metadata | |
# return None | |
# except ArangoError as e: | |
# logger.error(f"Error retrieving metadata for collection '{collection_name}': {str(e)}") | |
# raise | |
# async def generate_metadata_for_collection_from_llm(collection_name: str, documents: List[Dict[str, Any]], config: dict) -> Optional[Dict[str, Any]]: | |
# try: | |
# logger.debug(f"Generating metadata for collection '{collection_name}' using LLM.") | |
# # Generate metadata for the collection using an LLM call | |
# metadata = { | |
# "_key": collection_name, | |
# "name": collection_name, | |
# "description": f"Metadata for the {collection_name} collection.", | |
# "fields": {k: f"Description for field {k}" for k in documents[0].keys()} if documents else {} | |
# } | |
# return metadata | |
# except Exception as e: | |
# logger.error(f"Failed to generate metadata for collection '{collection_name}': {str(e)}") | |
# return None | |
### | |
# Usage example: get_database_schema, get_decision_schema | |
### | |
if __name__ == "__main__": | |
config = { | |
"arango_config": { | |
"host": "http://localhost:8529", | |
"db_name": "nuclear", | |
"username": "root", | |
"password": "openSesame", | |
}, | |
"llm_config": { | |
"model": "openai/gpt-4o-mini", | |
"json_mode": True | |
}, | |
"included_collections": | |
[ | |
'emergency_protocols', 'employee_records', 'radiation_levels', | |
'reactor_data', 'waste_management' | |
], | |
} | |
# Set up the ArangoDB client and connect to the database | |
client, arango_config = setup_arango_client(config) | |
db = connect_to_arango(client, arango_config) | |
decision_schema, database_schema = asyncio.run(get_schemas(db, config)) | |
print(decision_schema) | |
# print(database_schema) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment