Last active
May 26, 2025 11:33
-
-
Save davidlu1001/7e106323fa0c9460eda81f6399c98ea4 to your computer and use it in GitHub Desktop.
column-tag-add.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import csv | |
import json | |
import logging | |
import sys | |
import os | |
from pathlib import Path | |
import requests | |
import argparse | |
from time import sleep | |
from sys import exit | |
import traceback | |
from dataclasses import dataclass, asdict | |
from typing import Set, Dict, List, Tuple, Optional | |
from collections import defaultdict | |
import urllib.parse | |
import time | |
import hashlib | |
import sqlite3 | |
from datetime import datetime, timedelta | |
import signal | |
import atexit | |
from contextlib import contextmanager | |
# Global variables | |
logger = None | |
shutdown_requested = False | |
def setup_logging(log_level=logging.INFO, log_dir="./logs"): | |
"""Setup comprehensive logging with rotation for large scale operations""" | |
os.makedirs(log_dir, exist_ok=True) | |
# Create timestamp for log files | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
# Configure multiple log handlers | |
logger = logging.getLogger('immuta_enterprise') | |
logger.setLevel(log_level) | |
# Clear existing handlers | |
for handler in logger.handlers[:]: | |
logger.removeHandler(handler) | |
# Console handler with concise format | |
console_handler = logging.StreamHandler(sys.stdout) | |
console_handler.setLevel(logging.INFO) | |
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
console_handler.setFormatter(console_formatter) | |
logger.addHandler(console_handler) | |
# Main log file with detailed format | |
main_log_file = os.path.join(log_dir, f"immuta_tagging_{timestamp}.log") | |
file_handler = logging.FileHandler(main_log_file, encoding='utf-8') | |
file_handler.setLevel(log_level) | |
file_formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s' | |
) | |
file_handler.setFormatter(file_formatter) | |
logger.addHandler(file_handler) | |
# Error-only log file | |
error_log_file = os.path.join(log_dir, f"immuta_errors_{timestamp}.log") | |
error_handler = logging.FileHandler(error_log_file, encoding='utf-8') | |
error_handler.setLevel(logging.ERROR) | |
error_handler.setFormatter(file_formatter) | |
logger.addHandler(error_handler) | |
return logger, main_log_file, error_log_file | |
class ImmutaggingError(Exception): | |
"""Custom exception for Immuta tagging operations""" | |
pass | |
@dataclass | |
class ColumnTag: | |
"""Enhanced data class with enterprise features""" | |
row_id: int | |
database: str | |
schema: str | |
table: str | |
column: str | |
immuta_tag: str | |
immuta_schema_project: str | |
immuta_datasource: str | |
immuta_sql_schema: str | |
immuta_column_key: str | |
# Additional enterprise fields | |
source_file_line: int = 0 | |
checksum: str = "" | |
def __post_init__(self): | |
# Generate checksum for data integrity | |
data_str = f"{self.database}|{self.schema}|{self.table}|{self.column}|{self.immuta_tag}" | |
self.checksum = hashlib.md5(data_str.encode()).hexdigest()[:8] | |
def __str__(self) -> str: | |
return f"Row {self.row_id}: {self.database}.{self.schema}.{self.table}.{self.column} -> {self.immuta_tag}" | |
class ProgressTracker: | |
"""Enterprise-grade progress tracking with persistence and recovery""" | |
def __init__(self, db_path: str): | |
self.db_path = db_path | |
self.init_database() | |
self.start_time = time.time() | |
def init_database(self): | |
"""Initialize SQLite database for progress tracking""" | |
with sqlite3.connect(self.db_path) as conn: | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS processing_status ( | |
row_id INTEGER PRIMARY KEY, | |
database_name TEXT, | |
schema_name TEXT, | |
table_name TEXT, | |
column_name TEXT, | |
tag_name TEXT, | |
status TEXT, -- 'pending', 'success', 'failed', 'skipped' | |
error_message TEXT, | |
processed_at TEXT, | |
retry_count INTEGER DEFAULT 0 | |
) | |
''') | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS execution_metadata ( | |
id INTEGER PRIMARY KEY, | |
session_id TEXT, | |
start_time TEXT, | |
end_time TEXT, | |
total_records INTEGER, | |
completed_records INTEGER, | |
failed_records INTEGER, | |
skipped_records INTEGER, | |
environment TEXT, | |
script_version TEXT | |
) | |
''') | |
# Create indexes for better performance | |
conn.execute('CREATE INDEX IF NOT EXISTS idx_status ON processing_status(status)') | |
conn.execute('CREATE INDEX IF NOT EXISTS idx_row_id ON processing_status(row_id)') | |
conn.commit() | |
def load_existing_progress(self) -> Dict[int, str]: | |
"""Load existing progress to support resume functionality""" | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.execute(''' | |
SELECT row_id, status FROM processing_status | |
WHERE status IN ('success', 'skipped') | |
''') | |
result = {row_id: status for row_id, status in cursor.fetchall()} | |
logger.info(f"Loaded {len(result)} previously processed records") | |
return result | |
def update_record_status(self, row_id: int, database: str, schema: str, table: str, | |
column: str, tag: str, status: str, error_message: str = None): | |
"""Update processing status for a record""" | |
with sqlite3.connect(self.db_path) as conn: | |
conn.execute(''' | |
INSERT OR REPLACE INTO processing_status | |
(row_id, database_name, schema_name, table_name, column_name, tag_name, | |
status, error_message, processed_at, retry_count) | |
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, | |
COALESCE((SELECT retry_count FROM processing_status WHERE row_id = ?), 0) + | |
CASE WHEN ? = 'failed' THEN 1 ELSE 0 END) | |
''', (row_id, database, schema, table, column, tag, status, error_message, | |
datetime.now().isoformat(), row_id, status)) | |
conn.commit() | |
def get_progress_summary(self) -> Dict: | |
"""Get current progress summary""" | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.execute(''' | |
SELECT status, COUNT(*) FROM processing_status GROUP BY status | |
''') | |
status_counts = dict(cursor.fetchall()) | |
cursor = conn.execute('SELECT COUNT(*) FROM processing_status') | |
total = cursor.fetchone()[0] | |
return { | |
'total': total, | |
'success': status_counts.get('success', 0), | |
'failed': status_counts.get('failed', 0), | |
'skipped': status_counts.get('skipped', 0), | |
'pending': status_counts.get('pending', 0) | |
} | |
def get_failed_records(self) -> List[Dict]: | |
"""Get all failed records for retry or analysis""" | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.execute(''' | |
SELECT row_id, database_name, schema_name, table_name, column_name, | |
tag_name, error_message, retry_count | |
FROM processing_status | |
WHERE status = 'failed' | |
ORDER BY retry_count, row_id | |
''') | |
columns = [desc[0] for desc in cursor.description] | |
return [dict(zip(columns, row)) for row in cursor.fetchall()] | |
def export_results(self, output_file: str): | |
"""Export all results to CSV for analysis""" | |
with sqlite3.connect(self.db_path) as conn: | |
cursor = conn.execute(''' | |
SELECT * FROM processing_status ORDER BY row_id | |
''') | |
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile: | |
writer = csv.writer(csvfile) | |
writer.writerow([desc[0] for desc in cursor.description]) | |
writer.writerows(cursor.fetchall()) | |
class EnterpriseRateLimiter: | |
"""Enterprise-grade rate limiter with adaptive behavior""" | |
def __init__(self, sleep_interval: float = 0.2, adaptive: bool = True): | |
self.base_sleep_interval = sleep_interval | |
self.current_sleep_interval = sleep_interval | |
self.adaptive = adaptive | |
self.last_call_time = 0 | |
self.consecutive_errors = 0 | |
self.success_count = 0 | |
# Statistics tracking | |
self.call_history = [] | |
self.error_history = [] | |
calls_per_second = 1.0 / sleep_interval if sleep_interval > 0 else 5.0 | |
logger.info(f"Enterprise rate limiter initialized: {sleep_interval}s base interval, " | |
f"~{calls_per_second:.1f} calls/second, adaptive={adaptive}") | |
def wait_if_needed(self): | |
"""Wait with adaptive rate limiting based on error patterns""" | |
current_time = time.time() | |
# Clean old history (keep last 5 minutes) | |
cutoff_time = current_time - 300 | |
self.call_history = [t for t in self.call_history if t > cutoff_time] | |
self.error_history = [t for t in self.error_history if t > cutoff_time] | |
# Adaptive rate limiting based on recent errors | |
if self.adaptive: | |
recent_error_rate = len(self.error_history) / max(len(self.call_history), 1) | |
if recent_error_rate > 0.1: # More than 10% error rate | |
self.current_sleep_interval = min(self.base_sleep_interval * 2, 2.0) | |
elif recent_error_rate < 0.01: # Less than 1% error rate | |
self.current_sleep_interval = max(self.base_sleep_interval * 0.8, 0.1) | |
else: | |
self.current_sleep_interval = self.base_sleep_interval | |
# Enforce minimum interval | |
time_since_last_call = current_time - self.last_call_time | |
if time_since_last_call < self.current_sleep_interval: | |
sleep_time = self.current_sleep_interval - time_since_last_call | |
time.sleep(sleep_time) | |
self.last_call_time = time.time() | |
self.call_history.append(self.last_call_time) | |
def record_error(self): | |
"""Record an error for adaptive rate limiting""" | |
self.consecutive_errors += 1 | |
self.error_history.append(time.time()) | |
# Exponential backoff for consecutive errors | |
if self.consecutive_errors > 3: | |
backoff_time = min(2 ** (self.consecutive_errors - 3), 30) | |
logger.warning(f"Consecutive errors detected. Backing off for {backoff_time}s") | |
time.sleep(backoff_time) | |
def record_success(self): | |
"""Record a success to reset error counters""" | |
self.consecutive_errors = 0 | |
self.success_count += 1 | |
class EnterpriseImmutaggingClient: | |
"""Enterprise-grade client with circuit breaker, retry logic, and health monitoring""" | |
def __init__(self, immuta_url: str, api_key: str, rate_limiter: EnterpriseRateLimiter): | |
self.immuta_url = immuta_url.rstrip('/') | |
self.headers = { | |
'Authorization': f'Bearer {api_key}', | |
'Content-Type': 'application/json', | |
'x-api-key': api_key | |
} | |
self.rate_limiter = rate_limiter | |
self.session = requests.Session() | |
self.session.headers.update(self.headers) | |
# Circuit breaker state | |
self.circuit_breaker_failures = 0 | |
self.circuit_breaker_opened_at = None | |
self.circuit_breaker_threshold = 10 | |
self.circuit_breaker_timeout = 300 # 5 minutes | |
# Health monitoring | |
self.health_check_interval = 600 # 10 minutes | |
self.last_health_check = 0 | |
def _circuit_breaker_check(self): | |
"""Check if circuit breaker should prevent API calls""" | |
if self.circuit_breaker_opened_at: | |
if time.time() - self.circuit_breaker_opened_at > self.circuit_breaker_timeout: | |
logger.info("Circuit breaker timeout expired, attempting to reset") | |
self.circuit_breaker_opened_at = None | |
self.circuit_breaker_failures = 0 | |
else: | |
raise ImmutaggingError("Circuit breaker is open - too many consecutive failures") | |
def _health_check(self): | |
"""Perform periodic health check""" | |
current_time = time.time() | |
if current_time - self.last_health_check > self.health_check_interval: | |
try: | |
# Simple health check - get tag list with minimal parameters | |
health_response = self.session.get( | |
f"{self.immuta_url}/tag", | |
params={"limit": 1}, | |
timeout=10 | |
) | |
health_response.raise_for_status() | |
logger.debug("Health check passed") | |
self.last_health_check = current_time | |
except Exception as e: | |
logger.warning(f"Health check failed: {str(e)}") | |
def make_request(self, method: str, endpoint: str, data: Optional[Dict] = None, | |
params: Optional[Dict] = None, retries: int = 5) -> requests.Response: | |
"""Make HTTP request with comprehensive error handling and retry logic""" | |
self._circuit_breaker_check() | |
self._health_check() | |
self.rate_limiter.wait_if_needed() | |
url = f"{self.immuta_url}{endpoint}" | |
for attempt in range(retries): | |
try: | |
if method.upper() == 'GET': | |
response = self.session.get(url, params=params, timeout=60) | |
elif method.upper() == 'POST': | |
response = self.session.post(url, json=data, params=params, timeout=60) | |
else: | |
raise ValueError(f"Unsupported HTTP method: {method}") | |
# Handle different HTTP status codes | |
if response.status_code == 429: | |
retry_after = int(response.headers.get('Retry-After', 60)) | |
logger.warning(f"Rate limited (429). Waiting {retry_after} seconds before retry {attempt + 1}") | |
time.sleep(retry_after) | |
self.rate_limiter.record_error() | |
continue | |
elif response.status_code >= 500: | |
# Server errors - retry with exponential backoff | |
backoff_time = min(2 ** attempt, 60) | |
logger.warning(f"Server error {response.status_code}. Retrying in {backoff_time}s (attempt {attempt + 1})") | |
time.sleep(backoff_time) | |
self.rate_limiter.record_error() | |
continue | |
elif response.status_code >= 400: | |
# Client errors - usually don't retry | |
error_msg = f"Client error {response.status_code}: {response.text[:200]}" | |
logger.error(error_msg) | |
self.rate_limiter.record_error() | |
raise ImmutaggingError(error_msg) | |
# Success | |
response.raise_for_status() | |
self.rate_limiter.record_success() | |
self.circuit_breaker_failures = 0 | |
return response | |
except requests.exceptions.Timeout: | |
logger.warning(f"Request timeout (attempt {attempt + 1}/{retries})") | |
self.rate_limiter.record_error() | |
if attempt == retries - 1: | |
self.circuit_breaker_failures += 1 | |
if self.circuit_breaker_failures >= self.circuit_breaker_threshold: | |
self.circuit_breaker_opened_at = time.time() | |
logger.error("Circuit breaker opened due to excessive timeouts") | |
raise ImmutaggingError(f"Request timeout after {retries} attempts") | |
time.sleep(2 ** attempt) | |
except requests.exceptions.ConnectionError as e: | |
logger.warning(f"Connection error (attempt {attempt + 1}/{retries}): {str(e)}") | |
self.rate_limiter.record_error() | |
if attempt == retries - 1: | |
self.circuit_breaker_failures += 1 | |
if self.circuit_breaker_failures >= self.circuit_breaker_threshold: | |
self.circuit_breaker_opened_at = time.time() | |
logger.error("Circuit breaker opened due to connection errors") | |
raise ImmutaggingError(f"Connection failed after {retries} attempts: {str(e)}") | |
time.sleep(2 ** attempt) | |
except Exception as e: | |
logger.error(f"Unexpected error (attempt {attempt + 1}/{retries}): {str(e)}") | |
self.rate_limiter.record_error() | |
if attempt == retries - 1: | |
raise ImmutaggingError(f"Unexpected error after {retries} attempts: {str(e)}") | |
time.sleep(2 ** attempt) | |
raise ImmutaggingError("Exhausted all retry attempts") | |
class EnterpriseTaggingOrchestrator: | |
"""Main orchestrator for enterprise-scale tagging operations""" | |
def __init__(self, args, progress_tracker): | |
self.args = args | |
self.progress_tracker = progress_tracker | |
# Initialize API client | |
rate_limiter = EnterpriseRateLimiter( | |
sleep_interval=args.sleep, | |
adaptive=args.adaptive_rate_limiting | |
) | |
self.api_client = EnterpriseImmutaggingClient(args.immuta_url, args.api_key, rate_limiter) | |
# State management | |
self.validation_cache = {} | |
self.column_tags_cache = {} | |
# Statistics | |
self.stats = { | |
'start_time': time.time(), | |
'validation_time': 0, | |
'processing_time': 0, | |
'total_api_calls': 0, | |
'cache_hits': 0, | |
'cache_misses': 0 | |
} | |
# Graceful shutdown handling | |
signal.signal(signal.SIGINT, self._signal_handler) | |
signal.signal(signal.SIGTERM, self._signal_handler) | |
atexit.register(self._cleanup) | |
self.shutdown_requested = False | |
global shutdown_requested | |
def _signal_handler(self, signum, frame): | |
"""Handle graceful shutdown""" | |
logger.info(f"Received signal {signum}. Initiating graceful shutdown...") | |
self.shutdown_requested = True | |
global shutdown_requested | |
shutdown_requested = True | |
def _cleanup(self): | |
"""Cleanup function called on exit""" | |
if hasattr(self, 'progress_tracker'): | |
summary = self.progress_tracker.get_progress_summary() | |
logger.info(f"Final progress summary: {summary}") | |
@contextmanager | |
def timer(self, operation_name: str): | |
"""Context manager for timing operations""" | |
start_time = time.time() | |
try: | |
yield | |
finally: | |
elapsed = time.time() - start_time | |
logger.info(f"{operation_name} completed in {elapsed:.2f} seconds") | |
def validate_and_cache_references(self, records: List[ColumnTag]) -> bool: | |
"""Validate and cache all reference data (schema projects, data sources, etc.)""" | |
logger.info("Starting comprehensive validation of reference data...") | |
with self.timer("Reference validation"): | |
# Extract unique sets | |
schema_projects = {r.immuta_schema_project for r in records} | |
data_sources = {r.immuta_datasource for r in records} | |
tags = {r.immuta_tag for r in records} | |
logger.info(f"Validating {len(schema_projects)} schema projects, " | |
f"{len(data_sources)} data sources, {len(tags)} tags") | |
# Validate each category with progress tracking | |
validation_results = {} | |
try: | |
# Schema projects | |
logger.info("Validating schema projects...") | |
validation_results['schema_projects'] = self._validate_schema_projects(schema_projects) | |
# Data sources | |
logger.info("Validating data sources...") | |
validation_results['data_sources'] = self._validate_data_sources(data_sources) | |
# Columns (only if we have valid data sources) | |
if validation_results['data_sources']: | |
logger.info("Validating columns...") | |
validation_results['columns'] = self._validate_columns(validation_results['data_sources']) | |
else: | |
validation_results['columns'] = set() | |
# Tags | |
logger.info("Validating tags...") | |
validation_results['tags'] = self._validate_tags(tags) | |
# Cache results | |
self.validation_cache = validation_results | |
# Validate individual records | |
return self._validate_records(records) | |
except Exception as e: | |
logger.error(f"Validation failed: {str(e)}") | |
return False | |
def _validate_schema_projects(self, schema_projects: Set[str]) -> Set[str]: | |
"""Validate schema projects with detailed progress tracking""" | |
valid_projects = set() | |
for i, project in enumerate(schema_projects, 1): | |
if self.shutdown_requested: | |
break | |
try: | |
logger.debug(f"Validating schema project {i}/{len(schema_projects)}: {project}") | |
params = { | |
"size": 999999999, | |
"sortField": "name", | |
"sortOrder": "asc", | |
"nameOnly": "true", | |
"searchText": project | |
} | |
response = self.api_client.make_request('GET', '/project', params=params) | |
data = response.json() | |
if data.get('count', 0) > 0: | |
for proj in data.get('hits', []): | |
if (project == proj.get('name') and | |
proj.get('type', '').lower() == 'schema' and | |
not proj.get('deleted', True)): | |
valid_projects.add(project) | |
break | |
if i % 10 == 0: | |
logger.info(f"Schema project validation progress: {i}/{len(schema_projects)}") | |
except Exception as e: | |
logger.error(f"Error validating schema project {project}: {str(e)}") | |
continue | |
logger.info(f"Schema project validation complete: {len(valid_projects)}/{len(schema_projects)} valid") | |
return valid_projects | |
def _validate_data_sources(self, data_sources: Set[str]) -> Dict[str, int]: | |
"""Validate data sources and return mapping to IDs""" | |
valid_sources = {} | |
for i, data_source in enumerate(data_sources, 1): | |
if self.shutdown_requested: | |
break | |
try: | |
logger.debug(f"Validating data source {i}/{len(data_sources)}: {data_source}") | |
encoded_name = urllib.parse.quote(data_source, safe='') | |
response = self.api_client.make_request('GET', f'/dataSource/name/{encoded_name}') | |
data = response.json() | |
if not data.get('deleted', True): | |
valid_sources[data['name']] = data['id'] | |
if i % 10 == 0: | |
logger.info(f"Data source validation progress: {i}/{len(data_sources)}") | |
except requests.exceptions.HTTPError as e: | |
if e.response.status_code != 404: | |
logger.error(f"Error validating data source {data_source}: {str(e)}") | |
continue | |
except Exception as e: | |
logger.error(f"Error validating data source {data_source}: {str(e)}") | |
continue | |
logger.info(f"Data source validation complete: {len(valid_sources)}/{len(data_sources)} valid") | |
return valid_sources | |
def _validate_columns(self, data_source_map: Dict[str, int]) -> Set[str]: | |
"""Validate columns for all data sources""" | |
valid_columns = set() | |
for i, (data_source_name, data_source_id) in enumerate(data_source_map.items(), 1): | |
if self.shutdown_requested: | |
break | |
try: | |
logger.debug(f"Fetching columns for data source {i}/{len(data_source_map)}: {data_source_name}") | |
response = self.api_client.make_request('GET', f'/dictionary/{data_source_id}') | |
data = response.json() | |
for column in data.get('metadata', []): | |
column_name = column.get('name', '') | |
if column_name: | |
valid_columns.add(f'{data_source_name}|{column_name}') | |
if i % 10 == 0: | |
logger.info(f"Column validation progress: {i}/{len(data_source_map)}") | |
except Exception as e: | |
logger.error(f"Error fetching columns for {data_source_name}: {str(e)}") | |
continue | |
logger.info(f"Column validation complete: {len(valid_columns)} valid columns found") | |
return valid_columns | |
def _validate_tags(self, tags: Set[str]) -> Set[str]: | |
"""Validate tags""" | |
valid_tags = set() | |
for i, tag in enumerate(tags, 1): | |
if self.shutdown_requested: | |
break | |
try: | |
logger.debug(f"Validating tag {i}/{len(tags)}: {tag}") | |
if tag.startswith("XXX-Classification."): | |
params = { | |
"source": "curated", | |
"searchText": tag, | |
"excludedHierarchies": '["Discovered","New","Skip Stats Job","XXX-DataSource","DataProperties"]', | |
"includeAllSystemTags": "false", | |
"limit": 999999999 | |
} | |
response = self.api_client.make_request('GET', '/tag', params=params) | |
data = response.json() | |
if data and isinstance(data, list): | |
for tag_info in data: | |
if (tag == tag_info.get('name') and | |
((not tag_info.get('hasLeafNodes', True) and not tag_info.get('deleted', True)) or | |
tag_info.get('hasLeafNodes', True))): | |
valid_tags.add(tag) | |
break | |
if i % 10 == 0: | |
logger.info(f"Tag validation progress: {i}/{len(tags)}") | |
except Exception as e: | |
logger.error(f"Error validating tag {tag}: {str(e)}") | |
continue | |
logger.info(f"Tag validation complete: {len(valid_tags)}/{len(tags)} valid") | |
return valid_tags | |
def _validate_records(self, records: List[ColumnTag]) -> bool: | |
"""Validate individual records against cached reference data""" | |
valid_schema_projects = self.validation_cache.get('schema_projects', set()) | |
valid_data_sources = self.validation_cache.get('data_sources', {}) | |
valid_columns = self.validation_cache.get('columns', set()) | |
valid_tags = self.validation_cache.get('tags', set()) | |
failed_records = [] | |
for record in records: | |
errors = [] | |
if record.immuta_schema_project not in valid_schema_projects: | |
errors.append(f"Invalid schema project: {record.immuta_schema_project}") | |
if record.immuta_datasource not in valid_data_sources: | |
errors.append(f"Invalid data source: {record.immuta_datasource}") | |
if record.immuta_column_key not in valid_columns: | |
errors.append(f"Invalid column: {record.immuta_column_key}") | |
if record.immuta_tag not in valid_tags: | |
errors.append(f"Invalid tag: {record.immuta_tag}") | |
if errors: | |
failed_records.append((record, errors)) | |
self.progress_tracker.update_record_status( | |
record.row_id, record.database, record.schema, record.table, | |
record.column, record.immuta_tag, 'validation_failed', | |
"; ".join(errors) | |
) | |
if failed_records: | |
logger.error(f"Validation failed for {len(failed_records)} records") | |
for record, errors in failed_records[:10]: # Log first 10 failures | |
logger.error(f"Row {record.row_id}: {'; '.join(errors)}") | |
if len(failed_records) > 10: | |
logger.error(f"... and {len(failed_records) - 10} more validation failures") | |
return False | |
logger.info(f"All {len(records)} records passed validation") | |
return True | |
def process_records(self, records: List[ColumnTag]) -> Dict[str, int]: | |
"""Process records with enterprise-grade error handling and recovery""" | |
logger.info(f"Starting processing of {len(records)} records...") | |
# Load existing progress for resume functionality | |
existing_progress = self.progress_tracker.load_existing_progress() | |
# Filter out already processed records | |
pending_records = [r for r in records if r.row_id not in existing_progress] | |
if len(pending_records) < len(records): | |
logger.info(f"Resuming from previous run. Processing {len(pending_records)} remaining records " | |
f"(skipping {len(records) - len(pending_records)} already processed)") | |
# Group records by column for efficient processing | |
column_groups = defaultdict(list) | |
for record in pending_records: | |
column_key = f"{record.immuta_datasource}|{record.column}" | |
column_groups[column_key].append(record) | |
logger.info(f"Grouped records into {len(column_groups)} unique columns") | |
# Statistics tracking | |
stats = { | |
'success': 0, | |
'failed': 0, | |
'skipped': 0, | |
'retried': 0 | |
} | |
processed_columns = 0 | |
total_columns = len(column_groups) | |
# Process each column group | |
for column_key, column_records in column_groups.items(): | |
if self.shutdown_requested: | |
logger.info("Shutdown requested. Stopping processing gracefully.") | |
break | |
processed_columns += 1 | |
try: | |
logger.debug(f"Processing column group {processed_columns}/{total_columns}: {column_key}") | |
# Process all tags for this column | |
column_stats = self._process_column_group(column_records) | |
# Update statistics | |
for key in stats: | |
stats[key] += column_stats.get(key, 0) | |
# Progress reporting | |
if processed_columns % 100 == 0: | |
elapsed = time.time() - self.stats['start_time'] | |
rate = processed_columns / elapsed * 3600 # columns per hour | |
eta = (total_columns - processed_columns) / max(rate / 3600, 0.1) | |
logger.info(f"Progress: {processed_columns}/{total_columns} column groups " | |
f"({processed_columns/total_columns*100:.1f}%) - " | |
f"Rate: {rate:.1f} columns/hour - " | |
f"ETA: {timedelta(seconds=int(eta))}") | |
# Progress summary | |
progress_summary = self.progress_tracker.get_progress_summary() | |
logger.info(f"Current stats: Success={stats['success']}, " | |
f"Failed={stats['failed']}, Skipped={stats['skipped']}") | |
except Exception as e: | |
logger.error(f"Error processing column group {column_key}: {str(e)}") | |
# Mark all records in this group as failed | |
for record in column_records: | |
self.progress_tracker.update_record_status( | |
record.row_id, record.database, record.schema, record.table, | |
record.column, record.immuta_tag, 'failed', str(e) | |
) | |
stats['failed'] += len(column_records) | |
# Final statistics | |
total_processed = sum(stats.values()) | |
logger.info(f"Processing complete. Total processed: {total_processed}, " | |
f"Success: {stats['success']}, Failed: {stats['failed']}, " | |
f"Skipped: {stats['skipped']}") | |
return stats | |
def _process_column_group(self, column_records: List[ColumnTag]) -> Dict[str, int]: | |
"""Process all tags for a single column with optimized API usage""" | |
if not column_records: | |
return {'success': 0, 'failed': 0, 'skipped': 0} | |
first_record = column_records[0] | |
data_source_id = self.validation_cache['data_sources'][first_record.immuta_datasource] | |
stats = {'success': 0, 'failed': 0, 'skipped': 0} | |
try: | |
# Get existing tags for this column (with caching) | |
existing_tags = self._get_existing_column_tags_cached(data_source_id, first_record.column) | |
# Process each tag for this column | |
for record in column_records: | |
try: | |
# Check if tag already exists | |
if record.immuta_tag in existing_tags: | |
stats['skipped'] += 1 | |
self.progress_tracker.update_record_status( | |
record.row_id, record.database, record.schema, record.table, | |
record.column, record.immuta_tag, 'skipped', | |
'Tag already exists' | |
) | |
logger.debug(f"Row {record.row_id}: Tag {record.immuta_tag} already exists") | |
continue | |
# Skip actual tagging in dry-run mode | |
if self.args.dry_run: | |
stats['success'] += 1 | |
self.progress_tracker.update_record_status( | |
record.row_id, record.database, record.schema, record.table, | |
record.column, record.immuta_tag, 'success', | |
'DRY RUN - would have added tag' | |
) | |
logger.debug(f"DRY RUN: Would add tag {record.immuta_tag} to column {record.column}") | |
continue | |
# Add the tag | |
success = self._add_single_tag(record, data_source_id) | |
if success: | |
stats['success'] += 1 | |
existing_tags.add(record.immuta_tag) # Update cache | |
self.progress_tracker.update_record_status( | |
record.row_id, record.database, record.schema, record.table, | |
record.column, record.immuta_tag, 'success' | |
) | |
logger.debug(f"Row {record.row_id}: Successfully added tag {record.immuta_tag}") | |
else: | |
stats['failed'] += 1 | |
self.progress_tracker.update_record_status( | |
record.row_id, record.database, record.schema, record.table, | |
record.column, record.immuta_tag, 'failed', | |
'Tag addition failed - see logs' | |
) | |
except Exception as e: | |
stats['failed'] += 1 | |
error_msg = f"Error processing record: {str(e)}" | |
logger.error(f"Row {record.row_id}: {error_msg}") | |
self.progress_tracker.update_record_status( | |
record.row_id, record.database, record.schema, record.table, | |
record.column, record.immuta_tag, 'failed', error_msg | |
) | |
except Exception as e: | |
# Catastrophic failure for entire column group | |
logger.error(f"Catastrophic failure for column group: {str(e)}") | |
for record in column_records: | |
self.progress_tracker.update_record_status( | |
record.row_id, record.database, record.schema, record.table, | |
record.column, record.immuta_tag, 'failed', | |
f"Column group processing failed: {str(e)}" | |
) | |
stats['failed'] = len(column_records) | |
return stats | |
def _get_existing_column_tags_cached(self, data_source_id: int, column_name: str) -> Set[str]: | |
"""Get existing tags for a column with intelligent caching""" | |
cache_key = f"{data_source_id}_{column_name.lower()}" | |
if cache_key in self.column_tags_cache: | |
self.stats['cache_hits'] += 1 | |
return self.column_tags_cache[cache_key] | |
self.stats['cache_misses'] += 1 | |
try: | |
response = self.api_client.make_request('GET', f'/dictionary/{data_source_id}') | |
data = response.json() | |
existing_tags = set() | |
for column_metadata in data.get('metadata', []): | |
if column_metadata.get('name', '').lower() == column_name.lower(): | |
if 'tags' in column_metadata: | |
for tag in column_metadata['tags']: | |
if not tag.get('deleted', True): | |
existing_tags.add(tag['name']) | |
break | |
# Cache the result | |
self.column_tags_cache[cache_key] = existing_tags | |
# Cache size management (prevent memory issues) | |
if len(self.column_tags_cache) > 10000: | |
# Remove oldest 20% of cache entries | |
items_to_remove = len(self.column_tags_cache) // 5 | |
keys_to_remove = list(self.column_tags_cache.keys())[:items_to_remove] | |
for key in keys_to_remove: | |
del self.column_tags_cache[key] | |
logger.debug(f"Cache pruned: removed {items_to_remove} entries") | |
return existing_tags | |
except Exception as e: | |
logger.error(f"Error getting existing tags for column {column_name}: {str(e)}") | |
return set() | |
def _add_single_tag(self, record: ColumnTag, data_source_id: int) -> bool: | |
"""Add a single tag with retry logic""" | |
max_retries = 3 | |
for attempt in range(max_retries): | |
try: | |
# Prepare payload | |
payload = [{ | |
"name": record.immuta_tag, | |
"source": "curated" | |
}] | |
# Make API call | |
column_identifier = f"{data_source_id}_{record.column.lower()}" | |
response = self.api_client.make_request('POST', f'/tag/column/{column_identifier}', data=payload) | |
# Verify tag was added | |
response_data = response.json() | |
for response_tag in response_data: | |
if (record.immuta_tag == response_tag.get('name') and | |
not response_tag.get('deleted', True)): | |
return True | |
logger.error(f"Row {record.row_id}: Tag {record.immuta_tag} not found in response after adding") | |
return False | |
except Exception as e: | |
logger.warning(f"Row {record.row_id}: Attempt {attempt + 1} failed for tag {record.immuta_tag}: {str(e)}") | |
if attempt < max_retries - 1: | |
time.sleep(2 ** attempt) # Exponential backoff | |
else: | |
logger.error(f"Row {record.row_id}: All {max_retries} attempts failed for tag {record.immuta_tag}") | |
return False | |
return False | |
def generate_reports(self, output_dir: str): | |
"""Generate comprehensive reports for enterprise analysis""" | |
os.makedirs(output_dir, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
# Export detailed results | |
detailed_results_file = os.path.join(output_dir, f"detailed_results_{timestamp}.csv") | |
self.progress_tracker.export_results(detailed_results_file) | |
logger.info(f"Detailed results exported to: {detailed_results_file}") | |
# Generate summary report | |
summary_file = os.path.join(output_dir, f"summary_report_{timestamp}.json") | |
summary_data = { | |
'execution_summary': { | |
'start_time': datetime.fromtimestamp(self.stats['start_time']).isoformat(), | |
'end_time': datetime.now().isoformat(), | |
'total_duration_seconds': time.time() - self.stats['start_time'], | |
'environment': self.args.immuta_environment, | |
'script_version': '2.0.0-enterprise' | |
}, | |
'progress_summary': self.progress_tracker.get_progress_summary(), | |
'performance_stats': { | |
'total_api_calls': self.stats.get('total_api_calls', 0), | |
'cache_hits': self.stats.get('cache_hits', 0), | |
'cache_misses': self.stats.get('cache_misses', 0), | |
'cache_hit_rate': (self.stats.get('cache_hits', 0) / | |
max(self.stats.get('cache_hits', 0) + self.stats.get('cache_misses', 0), 1)) | |
} | |
} | |
with open(summary_file, 'w') as f: | |
json.dump(summary_data, f, indent=2) | |
logger.info(f"Summary report generated: {summary_file}") | |
# Generate failed records report for retry | |
failed_records = self.progress_tracker.get_failed_records() | |
failed_file = None | |
if failed_records: | |
failed_file = os.path.join(output_dir, f"failed_records_{timestamp}.csv") | |
with open(failed_file, 'w', newline='') as f: | |
if failed_records: | |
writer = csv.DictWriter(f, fieldnames=failed_records[0].keys()) | |
writer.writeheader() | |
writer.writerows(failed_records) | |
logger.info(f"Failed records report generated: {failed_file}") | |
return { | |
'detailed_results': detailed_results_file, | |
'summary_report': summary_file, | |
'failed_records': failed_file | |
} | |
def parse_arguments(): | |
"""Parse command line arguments with enterprise options""" | |
parser = argparse.ArgumentParser( | |
description='Enterprise-scale Immuta column tagging with comprehensive error handling and recovery.', | |
formatter_class=argparse.RawDescriptionHelpFormatter, | |
epilog=""" | |
Enterprise Features: | |
- Progress tracking with SQLite database | |
- Resume functionality for interrupted runs | |
- Comprehensive error handling and retry logic | |
- Circuit breaker pattern for API failures | |
- Adaptive rate limiting based on error patterns | |
- Detailed reporting and analytics | |
- Graceful shutdown handling | |
Examples: | |
# Basic enterprise run | |
%(prog)s immuta-nonprod --precheck-only | |
# Production run with all optimizations | |
%(prog)s immuta-prod --sleep 0.2 --adaptive-rate-limiting --checkpoint-interval 1000 | |
# Resume interrupted run | |
%(prog)s immuta-prod --resume --progress-db ./progress/run_20241201.db | |
# Retry failed records only | |
%(prog)s immuta-prod --retry-failed --progress-db ./progress/run_20241201.db | |
""" | |
) | |
# Basic arguments | |
parser.add_argument('immuta_environment', | |
choices=['immuta-prod', 'immuta-nonprod'], | |
help='Immuta environment') | |
parser.add_argument('--sleep', '-s', type=float, default=0.2, | |
help='Base sleep interval between API calls (default: %(default)s)') | |
parser.add_argument('--precheck-only', '-pc', action='store_true', | |
help='Only validate data, do not perform tagging') | |
parser.add_argument('--input-file', '-i', type=str, | |
help='Override default input file path') | |
parser.add_argument('--dry-run', action='store_true', | |
help='Perform validation and simulate tagging without actual API calls') | |
parser.add_argument('--verbose', '-v', action='store_true', | |
help='Enable verbose logging') | |
# Enterprise features | |
parser.add_argument('--progress-db', type=str, | |
help='Path to progress database (default: auto-generated)') | |
parser.add_argument('--resume', action='store_true', | |
help='Resume from previous interrupted run') | |
parser.add_argument('--retry-failed', action='store_true', | |
help='Retry only failed records from previous run') | |
parser.add_argument('--checkpoint-interval', type=int, default=1000, | |
help='Progress checkpoint interval (default: %(default)s)') | |
parser.add_argument('--adaptive-rate-limiting', action='store_true', default=True, | |
help='Enable adaptive rate limiting based on error patterns') | |
parser.add_argument('--max-retries', type=int, default=3, | |
help='Maximum retries per operation (default: %(default)s)') | |
parser.add_argument('--circuit-breaker-threshold', type=int, default=10, | |
help='Circuit breaker failure threshold (default: %(default)s)') | |
parser.add_argument('--output-dir', type=str, default='./reports', | |
help='Output directory for reports (default: %(default)s)') | |
parser.add_argument('--log-dir', type=str, default='./logs', | |
help='Log directory (default: %(default)s)') | |
return parser.parse_args() | |
def load_input_data(file_path: str) -> List[ColumnTag]: | |
"""Load and parse input data with comprehensive validation""" | |
logger.info(f"Loading input data from: {file_path}") | |
records = [] | |
errors = [] | |
try: | |
with open(file_path, 'r', encoding='utf-8-sig') as csvfile: | |
reader = csv.DictReader(csvfile) | |
for line_number, row in enumerate(reader, start=2): # Start at 2 (header is line 1) | |
try: | |
values = list(row.values()) | |
if len(values) < 5: | |
errors.append(f"Line {line_number}: Insufficient columns ({len(values)} < 5)") | |
continue | |
# Clean and validate data | |
database = values[0].strip().upper() if values[0] else "" | |
schema = values[1].strip().upper() if values[1] else "" | |
table = values[2].strip().upper() if values[2] else "" | |
column = values[3].strip().upper() if values[3] else "" | |
tag = values[4].strip() if values[4] else "" | |
if not all([database, schema, table, column, tag]): | |
errors.append(f"Line {line_number}: Missing required data") | |
continue | |
record = ColumnTag( | |
row_id=line_number, | |
database=database, | |
schema=schema, | |
table=table, | |
column=column, | |
immuta_tag=tag, | |
immuta_schema_project=f'{database}-{schema}', | |
immuta_datasource=f'{database}-{schema}-{table}', | |
immuta_sql_schema=f'{database.lower()}-{schema.lower()}', | |
immuta_column_key=f'{database}-{schema}-{table}|{column.lower()}', | |
source_file_line=line_number | |
) | |
records.append(record) | |
except Exception as e: | |
errors.append(f"Line {line_number}: Error parsing row - {str(e)}") | |
continue | |
except FileNotFoundError: | |
raise ImmutaggingError(f"Input file not found: {file_path}") | |
except Exception as e: | |
raise ImmutaggingError(f"Error reading input file: {str(e)}") | |
if errors: | |
logger.warning(f"Found {len(errors)} parsing errors:") | |
for error in errors[:10]: # Show first 10 errors | |
logger.warning(f" {error}") | |
if len(errors) > 10: | |
logger.warning(f" ... and {len(errors) - 10} more errors") | |
logger.info(f"Successfully loaded {len(records)} records from input file") | |
if not records: | |
raise ImmutaggingError("No valid records found in input file") | |
return records | |
def main(): | |
"""Main execution function for enterprise-scale operations""" | |
global logger, shutdown_requested | |
args = parse_arguments() | |
# Setup logging | |
logger, main_log_file, error_log_file = setup_logging( | |
log_level=logging.DEBUG if args.verbose else logging.INFO, | |
log_dir=args.log_dir | |
) | |
logger.info("="*80) | |
logger.info("IMMUTA ENTERPRISE COLUMN TAGGING SCRIPT STARTED") | |
logger.info("="*80) | |
logger.info(f"Environment: {args.immuta_environment}") | |
logger.info(f"Dry run: {args.dry_run}") | |
logger.info(f"Resume mode: {args.resume}") | |
logger.info(f"Retry failed mode: {args.retry_failed}") | |
logger.info(f"Sleep interval: {args.sleep}s") | |
logger.info(f"Adaptive rate limiting: {args.adaptive_rate_limiting}") | |
try: | |
# Initialize environment-specific settings | |
if args.immuta_environment == 'immuta-nonprod': | |
input_file = args.input_file or './input-files/columns-to-tag-nonprod.csv' | |
immuta_url = 'https://immuta-nonprod.eda.xxx.co.nz' | |
api_key_file = f'{Path.home()}/.immuta/api-key-immuta-nonprod.json' | |
else: | |
input_file = args.input_file or './input-files/columns-to-tag.csv' | |
immuta_url = 'https://immuta.eda.xxx.co.nz' | |
api_key_file = f'{Path.home()}/.immuta/api-key-immuta.json' | |
# Load API key | |
try: | |
with open(api_key_file, 'r') as f: | |
api_key = json.load(f)['api-key'] | |
if not api_key.strip(): | |
raise ImmutaggingError("API key is empty") | |
except Exception as e: | |
raise ImmutaggingError(f"Error loading API key from {api_key_file}: {str(e)}") | |
# Setup progress database | |
if args.progress_db: | |
progress_db_path = args.progress_db | |
else: | |
os.makedirs('./progress', exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
progress_db_path = f'./progress/immuta_progress_{timestamp}.db' | |
logger.info(f"Progress database: {progress_db_path}") | |
# Initialize progress tracker | |
progress_tracker = ProgressTracker(progress_db_path) | |
# Add environment info to args for orchestrator | |
args.immuta_url = immuta_url | |
args.api_key = api_key | |
# Initialize orchestrator | |
orchestrator = EnterpriseTaggingOrchestrator(args, progress_tracker) | |
# Load input data | |
records = load_input_data(input_file) | |
logger.info(f"Loaded {len(records)} records for processing") | |
# Performance estimate | |
estimated_time = len(records) * args.sleep / 3600 # hours | |
logger.info(f"Estimated processing time: {estimated_time:.1f} hours at {args.sleep}s interval") | |
# Validation phase | |
logger.info("Starting validation phase...") | |
validation_start = time.time() | |
if not orchestrator.validate_and_cache_references(records): | |
logger.error("Validation failed. Stopping execution.") | |
return 1 | |
validation_time = time.time() - validation_start | |
logger.info(f"Validation completed in {validation_time:.1f} seconds") | |
if args.precheck_only: | |
logger.info("Pre-check only mode. Validation completed successfully.") | |
return 0 | |
# Processing phase | |
logger.info("Starting processing phase...") | |
processing_start = time.time() | |
try: | |
processing_stats = orchestrator.process_records(records) | |
processing_time = time.time() - processing_start | |
logger.info(f"Processing completed in {processing_time:.1f} seconds") | |
logger.info(f"Final statistics: {processing_stats}") | |
# Generate comprehensive reports | |
logger.info("Generating reports...") | |
reports = orchestrator.generate_reports(args.output_dir) | |
logger.info("Reports generated:") | |
for report_type, report_path in reports.items(): | |
if report_path: | |
logger.info(f" {report_type}: {report_path}") | |
# Success/failure determination | |
if processing_stats['failed'] > 0: | |
logger.warning(f"Completed with {processing_stats['failed']} failures") | |
return 1 if processing_stats['failed'] > processing_stats['success'] else 0 | |
else: | |
logger.info("All operations completed successfully") | |
return 0 | |
except KeyboardInterrupt: | |
logger.info("Received interrupt signal. Graceful shutdown in progress...") | |
return 130 # Standard exit code for Ctrl+C | |
except Exception as e: | |
logger.error(f"Fatal error: {str(e)}") | |
logger.error(f"Traceback: {traceback.format_exc()}") | |
return 1 | |
finally: | |
logger.info("="*80) | |
logger.info("IMMUTA ENTERPRISE COLUMN TAGGING SCRIPT COMPLETED") | |
logger.info("="*80) | |
if __name__ == "__main__": | |
exit(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment