Skip to content

Instantly share code, notes, and snippets.

@davidlu1001
Last active May 26, 2025 11:33
Show Gist options
  • Save davidlu1001/7e106323fa0c9460eda81f6399c98ea4 to your computer and use it in GitHub Desktop.
Save davidlu1001/7e106323fa0c9460eda81f6399c98ea4 to your computer and use it in GitHub Desktop.
column-tag-add.py
#!/usr/bin/env python3
import csv
import json
import logging
import sys
import os
from pathlib import Path
import requests
import argparse
from time import sleep
from sys import exit
import traceback
from dataclasses import dataclass, asdict
from typing import Set, Dict, List, Tuple, Optional
from collections import defaultdict
import urllib.parse
import time
import hashlib
import sqlite3
from datetime import datetime, timedelta
import signal
import atexit
from contextlib import contextmanager
# Global variables
logger = None
shutdown_requested = False
def setup_logging(log_level=logging.INFO, log_dir="./logs"):
"""Setup comprehensive logging with rotation for large scale operations"""
os.makedirs(log_dir, exist_ok=True)
# Create timestamp for log files
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Configure multiple log handlers
logger = logging.getLogger('immuta_enterprise')
logger.setLevel(log_level)
# Clear existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Console handler with concise format
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler.setFormatter(console_formatter)
logger.addHandler(console_handler)
# Main log file with detailed format
main_log_file = os.path.join(log_dir, f"immuta_tagging_{timestamp}.log")
file_handler = logging.FileHandler(main_log_file, encoding='utf-8')
file_handler.setLevel(log_level)
file_formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s'
)
file_handler.setFormatter(file_formatter)
logger.addHandler(file_handler)
# Error-only log file
error_log_file = os.path.join(log_dir, f"immuta_errors_{timestamp}.log")
error_handler = logging.FileHandler(error_log_file, encoding='utf-8')
error_handler.setLevel(logging.ERROR)
error_handler.setFormatter(file_formatter)
logger.addHandler(error_handler)
return logger, main_log_file, error_log_file
class ImmutaggingError(Exception):
"""Custom exception for Immuta tagging operations"""
pass
@dataclass
class ColumnTag:
"""Enhanced data class with enterprise features"""
row_id: int
database: str
schema: str
table: str
column: str
immuta_tag: str
immuta_schema_project: str
immuta_datasource: str
immuta_sql_schema: str
immuta_column_key: str
# Additional enterprise fields
source_file_line: int = 0
checksum: str = ""
def __post_init__(self):
# Generate checksum for data integrity
data_str = f"{self.database}|{self.schema}|{self.table}|{self.column}|{self.immuta_tag}"
self.checksum = hashlib.md5(data_str.encode()).hexdigest()[:8]
def __str__(self) -> str:
return f"Row {self.row_id}: {self.database}.{self.schema}.{self.table}.{self.column} -> {self.immuta_tag}"
class ProgressTracker:
"""Enterprise-grade progress tracking with persistence and recovery"""
def __init__(self, db_path: str):
self.db_path = db_path
self.init_database()
self.start_time = time.time()
def init_database(self):
"""Initialize SQLite database for progress tracking"""
with sqlite3.connect(self.db_path) as conn:
conn.execute('''
CREATE TABLE IF NOT EXISTS processing_status (
row_id INTEGER PRIMARY KEY,
database_name TEXT,
schema_name TEXT,
table_name TEXT,
column_name TEXT,
tag_name TEXT,
status TEXT, -- 'pending', 'success', 'failed', 'skipped'
error_message TEXT,
processed_at TEXT,
retry_count INTEGER DEFAULT 0
)
''')
conn.execute('''
CREATE TABLE IF NOT EXISTS execution_metadata (
id INTEGER PRIMARY KEY,
session_id TEXT,
start_time TEXT,
end_time TEXT,
total_records INTEGER,
completed_records INTEGER,
failed_records INTEGER,
skipped_records INTEGER,
environment TEXT,
script_version TEXT
)
''')
# Create indexes for better performance
conn.execute('CREATE INDEX IF NOT EXISTS idx_status ON processing_status(status)')
conn.execute('CREATE INDEX IF NOT EXISTS idx_row_id ON processing_status(row_id)')
conn.commit()
def load_existing_progress(self) -> Dict[int, str]:
"""Load existing progress to support resume functionality"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute('''
SELECT row_id, status FROM processing_status
WHERE status IN ('success', 'skipped')
''')
result = {row_id: status for row_id, status in cursor.fetchall()}
logger.info(f"Loaded {len(result)} previously processed records")
return result
def update_record_status(self, row_id: int, database: str, schema: str, table: str,
column: str, tag: str, status: str, error_message: str = None):
"""Update processing status for a record"""
with sqlite3.connect(self.db_path) as conn:
conn.execute('''
INSERT OR REPLACE INTO processing_status
(row_id, database_name, schema_name, table_name, column_name, tag_name,
status, error_message, processed_at, retry_count)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?,
COALESCE((SELECT retry_count FROM processing_status WHERE row_id = ?), 0) +
CASE WHEN ? = 'failed' THEN 1 ELSE 0 END)
''', (row_id, database, schema, table, column, tag, status, error_message,
datetime.now().isoformat(), row_id, status))
conn.commit()
def get_progress_summary(self) -> Dict:
"""Get current progress summary"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute('''
SELECT status, COUNT(*) FROM processing_status GROUP BY status
''')
status_counts = dict(cursor.fetchall())
cursor = conn.execute('SELECT COUNT(*) FROM processing_status')
total = cursor.fetchone()[0]
return {
'total': total,
'success': status_counts.get('success', 0),
'failed': status_counts.get('failed', 0),
'skipped': status_counts.get('skipped', 0),
'pending': status_counts.get('pending', 0)
}
def get_failed_records(self) -> List[Dict]:
"""Get all failed records for retry or analysis"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute('''
SELECT row_id, database_name, schema_name, table_name, column_name,
tag_name, error_message, retry_count
FROM processing_status
WHERE status = 'failed'
ORDER BY retry_count, row_id
''')
columns = [desc[0] for desc in cursor.description]
return [dict(zip(columns, row)) for row in cursor.fetchall()]
def export_results(self, output_file: str):
"""Export all results to CSV for analysis"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute('''
SELECT * FROM processing_status ORDER BY row_id
''')
with open(output_file, 'w', newline='', encoding='utf-8') as csvfile:
writer = csv.writer(csvfile)
writer.writerow([desc[0] for desc in cursor.description])
writer.writerows(cursor.fetchall())
class EnterpriseRateLimiter:
"""Enterprise-grade rate limiter with adaptive behavior"""
def __init__(self, sleep_interval: float = 0.2, adaptive: bool = True):
self.base_sleep_interval = sleep_interval
self.current_sleep_interval = sleep_interval
self.adaptive = adaptive
self.last_call_time = 0
self.consecutive_errors = 0
self.success_count = 0
# Statistics tracking
self.call_history = []
self.error_history = []
calls_per_second = 1.0 / sleep_interval if sleep_interval > 0 else 5.0
logger.info(f"Enterprise rate limiter initialized: {sleep_interval}s base interval, "
f"~{calls_per_second:.1f} calls/second, adaptive={adaptive}")
def wait_if_needed(self):
"""Wait with adaptive rate limiting based on error patterns"""
current_time = time.time()
# Clean old history (keep last 5 minutes)
cutoff_time = current_time - 300
self.call_history = [t for t in self.call_history if t > cutoff_time]
self.error_history = [t for t in self.error_history if t > cutoff_time]
# Adaptive rate limiting based on recent errors
if self.adaptive:
recent_error_rate = len(self.error_history) / max(len(self.call_history), 1)
if recent_error_rate > 0.1: # More than 10% error rate
self.current_sleep_interval = min(self.base_sleep_interval * 2, 2.0)
elif recent_error_rate < 0.01: # Less than 1% error rate
self.current_sleep_interval = max(self.base_sleep_interval * 0.8, 0.1)
else:
self.current_sleep_interval = self.base_sleep_interval
# Enforce minimum interval
time_since_last_call = current_time - self.last_call_time
if time_since_last_call < self.current_sleep_interval:
sleep_time = self.current_sleep_interval - time_since_last_call
time.sleep(sleep_time)
self.last_call_time = time.time()
self.call_history.append(self.last_call_time)
def record_error(self):
"""Record an error for adaptive rate limiting"""
self.consecutive_errors += 1
self.error_history.append(time.time())
# Exponential backoff for consecutive errors
if self.consecutive_errors > 3:
backoff_time = min(2 ** (self.consecutive_errors - 3), 30)
logger.warning(f"Consecutive errors detected. Backing off for {backoff_time}s")
time.sleep(backoff_time)
def record_success(self):
"""Record a success to reset error counters"""
self.consecutive_errors = 0
self.success_count += 1
class EnterpriseImmutaggingClient:
"""Enterprise-grade client with circuit breaker, retry logic, and health monitoring"""
def __init__(self, immuta_url: str, api_key: str, rate_limiter: EnterpriseRateLimiter):
self.immuta_url = immuta_url.rstrip('/')
self.headers = {
'Authorization': f'Bearer {api_key}',
'Content-Type': 'application/json',
'x-api-key': api_key
}
self.rate_limiter = rate_limiter
self.session = requests.Session()
self.session.headers.update(self.headers)
# Circuit breaker state
self.circuit_breaker_failures = 0
self.circuit_breaker_opened_at = None
self.circuit_breaker_threshold = 10
self.circuit_breaker_timeout = 300 # 5 minutes
# Health monitoring
self.health_check_interval = 600 # 10 minutes
self.last_health_check = 0
def _circuit_breaker_check(self):
"""Check if circuit breaker should prevent API calls"""
if self.circuit_breaker_opened_at:
if time.time() - self.circuit_breaker_opened_at > self.circuit_breaker_timeout:
logger.info("Circuit breaker timeout expired, attempting to reset")
self.circuit_breaker_opened_at = None
self.circuit_breaker_failures = 0
else:
raise ImmutaggingError("Circuit breaker is open - too many consecutive failures")
def _health_check(self):
"""Perform periodic health check"""
current_time = time.time()
if current_time - self.last_health_check > self.health_check_interval:
try:
# Simple health check - get tag list with minimal parameters
health_response = self.session.get(
f"{self.immuta_url}/tag",
params={"limit": 1},
timeout=10
)
health_response.raise_for_status()
logger.debug("Health check passed")
self.last_health_check = current_time
except Exception as e:
logger.warning(f"Health check failed: {str(e)}")
def make_request(self, method: str, endpoint: str, data: Optional[Dict] = None,
params: Optional[Dict] = None, retries: int = 5) -> requests.Response:
"""Make HTTP request with comprehensive error handling and retry logic"""
self._circuit_breaker_check()
self._health_check()
self.rate_limiter.wait_if_needed()
url = f"{self.immuta_url}{endpoint}"
for attempt in range(retries):
try:
if method.upper() == 'GET':
response = self.session.get(url, params=params, timeout=60)
elif method.upper() == 'POST':
response = self.session.post(url, json=data, params=params, timeout=60)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
# Handle different HTTP status codes
if response.status_code == 429:
retry_after = int(response.headers.get('Retry-After', 60))
logger.warning(f"Rate limited (429). Waiting {retry_after} seconds before retry {attempt + 1}")
time.sleep(retry_after)
self.rate_limiter.record_error()
continue
elif response.status_code >= 500:
# Server errors - retry with exponential backoff
backoff_time = min(2 ** attempt, 60)
logger.warning(f"Server error {response.status_code}. Retrying in {backoff_time}s (attempt {attempt + 1})")
time.sleep(backoff_time)
self.rate_limiter.record_error()
continue
elif response.status_code >= 400:
# Client errors - usually don't retry
error_msg = f"Client error {response.status_code}: {response.text[:200]}"
logger.error(error_msg)
self.rate_limiter.record_error()
raise ImmutaggingError(error_msg)
# Success
response.raise_for_status()
self.rate_limiter.record_success()
self.circuit_breaker_failures = 0
return response
except requests.exceptions.Timeout:
logger.warning(f"Request timeout (attempt {attempt + 1}/{retries})")
self.rate_limiter.record_error()
if attempt == retries - 1:
self.circuit_breaker_failures += 1
if self.circuit_breaker_failures >= self.circuit_breaker_threshold:
self.circuit_breaker_opened_at = time.time()
logger.error("Circuit breaker opened due to excessive timeouts")
raise ImmutaggingError(f"Request timeout after {retries} attempts")
time.sleep(2 ** attempt)
except requests.exceptions.ConnectionError as e:
logger.warning(f"Connection error (attempt {attempt + 1}/{retries}): {str(e)}")
self.rate_limiter.record_error()
if attempt == retries - 1:
self.circuit_breaker_failures += 1
if self.circuit_breaker_failures >= self.circuit_breaker_threshold:
self.circuit_breaker_opened_at = time.time()
logger.error("Circuit breaker opened due to connection errors")
raise ImmutaggingError(f"Connection failed after {retries} attempts: {str(e)}")
time.sleep(2 ** attempt)
except Exception as e:
logger.error(f"Unexpected error (attempt {attempt + 1}/{retries}): {str(e)}")
self.rate_limiter.record_error()
if attempt == retries - 1:
raise ImmutaggingError(f"Unexpected error after {retries} attempts: {str(e)}")
time.sleep(2 ** attempt)
raise ImmutaggingError("Exhausted all retry attempts")
class EnterpriseTaggingOrchestrator:
"""Main orchestrator for enterprise-scale tagging operations"""
def __init__(self, args, progress_tracker):
self.args = args
self.progress_tracker = progress_tracker
# Initialize API client
rate_limiter = EnterpriseRateLimiter(
sleep_interval=args.sleep,
adaptive=args.adaptive_rate_limiting
)
self.api_client = EnterpriseImmutaggingClient(args.immuta_url, args.api_key, rate_limiter)
# State management
self.validation_cache = {}
self.column_tags_cache = {}
# Statistics
self.stats = {
'start_time': time.time(),
'validation_time': 0,
'processing_time': 0,
'total_api_calls': 0,
'cache_hits': 0,
'cache_misses': 0
}
# Graceful shutdown handling
signal.signal(signal.SIGINT, self._signal_handler)
signal.signal(signal.SIGTERM, self._signal_handler)
atexit.register(self._cleanup)
self.shutdown_requested = False
global shutdown_requested
def _signal_handler(self, signum, frame):
"""Handle graceful shutdown"""
logger.info(f"Received signal {signum}. Initiating graceful shutdown...")
self.shutdown_requested = True
global shutdown_requested
shutdown_requested = True
def _cleanup(self):
"""Cleanup function called on exit"""
if hasattr(self, 'progress_tracker'):
summary = self.progress_tracker.get_progress_summary()
logger.info(f"Final progress summary: {summary}")
@contextmanager
def timer(self, operation_name: str):
"""Context manager for timing operations"""
start_time = time.time()
try:
yield
finally:
elapsed = time.time() - start_time
logger.info(f"{operation_name} completed in {elapsed:.2f} seconds")
def validate_and_cache_references(self, records: List[ColumnTag]) -> bool:
"""Validate and cache all reference data (schema projects, data sources, etc.)"""
logger.info("Starting comprehensive validation of reference data...")
with self.timer("Reference validation"):
# Extract unique sets
schema_projects = {r.immuta_schema_project for r in records}
data_sources = {r.immuta_datasource for r in records}
tags = {r.immuta_tag for r in records}
logger.info(f"Validating {len(schema_projects)} schema projects, "
f"{len(data_sources)} data sources, {len(tags)} tags")
# Validate each category with progress tracking
validation_results = {}
try:
# Schema projects
logger.info("Validating schema projects...")
validation_results['schema_projects'] = self._validate_schema_projects(schema_projects)
# Data sources
logger.info("Validating data sources...")
validation_results['data_sources'] = self._validate_data_sources(data_sources)
# Columns (only if we have valid data sources)
if validation_results['data_sources']:
logger.info("Validating columns...")
validation_results['columns'] = self._validate_columns(validation_results['data_sources'])
else:
validation_results['columns'] = set()
# Tags
logger.info("Validating tags...")
validation_results['tags'] = self._validate_tags(tags)
# Cache results
self.validation_cache = validation_results
# Validate individual records
return self._validate_records(records)
except Exception as e:
logger.error(f"Validation failed: {str(e)}")
return False
def _validate_schema_projects(self, schema_projects: Set[str]) -> Set[str]:
"""Validate schema projects with detailed progress tracking"""
valid_projects = set()
for i, project in enumerate(schema_projects, 1):
if self.shutdown_requested:
break
try:
logger.debug(f"Validating schema project {i}/{len(schema_projects)}: {project}")
params = {
"size": 999999999,
"sortField": "name",
"sortOrder": "asc",
"nameOnly": "true",
"searchText": project
}
response = self.api_client.make_request('GET', '/project', params=params)
data = response.json()
if data.get('count', 0) > 0:
for proj in data.get('hits', []):
if (project == proj.get('name') and
proj.get('type', '').lower() == 'schema' and
not proj.get('deleted', True)):
valid_projects.add(project)
break
if i % 10 == 0:
logger.info(f"Schema project validation progress: {i}/{len(schema_projects)}")
except Exception as e:
logger.error(f"Error validating schema project {project}: {str(e)}")
continue
logger.info(f"Schema project validation complete: {len(valid_projects)}/{len(schema_projects)} valid")
return valid_projects
def _validate_data_sources(self, data_sources: Set[str]) -> Dict[str, int]:
"""Validate data sources and return mapping to IDs"""
valid_sources = {}
for i, data_source in enumerate(data_sources, 1):
if self.shutdown_requested:
break
try:
logger.debug(f"Validating data source {i}/{len(data_sources)}: {data_source}")
encoded_name = urllib.parse.quote(data_source, safe='')
response = self.api_client.make_request('GET', f'/dataSource/name/{encoded_name}')
data = response.json()
if not data.get('deleted', True):
valid_sources[data['name']] = data['id']
if i % 10 == 0:
logger.info(f"Data source validation progress: {i}/{len(data_sources)}")
except requests.exceptions.HTTPError as e:
if e.response.status_code != 404:
logger.error(f"Error validating data source {data_source}: {str(e)}")
continue
except Exception as e:
logger.error(f"Error validating data source {data_source}: {str(e)}")
continue
logger.info(f"Data source validation complete: {len(valid_sources)}/{len(data_sources)} valid")
return valid_sources
def _validate_columns(self, data_source_map: Dict[str, int]) -> Set[str]:
"""Validate columns for all data sources"""
valid_columns = set()
for i, (data_source_name, data_source_id) in enumerate(data_source_map.items(), 1):
if self.shutdown_requested:
break
try:
logger.debug(f"Fetching columns for data source {i}/{len(data_source_map)}: {data_source_name}")
response = self.api_client.make_request('GET', f'/dictionary/{data_source_id}')
data = response.json()
for column in data.get('metadata', []):
column_name = column.get('name', '')
if column_name:
valid_columns.add(f'{data_source_name}|{column_name}')
if i % 10 == 0:
logger.info(f"Column validation progress: {i}/{len(data_source_map)}")
except Exception as e:
logger.error(f"Error fetching columns for {data_source_name}: {str(e)}")
continue
logger.info(f"Column validation complete: {len(valid_columns)} valid columns found")
return valid_columns
def _validate_tags(self, tags: Set[str]) -> Set[str]:
"""Validate tags"""
valid_tags = set()
for i, tag in enumerate(tags, 1):
if self.shutdown_requested:
break
try:
logger.debug(f"Validating tag {i}/{len(tags)}: {tag}")
if tag.startswith("XXX-Classification."):
params = {
"source": "curated",
"searchText": tag,
"excludedHierarchies": '["Discovered","New","Skip Stats Job","XXX-DataSource","DataProperties"]',
"includeAllSystemTags": "false",
"limit": 999999999
}
response = self.api_client.make_request('GET', '/tag', params=params)
data = response.json()
if data and isinstance(data, list):
for tag_info in data:
if (tag == tag_info.get('name') and
((not tag_info.get('hasLeafNodes', True) and not tag_info.get('deleted', True)) or
tag_info.get('hasLeafNodes', True))):
valid_tags.add(tag)
break
if i % 10 == 0:
logger.info(f"Tag validation progress: {i}/{len(tags)}")
except Exception as e:
logger.error(f"Error validating tag {tag}: {str(e)}")
continue
logger.info(f"Tag validation complete: {len(valid_tags)}/{len(tags)} valid")
return valid_tags
def _validate_records(self, records: List[ColumnTag]) -> bool:
"""Validate individual records against cached reference data"""
valid_schema_projects = self.validation_cache.get('schema_projects', set())
valid_data_sources = self.validation_cache.get('data_sources', {})
valid_columns = self.validation_cache.get('columns', set())
valid_tags = self.validation_cache.get('tags', set())
failed_records = []
for record in records:
errors = []
if record.immuta_schema_project not in valid_schema_projects:
errors.append(f"Invalid schema project: {record.immuta_schema_project}")
if record.immuta_datasource not in valid_data_sources:
errors.append(f"Invalid data source: {record.immuta_datasource}")
if record.immuta_column_key not in valid_columns:
errors.append(f"Invalid column: {record.immuta_column_key}")
if record.immuta_tag not in valid_tags:
errors.append(f"Invalid tag: {record.immuta_tag}")
if errors:
failed_records.append((record, errors))
self.progress_tracker.update_record_status(
record.row_id, record.database, record.schema, record.table,
record.column, record.immuta_tag, 'validation_failed',
"; ".join(errors)
)
if failed_records:
logger.error(f"Validation failed for {len(failed_records)} records")
for record, errors in failed_records[:10]: # Log first 10 failures
logger.error(f"Row {record.row_id}: {'; '.join(errors)}")
if len(failed_records) > 10:
logger.error(f"... and {len(failed_records) - 10} more validation failures")
return False
logger.info(f"All {len(records)} records passed validation")
return True
def process_records(self, records: List[ColumnTag]) -> Dict[str, int]:
"""Process records with enterprise-grade error handling and recovery"""
logger.info(f"Starting processing of {len(records)} records...")
# Load existing progress for resume functionality
existing_progress = self.progress_tracker.load_existing_progress()
# Filter out already processed records
pending_records = [r for r in records if r.row_id not in existing_progress]
if len(pending_records) < len(records):
logger.info(f"Resuming from previous run. Processing {len(pending_records)} remaining records "
f"(skipping {len(records) - len(pending_records)} already processed)")
# Group records by column for efficient processing
column_groups = defaultdict(list)
for record in pending_records:
column_key = f"{record.immuta_datasource}|{record.column}"
column_groups[column_key].append(record)
logger.info(f"Grouped records into {len(column_groups)} unique columns")
# Statistics tracking
stats = {
'success': 0,
'failed': 0,
'skipped': 0,
'retried': 0
}
processed_columns = 0
total_columns = len(column_groups)
# Process each column group
for column_key, column_records in column_groups.items():
if self.shutdown_requested:
logger.info("Shutdown requested. Stopping processing gracefully.")
break
processed_columns += 1
try:
logger.debug(f"Processing column group {processed_columns}/{total_columns}: {column_key}")
# Process all tags for this column
column_stats = self._process_column_group(column_records)
# Update statistics
for key in stats:
stats[key] += column_stats.get(key, 0)
# Progress reporting
if processed_columns % 100 == 0:
elapsed = time.time() - self.stats['start_time']
rate = processed_columns / elapsed * 3600 # columns per hour
eta = (total_columns - processed_columns) / max(rate / 3600, 0.1)
logger.info(f"Progress: {processed_columns}/{total_columns} column groups "
f"({processed_columns/total_columns*100:.1f}%) - "
f"Rate: {rate:.1f} columns/hour - "
f"ETA: {timedelta(seconds=int(eta))}")
# Progress summary
progress_summary = self.progress_tracker.get_progress_summary()
logger.info(f"Current stats: Success={stats['success']}, "
f"Failed={stats['failed']}, Skipped={stats['skipped']}")
except Exception as e:
logger.error(f"Error processing column group {column_key}: {str(e)}")
# Mark all records in this group as failed
for record in column_records:
self.progress_tracker.update_record_status(
record.row_id, record.database, record.schema, record.table,
record.column, record.immuta_tag, 'failed', str(e)
)
stats['failed'] += len(column_records)
# Final statistics
total_processed = sum(stats.values())
logger.info(f"Processing complete. Total processed: {total_processed}, "
f"Success: {stats['success']}, Failed: {stats['failed']}, "
f"Skipped: {stats['skipped']}")
return stats
def _process_column_group(self, column_records: List[ColumnTag]) -> Dict[str, int]:
"""Process all tags for a single column with optimized API usage"""
if not column_records:
return {'success': 0, 'failed': 0, 'skipped': 0}
first_record = column_records[0]
data_source_id = self.validation_cache['data_sources'][first_record.immuta_datasource]
stats = {'success': 0, 'failed': 0, 'skipped': 0}
try:
# Get existing tags for this column (with caching)
existing_tags = self._get_existing_column_tags_cached(data_source_id, first_record.column)
# Process each tag for this column
for record in column_records:
try:
# Check if tag already exists
if record.immuta_tag in existing_tags:
stats['skipped'] += 1
self.progress_tracker.update_record_status(
record.row_id, record.database, record.schema, record.table,
record.column, record.immuta_tag, 'skipped',
'Tag already exists'
)
logger.debug(f"Row {record.row_id}: Tag {record.immuta_tag} already exists")
continue
# Skip actual tagging in dry-run mode
if self.args.dry_run:
stats['success'] += 1
self.progress_tracker.update_record_status(
record.row_id, record.database, record.schema, record.table,
record.column, record.immuta_tag, 'success',
'DRY RUN - would have added tag'
)
logger.debug(f"DRY RUN: Would add tag {record.immuta_tag} to column {record.column}")
continue
# Add the tag
success = self._add_single_tag(record, data_source_id)
if success:
stats['success'] += 1
existing_tags.add(record.immuta_tag) # Update cache
self.progress_tracker.update_record_status(
record.row_id, record.database, record.schema, record.table,
record.column, record.immuta_tag, 'success'
)
logger.debug(f"Row {record.row_id}: Successfully added tag {record.immuta_tag}")
else:
stats['failed'] += 1
self.progress_tracker.update_record_status(
record.row_id, record.database, record.schema, record.table,
record.column, record.immuta_tag, 'failed',
'Tag addition failed - see logs'
)
except Exception as e:
stats['failed'] += 1
error_msg = f"Error processing record: {str(e)}"
logger.error(f"Row {record.row_id}: {error_msg}")
self.progress_tracker.update_record_status(
record.row_id, record.database, record.schema, record.table,
record.column, record.immuta_tag, 'failed', error_msg
)
except Exception as e:
# Catastrophic failure for entire column group
logger.error(f"Catastrophic failure for column group: {str(e)}")
for record in column_records:
self.progress_tracker.update_record_status(
record.row_id, record.database, record.schema, record.table,
record.column, record.immuta_tag, 'failed',
f"Column group processing failed: {str(e)}"
)
stats['failed'] = len(column_records)
return stats
def _get_existing_column_tags_cached(self, data_source_id: int, column_name: str) -> Set[str]:
"""Get existing tags for a column with intelligent caching"""
cache_key = f"{data_source_id}_{column_name.lower()}"
if cache_key in self.column_tags_cache:
self.stats['cache_hits'] += 1
return self.column_tags_cache[cache_key]
self.stats['cache_misses'] += 1
try:
response = self.api_client.make_request('GET', f'/dictionary/{data_source_id}')
data = response.json()
existing_tags = set()
for column_metadata in data.get('metadata', []):
if column_metadata.get('name', '').lower() == column_name.lower():
if 'tags' in column_metadata:
for tag in column_metadata['tags']:
if not tag.get('deleted', True):
existing_tags.add(tag['name'])
break
# Cache the result
self.column_tags_cache[cache_key] = existing_tags
# Cache size management (prevent memory issues)
if len(self.column_tags_cache) > 10000:
# Remove oldest 20% of cache entries
items_to_remove = len(self.column_tags_cache) // 5
keys_to_remove = list(self.column_tags_cache.keys())[:items_to_remove]
for key in keys_to_remove:
del self.column_tags_cache[key]
logger.debug(f"Cache pruned: removed {items_to_remove} entries")
return existing_tags
except Exception as e:
logger.error(f"Error getting existing tags for column {column_name}: {str(e)}")
return set()
def _add_single_tag(self, record: ColumnTag, data_source_id: int) -> bool:
"""Add a single tag with retry logic"""
max_retries = 3
for attempt in range(max_retries):
try:
# Prepare payload
payload = [{
"name": record.immuta_tag,
"source": "curated"
}]
# Make API call
column_identifier = f"{data_source_id}_{record.column.lower()}"
response = self.api_client.make_request('POST', f'/tag/column/{column_identifier}', data=payload)
# Verify tag was added
response_data = response.json()
for response_tag in response_data:
if (record.immuta_tag == response_tag.get('name') and
not response_tag.get('deleted', True)):
return True
logger.error(f"Row {record.row_id}: Tag {record.immuta_tag} not found in response after adding")
return False
except Exception as e:
logger.warning(f"Row {record.row_id}: Attempt {attempt + 1} failed for tag {record.immuta_tag}: {str(e)}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt) # Exponential backoff
else:
logger.error(f"Row {record.row_id}: All {max_retries} attempts failed for tag {record.immuta_tag}")
return False
return False
def generate_reports(self, output_dir: str):
"""Generate comprehensive reports for enterprise analysis"""
os.makedirs(output_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Export detailed results
detailed_results_file = os.path.join(output_dir, f"detailed_results_{timestamp}.csv")
self.progress_tracker.export_results(detailed_results_file)
logger.info(f"Detailed results exported to: {detailed_results_file}")
# Generate summary report
summary_file = os.path.join(output_dir, f"summary_report_{timestamp}.json")
summary_data = {
'execution_summary': {
'start_time': datetime.fromtimestamp(self.stats['start_time']).isoformat(),
'end_time': datetime.now().isoformat(),
'total_duration_seconds': time.time() - self.stats['start_time'],
'environment': self.args.immuta_environment,
'script_version': '2.0.0-enterprise'
},
'progress_summary': self.progress_tracker.get_progress_summary(),
'performance_stats': {
'total_api_calls': self.stats.get('total_api_calls', 0),
'cache_hits': self.stats.get('cache_hits', 0),
'cache_misses': self.stats.get('cache_misses', 0),
'cache_hit_rate': (self.stats.get('cache_hits', 0) /
max(self.stats.get('cache_hits', 0) + self.stats.get('cache_misses', 0), 1))
}
}
with open(summary_file, 'w') as f:
json.dump(summary_data, f, indent=2)
logger.info(f"Summary report generated: {summary_file}")
# Generate failed records report for retry
failed_records = self.progress_tracker.get_failed_records()
failed_file = None
if failed_records:
failed_file = os.path.join(output_dir, f"failed_records_{timestamp}.csv")
with open(failed_file, 'w', newline='') as f:
if failed_records:
writer = csv.DictWriter(f, fieldnames=failed_records[0].keys())
writer.writeheader()
writer.writerows(failed_records)
logger.info(f"Failed records report generated: {failed_file}")
return {
'detailed_results': detailed_results_file,
'summary_report': summary_file,
'failed_records': failed_file
}
def parse_arguments():
"""Parse command line arguments with enterprise options"""
parser = argparse.ArgumentParser(
description='Enterprise-scale Immuta column tagging with comprehensive error handling and recovery.',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Enterprise Features:
- Progress tracking with SQLite database
- Resume functionality for interrupted runs
- Comprehensive error handling and retry logic
- Circuit breaker pattern for API failures
- Adaptive rate limiting based on error patterns
- Detailed reporting and analytics
- Graceful shutdown handling
Examples:
# Basic enterprise run
%(prog)s immuta-nonprod --precheck-only
# Production run with all optimizations
%(prog)s immuta-prod --sleep 0.2 --adaptive-rate-limiting --checkpoint-interval 1000
# Resume interrupted run
%(prog)s immuta-prod --resume --progress-db ./progress/run_20241201.db
# Retry failed records only
%(prog)s immuta-prod --retry-failed --progress-db ./progress/run_20241201.db
"""
)
# Basic arguments
parser.add_argument('immuta_environment',
choices=['immuta-prod', 'immuta-nonprod'],
help='Immuta environment')
parser.add_argument('--sleep', '-s', type=float, default=0.2,
help='Base sleep interval between API calls (default: %(default)s)')
parser.add_argument('--precheck-only', '-pc', action='store_true',
help='Only validate data, do not perform tagging')
parser.add_argument('--input-file', '-i', type=str,
help='Override default input file path')
parser.add_argument('--dry-run', action='store_true',
help='Perform validation and simulate tagging without actual API calls')
parser.add_argument('--verbose', '-v', action='store_true',
help='Enable verbose logging')
# Enterprise features
parser.add_argument('--progress-db', type=str,
help='Path to progress database (default: auto-generated)')
parser.add_argument('--resume', action='store_true',
help='Resume from previous interrupted run')
parser.add_argument('--retry-failed', action='store_true',
help='Retry only failed records from previous run')
parser.add_argument('--checkpoint-interval', type=int, default=1000,
help='Progress checkpoint interval (default: %(default)s)')
parser.add_argument('--adaptive-rate-limiting', action='store_true', default=True,
help='Enable adaptive rate limiting based on error patterns')
parser.add_argument('--max-retries', type=int, default=3,
help='Maximum retries per operation (default: %(default)s)')
parser.add_argument('--circuit-breaker-threshold', type=int, default=10,
help='Circuit breaker failure threshold (default: %(default)s)')
parser.add_argument('--output-dir', type=str, default='./reports',
help='Output directory for reports (default: %(default)s)')
parser.add_argument('--log-dir', type=str, default='./logs',
help='Log directory (default: %(default)s)')
return parser.parse_args()
def load_input_data(file_path: str) -> List[ColumnTag]:
"""Load and parse input data with comprehensive validation"""
logger.info(f"Loading input data from: {file_path}")
records = []
errors = []
try:
with open(file_path, 'r', encoding='utf-8-sig') as csvfile:
reader = csv.DictReader(csvfile)
for line_number, row in enumerate(reader, start=2): # Start at 2 (header is line 1)
try:
values = list(row.values())
if len(values) < 5:
errors.append(f"Line {line_number}: Insufficient columns ({len(values)} < 5)")
continue
# Clean and validate data
database = values[0].strip().upper() if values[0] else ""
schema = values[1].strip().upper() if values[1] else ""
table = values[2].strip().upper() if values[2] else ""
column = values[3].strip().upper() if values[3] else ""
tag = values[4].strip() if values[4] else ""
if not all([database, schema, table, column, tag]):
errors.append(f"Line {line_number}: Missing required data")
continue
record = ColumnTag(
row_id=line_number,
database=database,
schema=schema,
table=table,
column=column,
immuta_tag=tag,
immuta_schema_project=f'{database}-{schema}',
immuta_datasource=f'{database}-{schema}-{table}',
immuta_sql_schema=f'{database.lower()}-{schema.lower()}',
immuta_column_key=f'{database}-{schema}-{table}|{column.lower()}',
source_file_line=line_number
)
records.append(record)
except Exception as e:
errors.append(f"Line {line_number}: Error parsing row - {str(e)}")
continue
except FileNotFoundError:
raise ImmutaggingError(f"Input file not found: {file_path}")
except Exception as e:
raise ImmutaggingError(f"Error reading input file: {str(e)}")
if errors:
logger.warning(f"Found {len(errors)} parsing errors:")
for error in errors[:10]: # Show first 10 errors
logger.warning(f" {error}")
if len(errors) > 10:
logger.warning(f" ... and {len(errors) - 10} more errors")
logger.info(f"Successfully loaded {len(records)} records from input file")
if not records:
raise ImmutaggingError("No valid records found in input file")
return records
def main():
"""Main execution function for enterprise-scale operations"""
global logger, shutdown_requested
args = parse_arguments()
# Setup logging
logger, main_log_file, error_log_file = setup_logging(
log_level=logging.DEBUG if args.verbose else logging.INFO,
log_dir=args.log_dir
)
logger.info("="*80)
logger.info("IMMUTA ENTERPRISE COLUMN TAGGING SCRIPT STARTED")
logger.info("="*80)
logger.info(f"Environment: {args.immuta_environment}")
logger.info(f"Dry run: {args.dry_run}")
logger.info(f"Resume mode: {args.resume}")
logger.info(f"Retry failed mode: {args.retry_failed}")
logger.info(f"Sleep interval: {args.sleep}s")
logger.info(f"Adaptive rate limiting: {args.adaptive_rate_limiting}")
try:
# Initialize environment-specific settings
if args.immuta_environment == 'immuta-nonprod':
input_file = args.input_file or './input-files/columns-to-tag-nonprod.csv'
immuta_url = 'https://immuta-nonprod.eda.xxx.co.nz'
api_key_file = f'{Path.home()}/.immuta/api-key-immuta-nonprod.json'
else:
input_file = args.input_file or './input-files/columns-to-tag.csv'
immuta_url = 'https://immuta.eda.xxx.co.nz'
api_key_file = f'{Path.home()}/.immuta/api-key-immuta.json'
# Load API key
try:
with open(api_key_file, 'r') as f:
api_key = json.load(f)['api-key']
if not api_key.strip():
raise ImmutaggingError("API key is empty")
except Exception as e:
raise ImmutaggingError(f"Error loading API key from {api_key_file}: {str(e)}")
# Setup progress database
if args.progress_db:
progress_db_path = args.progress_db
else:
os.makedirs('./progress', exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
progress_db_path = f'./progress/immuta_progress_{timestamp}.db'
logger.info(f"Progress database: {progress_db_path}")
# Initialize progress tracker
progress_tracker = ProgressTracker(progress_db_path)
# Add environment info to args for orchestrator
args.immuta_url = immuta_url
args.api_key = api_key
# Initialize orchestrator
orchestrator = EnterpriseTaggingOrchestrator(args, progress_tracker)
# Load input data
records = load_input_data(input_file)
logger.info(f"Loaded {len(records)} records for processing")
# Performance estimate
estimated_time = len(records) * args.sleep / 3600 # hours
logger.info(f"Estimated processing time: {estimated_time:.1f} hours at {args.sleep}s interval")
# Validation phase
logger.info("Starting validation phase...")
validation_start = time.time()
if not orchestrator.validate_and_cache_references(records):
logger.error("Validation failed. Stopping execution.")
return 1
validation_time = time.time() - validation_start
logger.info(f"Validation completed in {validation_time:.1f} seconds")
if args.precheck_only:
logger.info("Pre-check only mode. Validation completed successfully.")
return 0
# Processing phase
logger.info("Starting processing phase...")
processing_start = time.time()
try:
processing_stats = orchestrator.process_records(records)
processing_time = time.time() - processing_start
logger.info(f"Processing completed in {processing_time:.1f} seconds")
logger.info(f"Final statistics: {processing_stats}")
# Generate comprehensive reports
logger.info("Generating reports...")
reports = orchestrator.generate_reports(args.output_dir)
logger.info("Reports generated:")
for report_type, report_path in reports.items():
if report_path:
logger.info(f" {report_type}: {report_path}")
# Success/failure determination
if processing_stats['failed'] > 0:
logger.warning(f"Completed with {processing_stats['failed']} failures")
return 1 if processing_stats['failed'] > processing_stats['success'] else 0
else:
logger.info("All operations completed successfully")
return 0
except KeyboardInterrupt:
logger.info("Received interrupt signal. Graceful shutdown in progress...")
return 130 # Standard exit code for Ctrl+C
except Exception as e:
logger.error(f"Fatal error: {str(e)}")
logger.error(f"Traceback: {traceback.format_exc()}")
return 1
finally:
logger.info("="*80)
logger.info("IMMUTA ENTERPRISE COLUMN TAGGING SCRIPT COMPLETED")
logger.info("="*80)
if __name__ == "__main__":
exit(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment