Created
December 29, 2025 20:00
-
-
Save bogged-broker/90a7d8306ec2c9732454281deb5fbe64 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Production-Grade Task Queue System - ENHANCED | |
| Distributed, persistent, scalable task queue with Redis backend, worker coordination, | |
| dependency management, retry logic, full orchestration integration, scheduler, priority | |
| routing, worker load balancing, and graceful shutdown. | |
| """ | |
| import json | |
| import time | |
| import uuid | |
| import pickle | |
| import logging | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field, asdict | |
| from enum import Enum | |
| from typing import Any, Dict, List, Optional, Set, Callable, Tuple | |
| import threading | |
| import redis | |
| from redis import Redis | |
| from datetime import datetime, timedelta | |
| import heapq | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ============================================================================ | |
| # TASK DEFINITIONS & SCHEMA | |
| # ============================================================================ | |
| class TaskType(Enum): | |
| """Complete enumeration of all task types in the system""" | |
| # Content Generation Pipeline | |
| GENERATE_SCRIPT = "generate_script" | |
| GENERATE_VISUALS = "generate_visuals" | |
| GENERATE_AUDIO = "generate_audio" | |
| GENERATE_THUMBNAIL = "generate_thumbnail" | |
| COMPOSE_VIDEO = "compose_video" | |
| # Optimization & ML | |
| OPTIMIZE_RL = "optimize_rl" | |
| UPDATE_PRETRAIN = "update_pretrain" | |
| TRAIN_MODEL = "train_model" | |
| AB_TEST = "ab_test" | |
| # Distribution | |
| POST_CONTENT = "post_content" | |
| SCHEDULE_POST = "schedule_post" | |
| VIRAL_BOOST = "viral_boost" | |
| # Analysis & Monitoring | |
| ANALYZE_PERFORMANCE = "analyze_performance" | |
| ANALYZE_METRICS = "analyze_metrics" | |
| SCRAPE_TRENDING = "scrape_trending" | |
| # Maintenance | |
| CLEANUP = "cleanup" | |
| BACKUP = "backup" | |
| HEALTH_CHECK = "health_check" | |
| class TaskPriority(Enum): | |
| """Priority levels with numeric values for sorting""" | |
| CRITICAL = 0 # System-critical, immediate execution | |
| HIGH = 1 # High priority, viral candidates | |
| NORMAL = 2 # Standard tasks | |
| LOW = 3 # Background tasks | |
| BACKGROUND = 4 # Cleanup, maintenance | |
| class TaskStatus(Enum): | |
| """Complete task lifecycle states""" | |
| PENDING = "pending" # Created, not yet queued | |
| QUEUED = "queued" # In queue, waiting | |
| CLAIMED = "claimed" # Claimed by worker | |
| RUNNING = "running" # Currently executing | |
| COMPLETED = "completed" # Successfully completed | |
| FAILED = "failed" # Failed, no more retries | |
| RETRY = "retry" # Failed, will retry | |
| CANCELLED = "cancelled" # Manually cancelled | |
| TIMEOUT = "timeout" # Execution timeout | |
| BLOCKED = "blocked" # Waiting on dependencies | |
| @dataclass | |
| class TaskPayload: | |
| """ | |
| Complete task payload with all metadata for distributed execution. | |
| Serializable for Redis storage. | |
| """ | |
| # Core identification | |
| task_id: str | |
| task_type: TaskType | |
| priority: TaskPriority | |
| # Timing | |
| created_at: float | |
| scheduled_for: float # Unix timestamp | |
| started_at: Optional[float] = None | |
| completed_at: Optional[float] = None | |
| deadline: Optional[float] = None # Hard deadline | |
| # Dependencies | |
| depends_on: List[str] = field(default_factory=list) | |
| blocks: List[str] = field(default_factory=list) | |
| # Execution metadata | |
| metadata: Dict[str, Any] = field(default_factory=dict) | |
| result: Optional[Dict[str, Any]] = None | |
| # Retry tracking | |
| retry_count: int = 0 | |
| max_retries: int = 3 | |
| last_error: Optional[str] = None | |
| error_history: List[Dict] = field(default_factory=list) | |
| # Status tracking | |
| status: TaskStatus = TaskStatus.PENDING | |
| worker_id: Optional[str] = None | |
| claimed_at: Optional[float] = None | |
| lease_expires_at: Optional[float] = None | |
| # Content-specific fields | |
| video_id: Optional[str] = None | |
| account_id: Optional[str] = None | |
| platform: Optional[str] = None | |
| # RL/ML optimization flags | |
| is_viral_candidate: bool = False | |
| predicted_engagement: float = 0.0 | |
| predicted_virality: float = 0.0 | |
| ab_test_variant: Optional[str] = None | |
| # Execution context | |
| timeout_seconds: int = 300 | |
| idempotency_key: Optional[str] = None | |
| def to_dict(self) -> Dict[str, Any]: | |
| """Convert to serializable dictionary""" | |
| data = asdict(self) | |
| data['task_type'] = self.task_type.value | |
| data['priority'] = self.priority.value | |
| data['status'] = self.status.value | |
| return data | |
| @classmethod | |
| def from_dict(cls, data: Dict[str, Any]) -> 'TaskPayload': | |
| """Create from dictionary""" | |
| data['task_type'] = TaskType(data['task_type']) | |
| data['priority'] = TaskPriority(data['priority']) | |
| data['status'] = TaskStatus(data['status']) | |
| return cls(**data) | |
| def __lt__(self, other: 'TaskPayload') -> bool: | |
| """Priority queue comparison""" | |
| if self.priority.value != other.priority.value: | |
| return self.priority.value < other.priority.value | |
| return self.scheduled_for < other.scheduled_for | |
| # ============================================================================ | |
| # PRIORITY ROUTER | |
| # ============================================================================ | |
| class PriorityRouter: | |
| """ | |
| Routes tasks to appropriate priority queues based on: | |
| - Predicted virality | |
| - Account load | |
| - System capacity | |
| - Business rules | |
| """ | |
| def __init__(self, redis_client: Redis): | |
| self.redis = redis_client | |
| self.rules: List[Callable] = [] | |
| # Virality thresholds | |
| self.viral_threshold_high = 0.8 # 80%+ predicted virality -> HIGH | |
| self.viral_threshold_medium = 0.5 # 50%+ -> NORMAL | |
| # Load balancing thresholds | |
| self.high_priority_max_queue_size = 100 | |
| self.account_max_concurrent_posts = 5 | |
| logger.info("PriorityRouter initialized") | |
| def route_task(self, payload: TaskPayload) -> TaskPriority: | |
| """ | |
| Determine optimal priority for task based on multiple factors. | |
| """ | |
| # System critical tasks always get CRITICAL priority | |
| if payload.task_type in [TaskType.HEALTH_CHECK, TaskType.BACKUP]: | |
| return TaskPriority.CRITICAL | |
| # Check custom routing rules first | |
| for rule in self.rules: | |
| priority = rule(payload) | |
| if priority: | |
| return priority | |
| # Viral candidate routing | |
| if payload.is_viral_candidate or payload.predicted_virality >= self.viral_threshold_high: | |
| logger.info(f"Task {payload.task_id} routed to HIGH (viral candidate)") | |
| return TaskPriority.HIGH | |
| # Deadline-based routing | |
| if payload.deadline: | |
| time_until_deadline = payload.deadline - time.time() | |
| if time_until_deadline < 300: # Less than 5 minutes | |
| logger.warning(f"Task {payload.task_id} routed to HIGH (deadline approaching)") | |
| return TaskPriority.HIGH | |
| # Account load balancing - prevent single account overload | |
| if payload.account_id: | |
| account_load = self._get_account_load(payload.account_id) | |
| if account_load >= self.account_max_concurrent_posts: | |
| logger.info(f"Task {payload.task_id} routed to LOW (account overload)") | |
| return TaskPriority.LOW | |
| # Engagement-based routing | |
| if payload.predicted_engagement >= self.viral_threshold_medium: | |
| return TaskPriority.NORMAL | |
| # Background/maintenance tasks | |
| if payload.task_type in [TaskType.CLEANUP, TaskType.ANALYZE_PERFORMANCE]: | |
| return TaskPriority.BACKGROUND | |
| # Default | |
| return payload.priority or TaskPriority.NORMAL | |
| def _get_account_load(self, account_id: str) -> int: | |
| """Get current load for account""" | |
| key = f"viral_factory:account_load:{account_id}" | |
| return int(self.redis.get(key) or 0) | |
| def add_routing_rule(self, rule: Callable[[TaskPayload], Optional[TaskPriority]]): | |
| """Add custom routing rule""" | |
| self.rules.append(rule) | |
| def get_recommended_queue(self, task_type: TaskType) -> str: | |
| """Get recommended queue for task type""" | |
| # Route different task types to specialized queues if needed | |
| routing_map = { | |
| TaskType.GENERATE_SCRIPT: "content_generation", | |
| TaskType.GENERATE_VISUALS: "content_generation", | |
| TaskType.COMPOSE_VIDEO: "video_processing", | |
| TaskType.POST_CONTENT: "distribution", | |
| TaskType.OPTIMIZE_RL: "ml_training", | |
| } | |
| return routing_map.get(task_type, "default") | |
| # ============================================================================ | |
| # WORKER MANAGER | |
| # ============================================================================ | |
| @dataclass | |
| class WorkerInfo: | |
| """Worker metadata and health info""" | |
| worker_id: str | |
| started_at: float | |
| last_heartbeat: float | |
| task_types: List[TaskType] | |
| current_tasks: List[str] | |
| total_completed: int = 0 | |
| total_failed: int = 0 | |
| cpu_percent: float = 0.0 | |
| memory_mb: float = 0.0 | |
| status: str = "active" # active, idle, overloaded, dead | |
| class WorkerManager: | |
| """ | |
| Manages worker pool, load balancing, and dynamic scaling. | |
| """ | |
| def __init__(self, redis_client: Redis, namespace: str = "viral_factory"): | |
| self.redis = redis_client | |
| self.namespace = namespace | |
| # Worker tracking | |
| self.workers: Dict[str, WorkerInfo] = {} | |
| self._lock = threading.RLock() | |
| # Configuration | |
| self.max_tasks_per_worker = 5 | |
| self.worker_timeout = 180 # 3 minutes no heartbeat = dead | |
| self.scale_up_threshold = 0.8 # 80% capacity | |
| self.scale_down_threshold = 0.2 # 20% capacity | |
| # Monitoring | |
| self._monitoring_active = False | |
| self._monitor_thread: Optional[threading.Thread] = None | |
| logger.info("WorkerManager initialized") | |
| def register_worker(self, worker_id: str, task_types: List[TaskType]) -> bool: | |
| """Register new worker""" | |
| with self._lock: | |
| self.workers[worker_id] = WorkerInfo( | |
| worker_id=worker_id, | |
| started_at=time.time(), | |
| last_heartbeat=time.time(), | |
| task_types=task_types, | |
| current_tasks=[] | |
| ) | |
| # Store in Redis | |
| key = f"{self.namespace}:worker:{worker_id}" | |
| self.redis.setex(key, self.worker_timeout, json.dumps({ | |
| 'worker_id': worker_id, | |
| 'task_types': [t.value for t in task_types], | |
| 'started_at': time.time() | |
| })) | |
| logger.info(f"Worker registered: {worker_id}") | |
| return True | |
| def update_heartbeat(self, worker_id: str, cpu: float = 0, memory: float = 0): | |
| """Update worker heartbeat and health metrics""" | |
| with self._lock: | |
| if worker_id not in self.workers: | |
| logger.warning(f"Heartbeat from unknown worker: {worker_id}") | |
| return False | |
| worker = self.workers[worker_id] | |
| worker.last_heartbeat = time.time() | |
| worker.cpu_percent = cpu | |
| worker.memory_mb = memory | |
| # Update Redis | |
| key = f"{self.namespace}:worker:{worker_id}:heartbeat" | |
| self.redis.setex(key, self.worker_timeout, time.time()) | |
| return True | |
| def assign_task(self, worker_id: str, task_id: str): | |
| """Assign task to worker""" | |
| with self._lock: | |
| if worker_id in self.workers: | |
| self.workers[worker_id].current_tasks.append(task_id) | |
| def complete_task(self, worker_id: str, task_id: str, success: bool = True): | |
| """Mark task completed by worker""" | |
| with self._lock: | |
| if worker_id in self.workers: | |
| worker = self.workers[worker_id] | |
| if task_id in worker.current_tasks: | |
| worker.current_tasks.remove(task_id) | |
| if success: | |
| worker.total_completed += 1 | |
| else: | |
| worker.total_failed += 1 | |
| def get_available_workers(self, task_type: TaskType) -> List[str]: | |
| """Get workers available for task type""" | |
| available = [] | |
| with self._lock: | |
| for worker_id, worker in self.workers.items(): | |
| # Check if worker handles this task type | |
| if task_type not in worker.task_types: | |
| continue | |
| # Check if worker is alive | |
| if time.time() - worker.last_heartbeat > self.worker_timeout: | |
| continue | |
| # Check if worker has capacity | |
| if len(worker.current_tasks) >= self.max_tasks_per_worker: | |
| continue | |
| available.append(worker_id) | |
| return available | |
| def get_least_loaded_worker(self, task_type: TaskType) -> Optional[str]: | |
| """Get worker with lowest load for task type""" | |
| available = self.get_available_workers(task_type) | |
| if not available: | |
| return None | |
| # Sort by current load | |
| with self._lock: | |
| return min(available, key=lambda wid: len(self.workers[wid].current_tasks)) | |
| def get_cluster_capacity(self) -> Dict[str, Any]: | |
| """Get overall cluster capacity metrics""" | |
| with self._lock: | |
| total_workers = len(self.workers) | |
| active_workers = sum(1 for w in self.workers.values() | |
| if time.time() - w.last_heartbeat < self.worker_timeout) | |
| total_capacity = active_workers * self.max_tasks_per_worker | |
| current_load = sum(len(w.current_tasks) for w in self.workers.values()) | |
| return { | |
| 'total_workers': total_workers, | |
| 'active_workers': active_workers, | |
| 'total_capacity': total_capacity, | |
| 'current_load': current_load, | |
| 'utilization': current_load / total_capacity if total_capacity > 0 else 0, | |
| 'available_slots': total_capacity - current_load | |
| } | |
| def should_scale_up(self) -> bool: | |
| """Check if cluster should scale up""" | |
| capacity = self.get_cluster_capacity() | |
| return capacity['utilization'] >= self.scale_up_threshold | |
| def should_scale_down(self) -> bool: | |
| """Check if cluster should scale down""" | |
| capacity = self.get_cluster_capacity() | |
| return capacity['utilization'] <= self.scale_down_threshold | |
| def cleanup_dead_workers(self) -> int: | |
| """Remove dead workers from tracking""" | |
| now = time.time() | |
| removed = 0 | |
| with self._lock: | |
| dead_workers = [ | |
| wid for wid, worker in self.workers.items() | |
| if now - worker.last_heartbeat > self.worker_timeout | |
| ] | |
| for wid in dead_workers: | |
| logger.warning(f"Removing dead worker: {wid}") | |
| del self.workers[wid] | |
| removed += 1 | |
| return removed | |
| def start_monitoring(self, interval: float = 30.0): | |
| """Start background worker monitoring""" | |
| if self._monitoring_active: | |
| return | |
| self._monitoring_active = True | |
| def monitor_loop(): | |
| while self._monitoring_active: | |
| try: | |
| # Cleanup dead workers | |
| removed = self.cleanup_dead_workers() | |
| if removed > 0: | |
| logger.warning(f"Cleaned up {removed} dead workers") | |
| # Check scaling needs | |
| if self.should_scale_up(): | |
| logger.info("Cluster needs to scale UP") | |
| elif self.should_scale_down(): | |
| logger.info("Cluster can scale DOWN") | |
| # Log capacity | |
| capacity = self.get_cluster_capacity() | |
| logger.info(f"Cluster capacity: {capacity['current_load']}/{capacity['total_capacity']} " | |
| f"({capacity['utilization']:.1%} utilization)") | |
| except Exception as e: | |
| logger.error(f"Worker monitoring error: {e}") | |
| time.sleep(interval) | |
| self._monitor_thread = threading.Thread(target=monitor_loop, daemon=True) | |
| self._monitor_thread.start() | |
| logger.info("Worker monitoring started") | |
| def stop_monitoring(self): | |
| """Stop worker monitoring""" | |
| self._monitoring_active = False | |
| if self._monitor_thread: | |
| self._monitor_thread.join(timeout=5) | |
| # ============================================================================ | |
| # SCHEDULER | |
| # ============================================================================ | |
| @dataclass | |
| class ScheduledTask: | |
| """Scheduled task configuration""" | |
| schedule_id: str | |
| task_type: TaskType | |
| metadata: Dict[str, Any] | |
| priority: TaskPriority | |
| cron_pattern: Optional[str] = None # Future: cron-like scheduling | |
| interval_seconds: Optional[int] = None # Recurring interval | |
| scheduled_time: Optional[float] = None # One-time execution | |
| last_run: Optional[float] = None | |
| next_run: Optional[float] = None | |
| enabled: bool = True | |
| class Scheduler: | |
| """ | |
| Task scheduler with support for: | |
| - One-time scheduled tasks | |
| - Recurring tasks (interval-based) | |
| - Burst pacing | |
| - Account load balancing | |
| - SLA deadline enforcement | |
| """ | |
| def __init__(self, redis_client: Redis, task_queue_manager, namespace: str = "viral_factory"): | |
| self.redis = redis_client | |
| self.queue_manager = task_queue_manager | |
| self.namespace = namespace | |
| # Scheduled tasks | |
| self.schedules: Dict[str, ScheduledTask] = {} | |
| self._lock = threading.RLock() | |
| # Burst pacing configuration | |
| self.burst_window_seconds = 3600 # 1 hour | |
| self.max_posts_per_burst = 10 | |
| self.min_post_interval = 60 # 1 minute between posts | |
| # Scheduler thread | |
| self._scheduler_active = False | |
| self._scheduler_thread: Optional[threading.Thread] = None | |
| logger.info("Scheduler initialized") | |
| def schedule_task(self, task_type: TaskType, metadata: Dict[str, Any], | |
| scheduled_time: float, priority: TaskPriority = TaskPriority.NORMAL, | |
| deadline: Optional[float] = None) -> str: | |
| """Schedule task for one-time execution""" | |
| schedule_id = str(uuid.uuid4()) | |
| scheduled_task = ScheduledTask( | |
| schedule_id=schedule_id, | |
| task_type=task_type, | |
| metadata=metadata, | |
| priority=priority, | |
| scheduled_time=scheduled_time, | |
| next_run=scheduled_time | |
| ) | |
| with self._lock: | |
| self.schedules[schedule_id] = scheduled_task | |
| # Store in Redis | |
| key = f"{self.namespace}:schedule:{schedule_id}" | |
| self.redis.set(key, json.dumps({ | |
| 'schedule_id': schedule_id, | |
| 'task_type': task_type.value, | |
| 'metadata': metadata, | |
| 'scheduled_time': scheduled_time, | |
| 'deadline': deadline | |
| })) | |
| logger.info(f"Task scheduled: {schedule_id} for {datetime.fromtimestamp(scheduled_time)}") | |
| return schedule_id | |
| def schedule_recurring(self, task_type: TaskType, metadata: Dict[str, Any], | |
| interval_seconds: int, priority: TaskPriority = TaskPriority.NORMAL) -> str: | |
| """Schedule recurring task""" | |
| schedule_id = str(uuid.uuid4()) | |
| next_run = time.time() + interval_seconds | |
| scheduled_task = ScheduledTask( | |
| schedule_id=schedule_id, | |
| task_type=task_type, | |
| metadata=metadata, | |
| priority=priority, | |
| interval_seconds=interval_seconds, | |
| next_run=next_run | |
| ) | |
| with self._lock: | |
| self.schedules[schedule_id] = scheduled_task | |
| logger.info(f"Recurring task scheduled: {schedule_id} (every {interval_seconds}s)") | |
| return schedule_id | |
| def cancel_schedule(self, schedule_id: str) -> bool: | |
| """Cancel scheduled task""" | |
| with self._lock: | |
| if schedule_id in self.schedules: | |
| del self.schedules[schedule_id] | |
| # Remove from Redis | |
| key = f"{self.namespace}:schedule:{schedule_id}" | |
| self.redis.delete(key) | |
| logger.info(f"Schedule cancelled: {schedule_id}") | |
| return True | |
| return False | |
| def check_burst_limit(self, account_id: str) -> bool: | |
| """Check if account can post within burst limits""" | |
| key = f"{self.namespace}:burst:{account_id}" | |
| # Get recent post count | |
| now = time.time() | |
| window_start = now - self.burst_window_seconds | |
| # Use sorted set to track posts in time window | |
| self.redis.zremrangebyscore(key, 0, window_start) # Cleanup old | |
| recent_count = self.redis.zcard(key) | |
| return recent_count < self.max_posts_per_burst | |
| def record_post(self, account_id: str, task_id: str): | |
| """Record post for burst tracking""" | |
| key = f"{self.namespace}:burst:{account_id}" | |
| self.redis.zadd(key, {task_id: time.time()}) | |
| self.redis.expire(key, self.burst_window_seconds * 2) | |
| def enforce_min_interval(self, account_id: str) -> float: | |
| """Get delay needed to enforce minimum interval""" | |
| key = f"{self.namespace}:last_post:{account_id}" | |
| last_post_time = self.redis.get(key) | |
| if not last_post_time: | |
| return 0 | |
| last_post_time = float(last_post_time) | |
| elapsed = time.time() - last_post_time | |
| if elapsed < self.min_post_interval: | |
| return self.min_post_interval - elapsed | |
| return 0 | |
| def _process_schedules(self): | |
| """Process scheduled tasks (called by scheduler thread)""" | |
| now = time.time() | |
| tasks_to_submit = [] | |
| with self._lock: | |
| for schedule in list(self.schedules.values()): | |
| if not schedule.enabled: | |
| continue | |
| if schedule.next_run and now >= schedule.next_run: | |
| # Check burst limits for account-based tasks | |
| account_id = schedule.metadata.get('account_id') | |
| if account_id: | |
| if not self.check_burst_limit(account_id): | |
| logger.warning(f"Burst limit reached for account {account_id}, delaying") | |
| schedule.next_run = now + 300 # Delay 5 minutes | |
| continue | |
| # Check minimum interval | |
| delay = self.enforce_min_interval(account_id) | |
| if delay > 0: | |
| logger.info(f"Enforcing min interval for {account_id}, delay {delay}s") | |
| schedule.next_run = now + delay | |
| continue | |
| tasks_to_submit.append(schedule) | |
| # Submit tasks outside lock | |
| for schedule in tasks_to_submit: | |
| try: | |
| # Submit to queue | |
| task_id = self.queue_manager.submit_task( | |
| task_type=schedule.task_type, | |
| metadata=schedule.metadata, | |
| priority=schedule.priority, | |
| scheduled_for=time.time() | |
| ) | |
| logger.info(f"Scheduled task submitted: {task_id}") | |
| # Record post for burst tracking | |
| account_id = schedule.metadata.get('account_id') | |
| if account_id: | |
| self.record_post(account_id, task_id) | |
| key = f"{self.namespace}:last_post:{account_id}" | |
| self.redis.setex(key, self.burst_window_seconds, time.time()) | |
| # Update schedule | |
| with self._lock: | |
| schedule.last_run = now | |
| if schedule.interval_seconds: | |
| # Recurring - schedule next run | |
| schedule.next_run = now + schedule.interval_seconds | |
| else: | |
| # One-time - disable | |
| schedule.enabled = False | |
| self.cancel_schedule(schedule.schedule_id) | |
| except Exception as e: | |
| logger.error(f"Error submitting scheduled task: {e}") | |
| def start(self, check_interval: float = 10.0): | |
| """Start scheduler background thread""" | |
| if self._scheduler_active: | |
| return | |
| self._scheduler_active = True | |
| def scheduler_loop(): | |
| while self._scheduler_active: | |
| try: | |
| self._process_schedules() | |
| except Exception as e: | |
| logger.error(f"Scheduler error: {e}") | |
| time.sleep(check_interval) | |
| self._scheduler_thread = threading.Thread(target=scheduler_loop, daemon=True) | |
| self._scheduler_thread.start() | |
| logger.info("Scheduler started") | |
| def stop(self): | |
| """Stop scheduler""" | |
| self._scheduler_active = False | |
| if self._scheduler_thread: | |
| self._scheduler_thread.join(timeout=5) | |
| logger.info("Scheduler stopped") | |
| # ============================================================================ | |
| # REDIS-BACKED PERSISTENT QUEUE (Enhanced with Integration Points) | |
| # ============================================================================ | |
| class RedisTaskQueue: | |
| """ | |
| Production-grade distributed task queue with Redis backend. | |
| Enhanced with full orchestration integration. | |
| """ | |
| # Redis key prefixes | |
| TASK_KEY = "task:{task_id}" | |
| QUEUE_KEY = "queue:{priority}:{task_type}" | |
| DELAYED_QUEUE = "delayed_tasks" | |
| CLAIMED_KEY = "claimed:{worker_id}" | |
| DEPENDENCY_KEY = "dependencies:{task_id}" | |
| STATUS_INDEX = "status:{status}" | |
| TYPE_INDEX = "type:{task_type}" | |
| WORKER_HEARTBEAT = "worker:{worker_id}:heartbeat" | |
| METRICS_KEY = "metrics:queue" | |
| LEASE_LOCK = "lease:lock" | |
| SHUTDOWN_KEY = "system:shutdown" | |
| def __init__(self, redis_client: Redis, namespace: str = "viral_factory"): | |
| self.redis = redis_client | |
| self.namespace = namespace | |
| self._lock = threading.RLock() | |
| # Local cache | |
| self._local_cache: Dict[str, Tuple[TaskPayload, float]] = {} | |
| self._cache_ttl = 60 | |
| # Hooks | |
| self._enqueue_hooks: List[Callable] = [] | |
| self._dequeue_hooks: List[Callable] = [] | |
| self._status_hooks: List[Callable] = [] | |
| self._claim_hooks: List[Callable] = [] | |
| # Metrics | |
| self._local_metrics = { | |
| 'total_enqueued': 0, | |
| 'total_dequeued': 0, | |
| 'total_completed': 0, | |
| 'total_failed': 0, | |
| 'total_retries': 0, | |
| 'total_timeouts': 0 | |
| } | |
| # Worker settings | |
| self.lease_duration = 300 | |
| self.heartbeat_interval = 30 | |
| # Maintenance | |
| self._maintenance_active = False | |
| self._maintenance_thread: Optional[threading.Thread] = None | |
| # Shutdown flag | |
| self._shutdown_requested = False | |
| logger.info("RedisTaskQueue initialized") | |
| def _key(self, pattern: str, **kwargs) -> str: | |
| """Generate namespaced Redis key""" | |
| key = pattern.format(**kwargs) | |
| return f"{self.namespace}:{key}" | |
| def enqueue(self, payload: TaskPayload) -> str: | |
| """Enqueue task with validation and persistence""" | |
| # Check shutdown | |
| if self._shutdown_requested: | |
| raise RuntimeError("Queue is shutting down, not accepting new tasks") | |
| if not payload.task_id: | |
| payload.task_id = str(uuid.uuid4()) | |
| self._validate_payload(payload) | |
| # Idempotency check | |
| if payload.idempotency_key: | |
| existing = self._check_idempotency(payload.idempotency_key) | |
| if existing: | |
| return existing | |
| payload.status = TaskStatus.QUEUED | |
| # Store task | |
| task_key = self._key(self.TASK_KEY, task_id=payload.task_id) | |
| self.redis.set(task_key, pickle.dumps(payload)) | |
| # Handle scheduling | |
| now = time.time() | |
| if payload.scheduled_for > now: | |
| self.redis.zadd( | |
| self._key(self.DELAYED_QUEUE), | |
| {payload.task_id: payload.scheduled_for} | |
| ) | |
| payload.status = TaskStatus.PENDING | |
| else: | |
| self._add_to_queue(payload) | |
| self._index_task(payload) | |
| # Dependencies | |
| if payload.depends_on: | |
| dep_key = self._key(self.DEPENDENCY_KEY, task_id=payload.task_id) | |
| self.redis.sadd(dep_key, *payload.depends_on) | |
| payload.status = TaskStatus.BLOCKED | |
| self._increment_metric('total_enqueued') | |
| # Idempotency mapping | |
| if payload.idempotency_key: | |
| idem_key = self._key(f"idempotency:{payload.idempotency_key}") | |
| self.redis.setex(idem_key, 86400, payload.task_id) | |
| self._execute_hooks(self._enqueue_hooks, payload) | |
| logger.debug(f"Task enqueued: {payload.task_id}") | |
| return payload.task_id | |
| def _validate_payload(self, payload: TaskPayload): | |
| """Validate task payload""" | |
| if not payload.task_type: | |
| raise ValueError("task_type is required") | |
| if payload.scheduled_for < 0: | |
| raise ValueError("scheduled_for must be >= 0") | |
| if payload.max_retries < 0: | |
| raise ValueError("max_retries must be >= 0") | |
| if payload.timeout_seconds <= 0: | |
| raise ValueError("timeout_seconds must be > 0") | |
| def _check_idempotency(self, idempotency_key: str) -> Optional[str]: | |
| """Check if task with idempotency key exists""" | |
| idem_key = self._key(f"idempotency:{idempotency_key}") | |
| result = self.redis.get(idem_key) | |
| return result.decode('utf-8') if result else None | |
| def _add_to_queue(self, payload: TaskPayload): | |
| """Add task to priority queue""" | |
| queue_key = self._key( | |
| self.QUEUE_KEY, | |
| priority=payload.priority.value, | |
| task_type=payload.task_type.value | |
| ) | |
| self.redis.zadd(queue_key, {payload.task_id: payload.scheduled_for}) | |
| def _index_task(self, payload: TaskPayload): | |
| """Add task to indexes""" | |
| status_key = self._key(self.STATUS_INDEX, status=payload.status.value) | |
| self.redis.sadd(status_key, payload.task_id) | |
| type_key = self._key(self.TYPE_INDEX, task_type=payload.task_type.value) | |
| self.redis.sadd(type_key, payload.task_id) | |
| def _remove_from_indexes(self, payload: TaskPayload): | |
| """Remove task from indexes""" | |
| status_key = self._key(self.STATUS_INDEX, status=payload.status.value) | |
| self.redis.srem(status_key, payload.task_id) | |
| type_key = self._key(self.TYPE_INDEX, task_type=payload.task_type.value) | |
| self.redis.srem(type_key, payload.task_id) | |
| def dequeue(self, worker_id: str, task_types: Optional[List[TaskType]] = None) -> Optional[TaskPayload]: | |
| """Dequeue next task for worker""" | |
| self._update_worker_heartbeat(worker_id) | |
| for priority in TaskPriority: | |
| task_types_to_check = task_types or list(TaskType) | |
| for task_type in task_types_to_check: | |
| queue_key = self._key( | |
| self.QUEUE_KEY, | |
| priority=priority.value, | |
| task_type=task_type.value | |
| ) | |
| result = self.redis.zpopmin(queue_key, 1) | |
| if not result: | |
| continue | |
| task_id = result[0][0].decode('utf-8') if isinstance(result[0][0], bytes) else result[0][0] | |
| task = self._claim_task(task_id, worker_id) | |
| if task: | |
| self._execute_hooks(self._dequeue_hooks, task) | |
| self._increment_metric('total_dequeued') | |
| return task | |
| return None | |
| def claim_next_task(self, worker_id: str, task_types: Optional[List[TaskType]] = None) -> Optional[TaskPayload]: | |
| """Claim next available task""" | |
| self._process_delayed_tasks() | |
| task = self.dequeue(worker_id, task_types) | |
| if not task: | |
| return None | |
| if not self.are_dependencies_met(task): | |
| task.status = TaskStatus.BLOCKED | |
| self._add_to_queue(task) | |
| self._store_task(task) | |
| return None | |
| return task | |
| def _claim_task(self, task_id: str, worker_id: str) -> Optional[TaskPayload]: | |
| """Atomically claim task""" | |
| lock_key = self._key(self.LEASE_LOCK) | |
| lock = self.redis.lock(lock_key, timeout=5) | |
| try: | |
| if not lock.acquire(blocking=True, blocking_timeout=5): | |
| return None | |
| task = self._get_task(task_id) | |
| if not task: | |
| return None | |
| if task.status in [TaskStatus.CLAIMED, TaskStatus.RUNNING]: | |
| return None | |
| task.status = TaskStatus.CLAIMED | |
| task.worker_id = worker_id | |
| task.claimed_at = time.time() | |
| task.lease_expires_at = time.time() + self.lease_duration | |
| self._store_task(task) | |
| claimed_key = self._key(self.CLAIMED_KEY, worker_id=worker_id) | |
| self.redis.sadd(claimed_key, task_id) | |
| self._execute_hooks(self._claim_hooks, task) | |
| return task | |
| finally: | |
| lock.release() | |
| def start_task(self, task_id: str) -> bool: | |
| """Mark task as running""" | |
| task = self._get_task(task_id) | |
| if not task: | |
| return False | |
| old_status = task.status | |
| task.status = TaskStatus.RUNNING | |
| task.started_at = time.time() | |
| self._update_task_status(task, old_status) | |
| return True | |
| def complete_task(self, task_id: str, result: Optional[Dict] = None) -> bool: | |
| """Mark task completed""" | |
| task = self._get_task(task_id) | |
| if not task: | |
| return False | |
| old_status = task.status | |
| task.status = TaskStatus.COMPLETED | |
| task.completed_at = time.time() | |
| task.result = result | |
| if task.worker_id: | |
| claimed_key = self._key(self.CLAIMED_KEY, worker_id=task.worker_id) | |
| self.redis.srem(claimed_key, task_id) | |
| self._update_task_status(task, old_status) | |
| self._increment_metric('total_completed') | |
| self._unblock_dependents(task_id) | |
| return True | |
| def fail_task(self, task_id: str, error: str, should_retry: bool = True) -> bool: | |
| """Mark task failed with retry logic""" | |
| task = self._get_task(task_id) | |
| if not task: | |
| return False | |
| old_status = task.status | |
| task.last_error = error | |
| task.error_history.append({ | |
| 'error': error, | |
| 'timestamp': time.time(), | |
| 'attempt': task.retry_count | |
| }) | |
| if should_retry and task.retry_count < task.max_retries: | |
| task.retry_count += 1 | |
| task.status = TaskStatus.RETRY | |
| delay = min(300, 2 ** task.retry_count * 5) | |
| task.scheduled_for = time.time() + delay | |
| self.redis.zadd( | |
| self._key(self.DELAYED_QUEUE), | |
| {task_id: task.scheduled_for} | |
| ) | |
| self._increment_metric('total_retries') | |
| else: | |
| task.status = TaskStatus.FAILED | |
| self._increment_metric('total_failed') | |
| if task.worker_id: | |
| claimed_key = self._key(self.CLAIMED_KEY, worker_id=task.worker_id) | |
| self.redis.srem(claimed_key, task_id) | |
| task.worker_id = None | |
| self._update_task_status(task, old_status) | |
| return True | |
| def timeout_task(self, task_id: str) -> bool: | |
| """Mark task timed out""" | |
| task = self._get_task(task_id) | |
| if not task: | |
| return False | |
| old_status = task.status | |
| task.status = TaskStatus.TIMEOUT | |
| task.last_error = "Task execution timeout" | |
| if task.worker_id: | |
| claimed_key = self._key(self.CLAIMED_KEY, worker_id=task.worker_id) | |
| self.redis.srem(claimed_key, task_id) | |
| self._update_task_status(task, old_status) | |
| self._increment_metric('total_timeouts') | |
| return self.fail_task(task_id, "Execution timeout", should_retry=True) | |
| def cancel_task(self, task_id: str) -> bool: | |
| """Cancel task""" | |
| task = self._get_task(task_id) | |
| if not task: | |
| return False | |
| if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]: | |
| return False | |
| old_status = task.status | |
| task.status = TaskStatus.CANCELLED | |
| for priority in TaskPriority: | |
| for task_type in TaskType: | |
| queue_key = self._key( | |
| self.QUEUE_KEY, | |
| priority=priority.value, | |
| task_type=task_type.value | |
| ) | |
| self.redis.zrem(queue_key, task_id) | |
| self.redis.zrem(self._key(self.DELAYED_QUEUE), task_id) | |
| self._update_task_status(task, old_status) | |
| return True | |
| def _update_task_status(self, task: TaskPayload, old_status: TaskStatus): | |
| """Update task status""" | |
| old_key = self._key(self.STATUS_INDEX, status=old_status.value) | |
| self.redis.srem(old_key, task.task_id) | |
| new_key = self._key(self.STATUS_INDEX, status=task.status.value) | |
| self.redis.sadd(new_key, task.task_id) | |
| self._store_task(task) | |
| self._execute_hooks(self._status_hooks, task) | |
| def are_dependencies_met(self, task: TaskPayload) -> bool: | |
| """Check if dependencies are met""" | |
| if not task.depends_on: | |
| return True | |
| for dep_id in task.depends_on: | |
| dep_task = self._get_task(dep_id) | |
| if not dep_task or dep_task.status != TaskStatus.COMPLETED: | |
| return False | |
| return True | |
| def _unblock_dependents(self, task_id: str): | |
| """Unblock dependent tasks""" | |
| for status in [TaskStatus.BLOCKED, TaskStatus.PENDING]: | |
| status_key = self._key(self.STATUS_INDEX, status=status.value) | |
| blocked_ids = self.redis.smembers(status_key) | |
| for blocked_id in blocked_ids: | |
| if isinstance(blocked_id, bytes): | |
| blocked_id = blocked_id.decode('utf-8') | |
| blocked_task = self._get_task(blocked_id) | |
| if not blocked_task: | |
| continue | |
| if task_id in blocked_task.depends_on: | |
| if self.are_dependencies_met(blocked_task): | |
| blocked_task.status = TaskStatus.QUEUED | |
| self._add_to_queue(blocked_task) | |
| self._update_task_status(blocked_task, TaskStatus.BLOCKED) | |
| def _update_worker_heartbeat(self, worker_id: str): | |
| """Update worker heartbeat""" | |
| heartbeat_key = self._key(self.WORKER_HEARTBEAT, worker_id=worker_id) | |
| self.redis.setex(heartbeat_key, self.heartbeat_interval * 3, time.time()) | |
| def renew_lease(self, task_id: str, worker_id: str) -> bool: | |
| """Renew task lease""" | |
| task = self._get_task(task_id) | |
| if not task or task.worker_id != worker_id: | |
| return False | |
| task.lease_expires_at = time.time() + self.lease_duration | |
| self._store_task(task) | |
| self._update_worker_heartbeat(worker_id) | |
| return True | |
| def reclaim_expired_leases(self) -> int: | |
| """Reclaim expired leases""" | |
| now = time.time() | |
| reclaimed = 0 | |
| for status in [TaskStatus.CLAIMED, TaskStatus.RUNNING]: | |
| status_key = self._key(self.STATUS_INDEX, status=status.value) | |
| task_ids = self.redis.smembers(status_key) | |
| for task_id in task_ids: | |
| if isinstance(task_id, bytes): | |
| task_id = task_id.decode('utf-8') | |
| task = self._get_task(task_id) | |
| if not task: | |
| continue | |
| if task.lease_expires_at and now > task.lease_expires_at: | |
| if task.worker_id: | |
| claimed_key = self._key(self.CLAIMED_KEY, worker_id=task.worker_id) | |
| self.redis.srem(claimed_key, task_id) | |
| task.worker_id = None | |
| task.claimed_at = None | |
| task.lease_expires_at = None | |
| task.status = TaskStatus.QUEUED | |
| self._add_to_queue(task) | |
| self._update_task_status(task, status) | |
| reclaimed += 1 | |
| return reclaimed | |
| def _process_delayed_tasks(self): | |
| """Move delayed tasks to active queue""" | |
| now = time.time() | |
| delayed_key = self._key(self.DELAYED_QUEUE) | |
| ready_tasks = self.redis.zrangebyscore(delayed_key, 0, now) | |
| for task_id in ready_tasks: | |
| if isinstance(task_id, bytes): | |
| task_id = task_id.decode('utf-8') | |
| task = self._get_task(task_id) | |
| if not task: | |
| self.redis.zrem(delayed_key, task_id) | |
| continue | |
| task.status = TaskStatus.QUEUED | |
| self._add_to_queue(task) | |
| self._store_task(task) | |
| self.redis.zrem(delayed_key, task_id) | |
| def _get_task(self, task_id: str) -> Optional[TaskPayload]: | |
| """Get task with caching""" | |
| if task_id in self._local_cache: | |
| cached, timestamp = self._local_cache[task_id] | |
| if time.time() - timestamp < self._cache_ttl: | |
| return cached | |
| task_key = self._key(self.TASK_KEY, task_id=task_id) | |
| data = self.redis.get(task_key) | |
| if not data: | |
| return None | |
| task = pickle.loads(data) | |
| self._local_cache[task_id] = (task, time.time()) | |
| return task | |
| def _store_task(self, task: TaskPayload): | |
| """Store task in Redis""" | |
| task_key = self._key(self.TASK_KEY, task_id=task.task_id) | |
| self.redis.set(task_key, pickle.dumps(task)) | |
| self._local_cache[task.task_id] = (task, time.time()) | |
| def get_task(self, task_id: str) -> Optional[TaskPayload]: | |
| """Public API to get task""" | |
| return self._get_task(task_id) | |
| def get_tasks_by_status(self, status: TaskStatus, limit: int = 100) -> List[TaskPayload]: | |
| """Get tasks by status""" | |
| status_key = self._key(self.STATUS_INDEX, status=status.value) | |
| task_ids = self.redis.smembers(status_key) | |
| tasks = [] | |
| for task_id in list(task_ids)[:limit]: | |
| if isinstance(task_id, bytes): | |
| task_id = task_id.decode('utf-8') | |
| task = self._get_task(task_id) | |
| if task: | |
| tasks.append(task) | |
| return tasks | |
| def _increment_metric(self, metric: str, amount: int = 1): | |
| """Increment metric""" | |
| self._local_metrics[metric] = self._local_metrics.get(metric, 0) + amount | |
| metrics_key = self._key(self.METRICS_KEY) | |
| self.redis.hincrby(metrics_key, metric, amount) | |
| def get_metrics(self) -> Dict[str, Any]: | |
| """Get queue metrics""" | |
| metrics_key = self._key(self.METRICS_KEY) | |
| redis_metrics = self.redis.hgetall(metrics_key) | |
| decoded = { | |
| k.decode('utf-8') if isinstance(k, bytes) else k: | |
| int(v.decode('utf-8')) if isinstance(v, bytes) else int(v) | |
| for k, v in redis_metrics.items() | |
| } | |
| for key, value in self._local_metrics.items(): | |
| decoded[key] = decoded.get(key, 0) + value | |
| return decoded | |
| def _execute_hooks(self, hooks: List[Callable], *args): | |
| """Execute hooks""" | |
| for hook in hooks: | |
| try: | |
| hook(*args) | |
| except Exception as e: | |
| logger.error(f"Hook execution error: {e}") | |
| def register_enqueue_hook(self, hook: Callable): | |
| self._enqueue_hooks.append(hook) | |
| def register_status_hook(self, hook: Callable): | |
| self._status_hooks.append(hook) | |
| def start_maintenance(self, interval: float = 60.0): | |
| """Start maintenance thread""" | |
| if self._maintenance_active: | |
| return | |
| self._maintenance_active = True | |
| def maintenance_loop(): | |
| while self._maintenance_active: | |
| try: | |
| reclaimed = self.reclaim_expired_leases() | |
| if reclaimed > 0: | |
| logger.info(f"Reclaimed {reclaimed} expired leases") | |
| self._process_delayed_tasks() | |
| except Exception as e: | |
| logger.error(f"Maintenance error: {e}") | |
| time.sleep(interval) | |
| self._maintenance_thread = threading.Thread(target=maintenance_loop, daemon=True) | |
| self._maintenance_thread.start() | |
| logger.info("Maintenance started") | |
| def stop_maintenance(self): | |
| """Stop maintenance""" | |
| self._maintenance_active = False | |
| if self._maintenance_thread: | |
| self._maintenance_thread.join(timeout=5) | |
| def request_shutdown(self): | |
| """Request graceful shutdown""" | |
| self._shutdown_requested = True | |
| shutdown_key = self._key(self.SHUTDOWN_KEY) | |
| self.redis.set(shutdown_key, "1", ex=3600) | |
| logger.warning("Shutdown requested") | |
| def cancel_all_pending(self) -> int: | |
| """Cancel all pending/queued tasks""" | |
| cancelled = 0 | |
| for status in [TaskStatus.PENDING, TaskStatus.QUEUED]: | |
| tasks = self.get_tasks_by_status(status, limit=10000) | |
| for task in tasks: | |
| if self.cancel_task(task.task_id): | |
| cancelled += 1 | |
| logger.info(f"Cancelled {cancelled} pending tasks") | |
| return cancelled | |
| # ============================================================================ | |
| # HIGH-LEVEL ORCHESTRATION MANAGER | |
| # ============================================================================ | |
| class TaskQueueManager: | |
| """ | |
| High-level orchestration manager coordinating all components. | |
| """ | |
| def __init__(self, redis_url: str = "redis://localhost:6379/0"): | |
| # Initialize Redis | |
| self.redis_client = redis.from_url(redis_url, decode_responses=False) | |
| # Core queue | |
| self.queue = RedisTaskQueue(self.redis_client) | |
| # Orchestration components | |
| self.priority_router = PriorityRouter(self.redis_client) | |
| self.worker_manager = WorkerManager(self.redis_client) | |
| self.scheduler = Scheduler(self.redis_client, self) | |
| # Start background services | |
| self.queue.start_maintenance() | |
| self.worker_manager.start_monitoring() | |
| self.scheduler.start() | |
| logger.info("TaskQueueManager initialized with full orchestration") | |
| def submit_task(self, task_type: TaskType, metadata: Dict[str, Any], | |
| priority: Optional[TaskPriority] = None, | |
| depends_on: List[str] = None, | |
| scheduled_for: Optional[float] = None, | |
| deadline: Optional[float] = None) -> str: | |
| """Submit task with automatic priority routing""" | |
| # Create payload | |
| payload = TaskPayload( | |
| task_id="", | |
| task_type=task_type, | |
| priority=priority or TaskPriority.NORMAL, | |
| created_at=time.time(), | |
| scheduled_for=scheduled_for or time.time(), | |
| depends_on=depends_on or [], | |
| metadata=metadata, | |
| deadline=deadline | |
| ) | |
| # Route priority if not explicitly set | |
| if not priority: | |
| payload.priority = self.priority_router.route_task(payload) | |
| logger.info(f"Task auto-routed to priority: {payload.priority.value}") | |
| # Submit to queue | |
| return self.queue.enqueue(payload) | |
| def get_next_task(self, worker_id: str, task_types: Optional[List[TaskType]] = None) -> Optional[TaskPayload]: | |
| """Get next task for worker with load balancing""" | |
| # Update worker heartbeat | |
| self.worker_manager.update_heartbeat(worker_id) | |
| # Check worker capacity | |
| worker_tasks = self.worker_manager.workers.get(worker_id) | |
| if worker_tasks and len(worker_tasks.current_tasks) >= self.worker_manager.max_tasks_per_worker: | |
| logger.warning(f"Worker {worker_id} at capacity") | |
| return None | |
| # Get task | |
| task = self.queue.claim_next_task(worker_id, task_types) | |
| if task: | |
| self.worker_manager.assign_task(worker_id, task.task_id) | |
| logger.info(f"Task {task.task_id} assigned to worker {worker_id}") | |
| return task | |
| def mark_completed(self, task_id: str, worker_id: str, result: Dict = None): | |
| """Mark task completed""" | |
| self.queue.complete_task(task_id, result) | |
| self.worker_manager.complete_task(worker_id, task_id, success=True) | |
| def mark_failed(self, task_id: str, worker_id: str, error: str): | |
| """Mark task failed""" | |
| self.queue.fail_task(task_id, error) | |
| self.worker_manager.complete_task(worker_id, task_id, success=False) | |
| def schedule_task(self, task_type: TaskType, metadata: Dict[str, Any], | |
| scheduled_time: float, priority: TaskPriority = TaskPriority.NORMAL) -> str: | |
| """Schedule task for future execution""" | |
| return self.scheduler.schedule_task(task_type, metadata, scheduled_time, priority) | |
| def schedule_recurring(self, task_type: TaskType, metadata: Dict[str, Any], | |
| interval_seconds: int, priority: TaskPriority = TaskPriority.NORMAL) -> str: | |
| """Schedule recurring task""" | |
| return self.scheduler.schedule_recurring(task_type, metadata, interval_seconds, priority) | |
| def get_stats(self) -> Dict[str, Any]: | |
| """Get comprehensive system statistics""" | |
| return { | |
| 'queue_metrics': self.queue.get_metrics(), | |
| 'worker_capacity': self.worker_manager.get_cluster_capacity(), | |
| 'active_workers': len(self.worker_manager.workers), | |
| 'scheduled_tasks': len(self.scheduler.schedules) | |
| } | |
| def graceful_shutdown(self, timeout: float = 60.0): | |
| """ | |
| Graceful shutdown with proper cleanup. | |
| 1. Stop accepting new tasks | |
| 2. Wait for running tasks to complete | |
| 3. Cancel pending tasks | |
| 4. Stop all background threads | |
| """ | |
| logger.warning("=== INITIATING GRACEFUL SHUTDOWN ===") | |
| # Request shutdown | |
| self.queue.request_shutdown() | |
| # Stop scheduler from creating new tasks | |
| self.scheduler.stop() | |
| logger.info("Scheduler stopped") | |
| # Wait for running tasks | |
| start_time = time.time() | |
| while time.time() - start_time < timeout: | |
| running_tasks = self.queue.get_tasks_by_status(TaskStatus.RUNNING) | |
| if not running_tasks: | |
| break | |
| logger.info(f"Waiting for {len(running_tasks)} running tasks...") | |
| time.sleep(5) | |
| # Cancel remaining tasks | |
| cancelled = self.queue.cancel_all_pending() | |
| logger.info(f"Cancelled {cancelled} pending tasks") | |
| # Stop background services | |
| self.queue.stop_maintenance() | |
| self.worker_manager.stop_monitoring() | |
| # Close Redis connection | |
| self.redis_client.close() | |
| logger.warning("=== SHUTDOWN COMPLETE ===") | |
| # ============================================================================ | |
| # EXAMPLE USAGE | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| # Initialize full system | |
| manager = TaskQueueManager("redis://localhost:6379/0") | |
| # Register custom routing rule | |
| def high_value_rule(task: TaskPayload) -> Optional[TaskPriority]: | |
| if task.metadata.get('high_value'): | |
| return TaskPriority.HIGH | |
| return None | |
| manager.priority_router.add_routing_rule(high_value_rule) | |
| # Register worker | |
| manager.worker_manager.register_worker("worker_001", [TaskType.GENERATE_SCRIPT]) | |
| # Submit immediate task | |
| task_id = manager.submit_task( | |
| task_type=TaskType.GENERATE_SCRIPT, | |
| metadata={"niche": "tech", "high_value": True}, | |
| deadline=time.time() + 3600 | |
| ) | |
| logger.info(f"Submitted task: {task_id}") | |
| # Schedule future task | |
| future_task_id = manager.schedule_task( | |
| task_type=TaskType.POST_CONTENT, | |
| metadata={"platform": "tiktok", "account_id": "acc_001"}, | |
| scheduled_time=time.time() + 300, # 5 minutes | |
| priority=TaskPriority.HIGH | |
| ) | |
| logger.info(f"Scheduled task: {future_task_id}") | |
| # Schedule recurring health check | |
| recurring_id = manager.schedule_recurring( | |
| task_type=TaskType.HEALTH_CHECK, | |
| metadata={}, | |
| interval_seconds=3600, # Every hour | |
| priority=TaskPriority.CRITICAL | |
| ) | |
| logger.info(f"Recurring task: {recurring_id}") | |
| # Worker gets task | |
| task = manager.get_next_task("worker_001") | |
| if task: | |
| logger.info(f"Worker got task: {task.task_id}") | |
| # Simulate work | |
| time.sleep(2) | |
| # Complete | |
| manager.mark_completed(task.task_id, "worker_001", {"status": "success"}) | |
| logger.info("Task completed") | |
| # Get statistics | |
| stats = manager.get_stats() | |
| logger.info(f"System stats: {stats}") | |
| # Graceful shutdown | |
| time.sleep(5) | |
| manager.graceful_shutdown() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment