Skip to content

Instantly share code, notes, and snippets.

@bogged-broker
Created December 29, 2025 20:00
Show Gist options
  • Select an option

  • Save bogged-broker/90a7d8306ec2c9732454281deb5fbe64 to your computer and use it in GitHub Desktop.

Select an option

Save bogged-broker/90a7d8306ec2c9732454281deb5fbe64 to your computer and use it in GitHub Desktop.
"""
Production-Grade Task Queue System - ENHANCED
Distributed, persistent, scalable task queue with Redis backend, worker coordination,
dependency management, retry logic, full orchestration integration, scheduler, priority
routing, worker load balancing, and graceful shutdown.
"""
import json
import time
import uuid
import pickle
import logging
from collections import defaultdict
from dataclasses import dataclass, field, asdict
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Callable, Tuple
import threading
import redis
from redis import Redis
from datetime import datetime, timedelta
import heapq
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================================
# TASK DEFINITIONS & SCHEMA
# ============================================================================
class TaskType(Enum):
"""Complete enumeration of all task types in the system"""
# Content Generation Pipeline
GENERATE_SCRIPT = "generate_script"
GENERATE_VISUALS = "generate_visuals"
GENERATE_AUDIO = "generate_audio"
GENERATE_THUMBNAIL = "generate_thumbnail"
COMPOSE_VIDEO = "compose_video"
# Optimization & ML
OPTIMIZE_RL = "optimize_rl"
UPDATE_PRETRAIN = "update_pretrain"
TRAIN_MODEL = "train_model"
AB_TEST = "ab_test"
# Distribution
POST_CONTENT = "post_content"
SCHEDULE_POST = "schedule_post"
VIRAL_BOOST = "viral_boost"
# Analysis & Monitoring
ANALYZE_PERFORMANCE = "analyze_performance"
ANALYZE_METRICS = "analyze_metrics"
SCRAPE_TRENDING = "scrape_trending"
# Maintenance
CLEANUP = "cleanup"
BACKUP = "backup"
HEALTH_CHECK = "health_check"
class TaskPriority(Enum):
"""Priority levels with numeric values for sorting"""
CRITICAL = 0 # System-critical, immediate execution
HIGH = 1 # High priority, viral candidates
NORMAL = 2 # Standard tasks
LOW = 3 # Background tasks
BACKGROUND = 4 # Cleanup, maintenance
class TaskStatus(Enum):
"""Complete task lifecycle states"""
PENDING = "pending" # Created, not yet queued
QUEUED = "queued" # In queue, waiting
CLAIMED = "claimed" # Claimed by worker
RUNNING = "running" # Currently executing
COMPLETED = "completed" # Successfully completed
FAILED = "failed" # Failed, no more retries
RETRY = "retry" # Failed, will retry
CANCELLED = "cancelled" # Manually cancelled
TIMEOUT = "timeout" # Execution timeout
BLOCKED = "blocked" # Waiting on dependencies
@dataclass
class TaskPayload:
"""
Complete task payload with all metadata for distributed execution.
Serializable for Redis storage.
"""
# Core identification
task_id: str
task_type: TaskType
priority: TaskPriority
# Timing
created_at: float
scheduled_for: float # Unix timestamp
started_at: Optional[float] = None
completed_at: Optional[float] = None
deadline: Optional[float] = None # Hard deadline
# Dependencies
depends_on: List[str] = field(default_factory=list)
blocks: List[str] = field(default_factory=list)
# Execution metadata
metadata: Dict[str, Any] = field(default_factory=dict)
result: Optional[Dict[str, Any]] = None
# Retry tracking
retry_count: int = 0
max_retries: int = 3
last_error: Optional[str] = None
error_history: List[Dict] = field(default_factory=list)
# Status tracking
status: TaskStatus = TaskStatus.PENDING
worker_id: Optional[str] = None
claimed_at: Optional[float] = None
lease_expires_at: Optional[float] = None
# Content-specific fields
video_id: Optional[str] = None
account_id: Optional[str] = None
platform: Optional[str] = None
# RL/ML optimization flags
is_viral_candidate: bool = False
predicted_engagement: float = 0.0
predicted_virality: float = 0.0
ab_test_variant: Optional[str] = None
# Execution context
timeout_seconds: int = 300
idempotency_key: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to serializable dictionary"""
data = asdict(self)
data['task_type'] = self.task_type.value
data['priority'] = self.priority.value
data['status'] = self.status.value
return data
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'TaskPayload':
"""Create from dictionary"""
data['task_type'] = TaskType(data['task_type'])
data['priority'] = TaskPriority(data['priority'])
data['status'] = TaskStatus(data['status'])
return cls(**data)
def __lt__(self, other: 'TaskPayload') -> bool:
"""Priority queue comparison"""
if self.priority.value != other.priority.value:
return self.priority.value < other.priority.value
return self.scheduled_for < other.scheduled_for
# ============================================================================
# PRIORITY ROUTER
# ============================================================================
class PriorityRouter:
"""
Routes tasks to appropriate priority queues based on:
- Predicted virality
- Account load
- System capacity
- Business rules
"""
def __init__(self, redis_client: Redis):
self.redis = redis_client
self.rules: List[Callable] = []
# Virality thresholds
self.viral_threshold_high = 0.8 # 80%+ predicted virality -> HIGH
self.viral_threshold_medium = 0.5 # 50%+ -> NORMAL
# Load balancing thresholds
self.high_priority_max_queue_size = 100
self.account_max_concurrent_posts = 5
logger.info("PriorityRouter initialized")
def route_task(self, payload: TaskPayload) -> TaskPriority:
"""
Determine optimal priority for task based on multiple factors.
"""
# System critical tasks always get CRITICAL priority
if payload.task_type in [TaskType.HEALTH_CHECK, TaskType.BACKUP]:
return TaskPriority.CRITICAL
# Check custom routing rules first
for rule in self.rules:
priority = rule(payload)
if priority:
return priority
# Viral candidate routing
if payload.is_viral_candidate or payload.predicted_virality >= self.viral_threshold_high:
logger.info(f"Task {payload.task_id} routed to HIGH (viral candidate)")
return TaskPriority.HIGH
# Deadline-based routing
if payload.deadline:
time_until_deadline = payload.deadline - time.time()
if time_until_deadline < 300: # Less than 5 minutes
logger.warning(f"Task {payload.task_id} routed to HIGH (deadline approaching)")
return TaskPriority.HIGH
# Account load balancing - prevent single account overload
if payload.account_id:
account_load = self._get_account_load(payload.account_id)
if account_load >= self.account_max_concurrent_posts:
logger.info(f"Task {payload.task_id} routed to LOW (account overload)")
return TaskPriority.LOW
# Engagement-based routing
if payload.predicted_engagement >= self.viral_threshold_medium:
return TaskPriority.NORMAL
# Background/maintenance tasks
if payload.task_type in [TaskType.CLEANUP, TaskType.ANALYZE_PERFORMANCE]:
return TaskPriority.BACKGROUND
# Default
return payload.priority or TaskPriority.NORMAL
def _get_account_load(self, account_id: str) -> int:
"""Get current load for account"""
key = f"viral_factory:account_load:{account_id}"
return int(self.redis.get(key) or 0)
def add_routing_rule(self, rule: Callable[[TaskPayload], Optional[TaskPriority]]):
"""Add custom routing rule"""
self.rules.append(rule)
def get_recommended_queue(self, task_type: TaskType) -> str:
"""Get recommended queue for task type"""
# Route different task types to specialized queues if needed
routing_map = {
TaskType.GENERATE_SCRIPT: "content_generation",
TaskType.GENERATE_VISUALS: "content_generation",
TaskType.COMPOSE_VIDEO: "video_processing",
TaskType.POST_CONTENT: "distribution",
TaskType.OPTIMIZE_RL: "ml_training",
}
return routing_map.get(task_type, "default")
# ============================================================================
# WORKER MANAGER
# ============================================================================
@dataclass
class WorkerInfo:
"""Worker metadata and health info"""
worker_id: str
started_at: float
last_heartbeat: float
task_types: List[TaskType]
current_tasks: List[str]
total_completed: int = 0
total_failed: int = 0
cpu_percent: float = 0.0
memory_mb: float = 0.0
status: str = "active" # active, idle, overloaded, dead
class WorkerManager:
"""
Manages worker pool, load balancing, and dynamic scaling.
"""
def __init__(self, redis_client: Redis, namespace: str = "viral_factory"):
self.redis = redis_client
self.namespace = namespace
# Worker tracking
self.workers: Dict[str, WorkerInfo] = {}
self._lock = threading.RLock()
# Configuration
self.max_tasks_per_worker = 5
self.worker_timeout = 180 # 3 minutes no heartbeat = dead
self.scale_up_threshold = 0.8 # 80% capacity
self.scale_down_threshold = 0.2 # 20% capacity
# Monitoring
self._monitoring_active = False
self._monitor_thread: Optional[threading.Thread] = None
logger.info("WorkerManager initialized")
def register_worker(self, worker_id: str, task_types: List[TaskType]) -> bool:
"""Register new worker"""
with self._lock:
self.workers[worker_id] = WorkerInfo(
worker_id=worker_id,
started_at=time.time(),
last_heartbeat=time.time(),
task_types=task_types,
current_tasks=[]
)
# Store in Redis
key = f"{self.namespace}:worker:{worker_id}"
self.redis.setex(key, self.worker_timeout, json.dumps({
'worker_id': worker_id,
'task_types': [t.value for t in task_types],
'started_at': time.time()
}))
logger.info(f"Worker registered: {worker_id}")
return True
def update_heartbeat(self, worker_id: str, cpu: float = 0, memory: float = 0):
"""Update worker heartbeat and health metrics"""
with self._lock:
if worker_id not in self.workers:
logger.warning(f"Heartbeat from unknown worker: {worker_id}")
return False
worker = self.workers[worker_id]
worker.last_heartbeat = time.time()
worker.cpu_percent = cpu
worker.memory_mb = memory
# Update Redis
key = f"{self.namespace}:worker:{worker_id}:heartbeat"
self.redis.setex(key, self.worker_timeout, time.time())
return True
def assign_task(self, worker_id: str, task_id: str):
"""Assign task to worker"""
with self._lock:
if worker_id in self.workers:
self.workers[worker_id].current_tasks.append(task_id)
def complete_task(self, worker_id: str, task_id: str, success: bool = True):
"""Mark task completed by worker"""
with self._lock:
if worker_id in self.workers:
worker = self.workers[worker_id]
if task_id in worker.current_tasks:
worker.current_tasks.remove(task_id)
if success:
worker.total_completed += 1
else:
worker.total_failed += 1
def get_available_workers(self, task_type: TaskType) -> List[str]:
"""Get workers available for task type"""
available = []
with self._lock:
for worker_id, worker in self.workers.items():
# Check if worker handles this task type
if task_type not in worker.task_types:
continue
# Check if worker is alive
if time.time() - worker.last_heartbeat > self.worker_timeout:
continue
# Check if worker has capacity
if len(worker.current_tasks) >= self.max_tasks_per_worker:
continue
available.append(worker_id)
return available
def get_least_loaded_worker(self, task_type: TaskType) -> Optional[str]:
"""Get worker with lowest load for task type"""
available = self.get_available_workers(task_type)
if not available:
return None
# Sort by current load
with self._lock:
return min(available, key=lambda wid: len(self.workers[wid].current_tasks))
def get_cluster_capacity(self) -> Dict[str, Any]:
"""Get overall cluster capacity metrics"""
with self._lock:
total_workers = len(self.workers)
active_workers = sum(1 for w in self.workers.values()
if time.time() - w.last_heartbeat < self.worker_timeout)
total_capacity = active_workers * self.max_tasks_per_worker
current_load = sum(len(w.current_tasks) for w in self.workers.values())
return {
'total_workers': total_workers,
'active_workers': active_workers,
'total_capacity': total_capacity,
'current_load': current_load,
'utilization': current_load / total_capacity if total_capacity > 0 else 0,
'available_slots': total_capacity - current_load
}
def should_scale_up(self) -> bool:
"""Check if cluster should scale up"""
capacity = self.get_cluster_capacity()
return capacity['utilization'] >= self.scale_up_threshold
def should_scale_down(self) -> bool:
"""Check if cluster should scale down"""
capacity = self.get_cluster_capacity()
return capacity['utilization'] <= self.scale_down_threshold
def cleanup_dead_workers(self) -> int:
"""Remove dead workers from tracking"""
now = time.time()
removed = 0
with self._lock:
dead_workers = [
wid for wid, worker in self.workers.items()
if now - worker.last_heartbeat > self.worker_timeout
]
for wid in dead_workers:
logger.warning(f"Removing dead worker: {wid}")
del self.workers[wid]
removed += 1
return removed
def start_monitoring(self, interval: float = 30.0):
"""Start background worker monitoring"""
if self._monitoring_active:
return
self._monitoring_active = True
def monitor_loop():
while self._monitoring_active:
try:
# Cleanup dead workers
removed = self.cleanup_dead_workers()
if removed > 0:
logger.warning(f"Cleaned up {removed} dead workers")
# Check scaling needs
if self.should_scale_up():
logger.info("Cluster needs to scale UP")
elif self.should_scale_down():
logger.info("Cluster can scale DOWN")
# Log capacity
capacity = self.get_cluster_capacity()
logger.info(f"Cluster capacity: {capacity['current_load']}/{capacity['total_capacity']} "
f"({capacity['utilization']:.1%} utilization)")
except Exception as e:
logger.error(f"Worker monitoring error: {e}")
time.sleep(interval)
self._monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
self._monitor_thread.start()
logger.info("Worker monitoring started")
def stop_monitoring(self):
"""Stop worker monitoring"""
self._monitoring_active = False
if self._monitor_thread:
self._monitor_thread.join(timeout=5)
# ============================================================================
# SCHEDULER
# ============================================================================
@dataclass
class ScheduledTask:
"""Scheduled task configuration"""
schedule_id: str
task_type: TaskType
metadata: Dict[str, Any]
priority: TaskPriority
cron_pattern: Optional[str] = None # Future: cron-like scheduling
interval_seconds: Optional[int] = None # Recurring interval
scheduled_time: Optional[float] = None # One-time execution
last_run: Optional[float] = None
next_run: Optional[float] = None
enabled: bool = True
class Scheduler:
"""
Task scheduler with support for:
- One-time scheduled tasks
- Recurring tasks (interval-based)
- Burst pacing
- Account load balancing
- SLA deadline enforcement
"""
def __init__(self, redis_client: Redis, task_queue_manager, namespace: str = "viral_factory"):
self.redis = redis_client
self.queue_manager = task_queue_manager
self.namespace = namespace
# Scheduled tasks
self.schedules: Dict[str, ScheduledTask] = {}
self._lock = threading.RLock()
# Burst pacing configuration
self.burst_window_seconds = 3600 # 1 hour
self.max_posts_per_burst = 10
self.min_post_interval = 60 # 1 minute between posts
# Scheduler thread
self._scheduler_active = False
self._scheduler_thread: Optional[threading.Thread] = None
logger.info("Scheduler initialized")
def schedule_task(self, task_type: TaskType, metadata: Dict[str, Any],
scheduled_time: float, priority: TaskPriority = TaskPriority.NORMAL,
deadline: Optional[float] = None) -> str:
"""Schedule task for one-time execution"""
schedule_id = str(uuid.uuid4())
scheduled_task = ScheduledTask(
schedule_id=schedule_id,
task_type=task_type,
metadata=metadata,
priority=priority,
scheduled_time=scheduled_time,
next_run=scheduled_time
)
with self._lock:
self.schedules[schedule_id] = scheduled_task
# Store in Redis
key = f"{self.namespace}:schedule:{schedule_id}"
self.redis.set(key, json.dumps({
'schedule_id': schedule_id,
'task_type': task_type.value,
'metadata': metadata,
'scheduled_time': scheduled_time,
'deadline': deadline
}))
logger.info(f"Task scheduled: {schedule_id} for {datetime.fromtimestamp(scheduled_time)}")
return schedule_id
def schedule_recurring(self, task_type: TaskType, metadata: Dict[str, Any],
interval_seconds: int, priority: TaskPriority = TaskPriority.NORMAL) -> str:
"""Schedule recurring task"""
schedule_id = str(uuid.uuid4())
next_run = time.time() + interval_seconds
scheduled_task = ScheduledTask(
schedule_id=schedule_id,
task_type=task_type,
metadata=metadata,
priority=priority,
interval_seconds=interval_seconds,
next_run=next_run
)
with self._lock:
self.schedules[schedule_id] = scheduled_task
logger.info(f"Recurring task scheduled: {schedule_id} (every {interval_seconds}s)")
return schedule_id
def cancel_schedule(self, schedule_id: str) -> bool:
"""Cancel scheduled task"""
with self._lock:
if schedule_id in self.schedules:
del self.schedules[schedule_id]
# Remove from Redis
key = f"{self.namespace}:schedule:{schedule_id}"
self.redis.delete(key)
logger.info(f"Schedule cancelled: {schedule_id}")
return True
return False
def check_burst_limit(self, account_id: str) -> bool:
"""Check if account can post within burst limits"""
key = f"{self.namespace}:burst:{account_id}"
# Get recent post count
now = time.time()
window_start = now - self.burst_window_seconds
# Use sorted set to track posts in time window
self.redis.zremrangebyscore(key, 0, window_start) # Cleanup old
recent_count = self.redis.zcard(key)
return recent_count < self.max_posts_per_burst
def record_post(self, account_id: str, task_id: str):
"""Record post for burst tracking"""
key = f"{self.namespace}:burst:{account_id}"
self.redis.zadd(key, {task_id: time.time()})
self.redis.expire(key, self.burst_window_seconds * 2)
def enforce_min_interval(self, account_id: str) -> float:
"""Get delay needed to enforce minimum interval"""
key = f"{self.namespace}:last_post:{account_id}"
last_post_time = self.redis.get(key)
if not last_post_time:
return 0
last_post_time = float(last_post_time)
elapsed = time.time() - last_post_time
if elapsed < self.min_post_interval:
return self.min_post_interval - elapsed
return 0
def _process_schedules(self):
"""Process scheduled tasks (called by scheduler thread)"""
now = time.time()
tasks_to_submit = []
with self._lock:
for schedule in list(self.schedules.values()):
if not schedule.enabled:
continue
if schedule.next_run and now >= schedule.next_run:
# Check burst limits for account-based tasks
account_id = schedule.metadata.get('account_id')
if account_id:
if not self.check_burst_limit(account_id):
logger.warning(f"Burst limit reached for account {account_id}, delaying")
schedule.next_run = now + 300 # Delay 5 minutes
continue
# Check minimum interval
delay = self.enforce_min_interval(account_id)
if delay > 0:
logger.info(f"Enforcing min interval for {account_id}, delay {delay}s")
schedule.next_run = now + delay
continue
tasks_to_submit.append(schedule)
# Submit tasks outside lock
for schedule in tasks_to_submit:
try:
# Submit to queue
task_id = self.queue_manager.submit_task(
task_type=schedule.task_type,
metadata=schedule.metadata,
priority=schedule.priority,
scheduled_for=time.time()
)
logger.info(f"Scheduled task submitted: {task_id}")
# Record post for burst tracking
account_id = schedule.metadata.get('account_id')
if account_id:
self.record_post(account_id, task_id)
key = f"{self.namespace}:last_post:{account_id}"
self.redis.setex(key, self.burst_window_seconds, time.time())
# Update schedule
with self._lock:
schedule.last_run = now
if schedule.interval_seconds:
# Recurring - schedule next run
schedule.next_run = now + schedule.interval_seconds
else:
# One-time - disable
schedule.enabled = False
self.cancel_schedule(schedule.schedule_id)
except Exception as e:
logger.error(f"Error submitting scheduled task: {e}")
def start(self, check_interval: float = 10.0):
"""Start scheduler background thread"""
if self._scheduler_active:
return
self._scheduler_active = True
def scheduler_loop():
while self._scheduler_active:
try:
self._process_schedules()
except Exception as e:
logger.error(f"Scheduler error: {e}")
time.sleep(check_interval)
self._scheduler_thread = threading.Thread(target=scheduler_loop, daemon=True)
self._scheduler_thread.start()
logger.info("Scheduler started")
def stop(self):
"""Stop scheduler"""
self._scheduler_active = False
if self._scheduler_thread:
self._scheduler_thread.join(timeout=5)
logger.info("Scheduler stopped")
# ============================================================================
# REDIS-BACKED PERSISTENT QUEUE (Enhanced with Integration Points)
# ============================================================================
class RedisTaskQueue:
"""
Production-grade distributed task queue with Redis backend.
Enhanced with full orchestration integration.
"""
# Redis key prefixes
TASK_KEY = "task:{task_id}"
QUEUE_KEY = "queue:{priority}:{task_type}"
DELAYED_QUEUE = "delayed_tasks"
CLAIMED_KEY = "claimed:{worker_id}"
DEPENDENCY_KEY = "dependencies:{task_id}"
STATUS_INDEX = "status:{status}"
TYPE_INDEX = "type:{task_type}"
WORKER_HEARTBEAT = "worker:{worker_id}:heartbeat"
METRICS_KEY = "metrics:queue"
LEASE_LOCK = "lease:lock"
SHUTDOWN_KEY = "system:shutdown"
def __init__(self, redis_client: Redis, namespace: str = "viral_factory"):
self.redis = redis_client
self.namespace = namespace
self._lock = threading.RLock()
# Local cache
self._local_cache: Dict[str, Tuple[TaskPayload, float]] = {}
self._cache_ttl = 60
# Hooks
self._enqueue_hooks: List[Callable] = []
self._dequeue_hooks: List[Callable] = []
self._status_hooks: List[Callable] = []
self._claim_hooks: List[Callable] = []
# Metrics
self._local_metrics = {
'total_enqueued': 0,
'total_dequeued': 0,
'total_completed': 0,
'total_failed': 0,
'total_retries': 0,
'total_timeouts': 0
}
# Worker settings
self.lease_duration = 300
self.heartbeat_interval = 30
# Maintenance
self._maintenance_active = False
self._maintenance_thread: Optional[threading.Thread] = None
# Shutdown flag
self._shutdown_requested = False
logger.info("RedisTaskQueue initialized")
def _key(self, pattern: str, **kwargs) -> str:
"""Generate namespaced Redis key"""
key = pattern.format(**kwargs)
return f"{self.namespace}:{key}"
def enqueue(self, payload: TaskPayload) -> str:
"""Enqueue task with validation and persistence"""
# Check shutdown
if self._shutdown_requested:
raise RuntimeError("Queue is shutting down, not accepting new tasks")
if not payload.task_id:
payload.task_id = str(uuid.uuid4())
self._validate_payload(payload)
# Idempotency check
if payload.idempotency_key:
existing = self._check_idempotency(payload.idempotency_key)
if existing:
return existing
payload.status = TaskStatus.QUEUED
# Store task
task_key = self._key(self.TASK_KEY, task_id=payload.task_id)
self.redis.set(task_key, pickle.dumps(payload))
# Handle scheduling
now = time.time()
if payload.scheduled_for > now:
self.redis.zadd(
self._key(self.DELAYED_QUEUE),
{payload.task_id: payload.scheduled_for}
)
payload.status = TaskStatus.PENDING
else:
self._add_to_queue(payload)
self._index_task(payload)
# Dependencies
if payload.depends_on:
dep_key = self._key(self.DEPENDENCY_KEY, task_id=payload.task_id)
self.redis.sadd(dep_key, *payload.depends_on)
payload.status = TaskStatus.BLOCKED
self._increment_metric('total_enqueued')
# Idempotency mapping
if payload.idempotency_key:
idem_key = self._key(f"idempotency:{payload.idempotency_key}")
self.redis.setex(idem_key, 86400, payload.task_id)
self._execute_hooks(self._enqueue_hooks, payload)
logger.debug(f"Task enqueued: {payload.task_id}")
return payload.task_id
def _validate_payload(self, payload: TaskPayload):
"""Validate task payload"""
if not payload.task_type:
raise ValueError("task_type is required")
if payload.scheduled_for < 0:
raise ValueError("scheduled_for must be >= 0")
if payload.max_retries < 0:
raise ValueError("max_retries must be >= 0")
if payload.timeout_seconds <= 0:
raise ValueError("timeout_seconds must be > 0")
def _check_idempotency(self, idempotency_key: str) -> Optional[str]:
"""Check if task with idempotency key exists"""
idem_key = self._key(f"idempotency:{idempotency_key}")
result = self.redis.get(idem_key)
return result.decode('utf-8') if result else None
def _add_to_queue(self, payload: TaskPayload):
"""Add task to priority queue"""
queue_key = self._key(
self.QUEUE_KEY,
priority=payload.priority.value,
task_type=payload.task_type.value
)
self.redis.zadd(queue_key, {payload.task_id: payload.scheduled_for})
def _index_task(self, payload: TaskPayload):
"""Add task to indexes"""
status_key = self._key(self.STATUS_INDEX, status=payload.status.value)
self.redis.sadd(status_key, payload.task_id)
type_key = self._key(self.TYPE_INDEX, task_type=payload.task_type.value)
self.redis.sadd(type_key, payload.task_id)
def _remove_from_indexes(self, payload: TaskPayload):
"""Remove task from indexes"""
status_key = self._key(self.STATUS_INDEX, status=payload.status.value)
self.redis.srem(status_key, payload.task_id)
type_key = self._key(self.TYPE_INDEX, task_type=payload.task_type.value)
self.redis.srem(type_key, payload.task_id)
def dequeue(self, worker_id: str, task_types: Optional[List[TaskType]] = None) -> Optional[TaskPayload]:
"""Dequeue next task for worker"""
self._update_worker_heartbeat(worker_id)
for priority in TaskPriority:
task_types_to_check = task_types or list(TaskType)
for task_type in task_types_to_check:
queue_key = self._key(
self.QUEUE_KEY,
priority=priority.value,
task_type=task_type.value
)
result = self.redis.zpopmin(queue_key, 1)
if not result:
continue
task_id = result[0][0].decode('utf-8') if isinstance(result[0][0], bytes) else result[0][0]
task = self._claim_task(task_id, worker_id)
if task:
self._execute_hooks(self._dequeue_hooks, task)
self._increment_metric('total_dequeued')
return task
return None
def claim_next_task(self, worker_id: str, task_types: Optional[List[TaskType]] = None) -> Optional[TaskPayload]:
"""Claim next available task"""
self._process_delayed_tasks()
task = self.dequeue(worker_id, task_types)
if not task:
return None
if not self.are_dependencies_met(task):
task.status = TaskStatus.BLOCKED
self._add_to_queue(task)
self._store_task(task)
return None
return task
def _claim_task(self, task_id: str, worker_id: str) -> Optional[TaskPayload]:
"""Atomically claim task"""
lock_key = self._key(self.LEASE_LOCK)
lock = self.redis.lock(lock_key, timeout=5)
try:
if not lock.acquire(blocking=True, blocking_timeout=5):
return None
task = self._get_task(task_id)
if not task:
return None
if task.status in [TaskStatus.CLAIMED, TaskStatus.RUNNING]:
return None
task.status = TaskStatus.CLAIMED
task.worker_id = worker_id
task.claimed_at = time.time()
task.lease_expires_at = time.time() + self.lease_duration
self._store_task(task)
claimed_key = self._key(self.CLAIMED_KEY, worker_id=worker_id)
self.redis.sadd(claimed_key, task_id)
self._execute_hooks(self._claim_hooks, task)
return task
finally:
lock.release()
def start_task(self, task_id: str) -> bool:
"""Mark task as running"""
task = self._get_task(task_id)
if not task:
return False
old_status = task.status
task.status = TaskStatus.RUNNING
task.started_at = time.time()
self._update_task_status(task, old_status)
return True
def complete_task(self, task_id: str, result: Optional[Dict] = None) -> bool:
"""Mark task completed"""
task = self._get_task(task_id)
if not task:
return False
old_status = task.status
task.status = TaskStatus.COMPLETED
task.completed_at = time.time()
task.result = result
if task.worker_id:
claimed_key = self._key(self.CLAIMED_KEY, worker_id=task.worker_id)
self.redis.srem(claimed_key, task_id)
self._update_task_status(task, old_status)
self._increment_metric('total_completed')
self._unblock_dependents(task_id)
return True
def fail_task(self, task_id: str, error: str, should_retry: bool = True) -> bool:
"""Mark task failed with retry logic"""
task = self._get_task(task_id)
if not task:
return False
old_status = task.status
task.last_error = error
task.error_history.append({
'error': error,
'timestamp': time.time(),
'attempt': task.retry_count
})
if should_retry and task.retry_count < task.max_retries:
task.retry_count += 1
task.status = TaskStatus.RETRY
delay = min(300, 2 ** task.retry_count * 5)
task.scheduled_for = time.time() + delay
self.redis.zadd(
self._key(self.DELAYED_QUEUE),
{task_id: task.scheduled_for}
)
self._increment_metric('total_retries')
else:
task.status = TaskStatus.FAILED
self._increment_metric('total_failed')
if task.worker_id:
claimed_key = self._key(self.CLAIMED_KEY, worker_id=task.worker_id)
self.redis.srem(claimed_key, task_id)
task.worker_id = None
self._update_task_status(task, old_status)
return True
def timeout_task(self, task_id: str) -> bool:
"""Mark task timed out"""
task = self._get_task(task_id)
if not task:
return False
old_status = task.status
task.status = TaskStatus.TIMEOUT
task.last_error = "Task execution timeout"
if task.worker_id:
claimed_key = self._key(self.CLAIMED_KEY, worker_id=task.worker_id)
self.redis.srem(claimed_key, task_id)
self._update_task_status(task, old_status)
self._increment_metric('total_timeouts')
return self.fail_task(task_id, "Execution timeout", should_retry=True)
def cancel_task(self, task_id: str) -> bool:
"""Cancel task"""
task = self._get_task(task_id)
if not task:
return False
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
return False
old_status = task.status
task.status = TaskStatus.CANCELLED
for priority in TaskPriority:
for task_type in TaskType:
queue_key = self._key(
self.QUEUE_KEY,
priority=priority.value,
task_type=task_type.value
)
self.redis.zrem(queue_key, task_id)
self.redis.zrem(self._key(self.DELAYED_QUEUE), task_id)
self._update_task_status(task, old_status)
return True
def _update_task_status(self, task: TaskPayload, old_status: TaskStatus):
"""Update task status"""
old_key = self._key(self.STATUS_INDEX, status=old_status.value)
self.redis.srem(old_key, task.task_id)
new_key = self._key(self.STATUS_INDEX, status=task.status.value)
self.redis.sadd(new_key, task.task_id)
self._store_task(task)
self._execute_hooks(self._status_hooks, task)
def are_dependencies_met(self, task: TaskPayload) -> bool:
"""Check if dependencies are met"""
if not task.depends_on:
return True
for dep_id in task.depends_on:
dep_task = self._get_task(dep_id)
if not dep_task or dep_task.status != TaskStatus.COMPLETED:
return False
return True
def _unblock_dependents(self, task_id: str):
"""Unblock dependent tasks"""
for status in [TaskStatus.BLOCKED, TaskStatus.PENDING]:
status_key = self._key(self.STATUS_INDEX, status=status.value)
blocked_ids = self.redis.smembers(status_key)
for blocked_id in blocked_ids:
if isinstance(blocked_id, bytes):
blocked_id = blocked_id.decode('utf-8')
blocked_task = self._get_task(blocked_id)
if not blocked_task:
continue
if task_id in blocked_task.depends_on:
if self.are_dependencies_met(blocked_task):
blocked_task.status = TaskStatus.QUEUED
self._add_to_queue(blocked_task)
self._update_task_status(blocked_task, TaskStatus.BLOCKED)
def _update_worker_heartbeat(self, worker_id: str):
"""Update worker heartbeat"""
heartbeat_key = self._key(self.WORKER_HEARTBEAT, worker_id=worker_id)
self.redis.setex(heartbeat_key, self.heartbeat_interval * 3, time.time())
def renew_lease(self, task_id: str, worker_id: str) -> bool:
"""Renew task lease"""
task = self._get_task(task_id)
if not task or task.worker_id != worker_id:
return False
task.lease_expires_at = time.time() + self.lease_duration
self._store_task(task)
self._update_worker_heartbeat(worker_id)
return True
def reclaim_expired_leases(self) -> int:
"""Reclaim expired leases"""
now = time.time()
reclaimed = 0
for status in [TaskStatus.CLAIMED, TaskStatus.RUNNING]:
status_key = self._key(self.STATUS_INDEX, status=status.value)
task_ids = self.redis.smembers(status_key)
for task_id in task_ids:
if isinstance(task_id, bytes):
task_id = task_id.decode('utf-8')
task = self._get_task(task_id)
if not task:
continue
if task.lease_expires_at and now > task.lease_expires_at:
if task.worker_id:
claimed_key = self._key(self.CLAIMED_KEY, worker_id=task.worker_id)
self.redis.srem(claimed_key, task_id)
task.worker_id = None
task.claimed_at = None
task.lease_expires_at = None
task.status = TaskStatus.QUEUED
self._add_to_queue(task)
self._update_task_status(task, status)
reclaimed += 1
return reclaimed
def _process_delayed_tasks(self):
"""Move delayed tasks to active queue"""
now = time.time()
delayed_key = self._key(self.DELAYED_QUEUE)
ready_tasks = self.redis.zrangebyscore(delayed_key, 0, now)
for task_id in ready_tasks:
if isinstance(task_id, bytes):
task_id = task_id.decode('utf-8')
task = self._get_task(task_id)
if not task:
self.redis.zrem(delayed_key, task_id)
continue
task.status = TaskStatus.QUEUED
self._add_to_queue(task)
self._store_task(task)
self.redis.zrem(delayed_key, task_id)
def _get_task(self, task_id: str) -> Optional[TaskPayload]:
"""Get task with caching"""
if task_id in self._local_cache:
cached, timestamp = self._local_cache[task_id]
if time.time() - timestamp < self._cache_ttl:
return cached
task_key = self._key(self.TASK_KEY, task_id=task_id)
data = self.redis.get(task_key)
if not data:
return None
task = pickle.loads(data)
self._local_cache[task_id] = (task, time.time())
return task
def _store_task(self, task: TaskPayload):
"""Store task in Redis"""
task_key = self._key(self.TASK_KEY, task_id=task.task_id)
self.redis.set(task_key, pickle.dumps(task))
self._local_cache[task.task_id] = (task, time.time())
def get_task(self, task_id: str) -> Optional[TaskPayload]:
"""Public API to get task"""
return self._get_task(task_id)
def get_tasks_by_status(self, status: TaskStatus, limit: int = 100) -> List[TaskPayload]:
"""Get tasks by status"""
status_key = self._key(self.STATUS_INDEX, status=status.value)
task_ids = self.redis.smembers(status_key)
tasks = []
for task_id in list(task_ids)[:limit]:
if isinstance(task_id, bytes):
task_id = task_id.decode('utf-8')
task = self._get_task(task_id)
if task:
tasks.append(task)
return tasks
def _increment_metric(self, metric: str, amount: int = 1):
"""Increment metric"""
self._local_metrics[metric] = self._local_metrics.get(metric, 0) + amount
metrics_key = self._key(self.METRICS_KEY)
self.redis.hincrby(metrics_key, metric, amount)
def get_metrics(self) -> Dict[str, Any]:
"""Get queue metrics"""
metrics_key = self._key(self.METRICS_KEY)
redis_metrics = self.redis.hgetall(metrics_key)
decoded = {
k.decode('utf-8') if isinstance(k, bytes) else k:
int(v.decode('utf-8')) if isinstance(v, bytes) else int(v)
for k, v in redis_metrics.items()
}
for key, value in self._local_metrics.items():
decoded[key] = decoded.get(key, 0) + value
return decoded
def _execute_hooks(self, hooks: List[Callable], *args):
"""Execute hooks"""
for hook in hooks:
try:
hook(*args)
except Exception as e:
logger.error(f"Hook execution error: {e}")
def register_enqueue_hook(self, hook: Callable):
self._enqueue_hooks.append(hook)
def register_status_hook(self, hook: Callable):
self._status_hooks.append(hook)
def start_maintenance(self, interval: float = 60.0):
"""Start maintenance thread"""
if self._maintenance_active:
return
self._maintenance_active = True
def maintenance_loop():
while self._maintenance_active:
try:
reclaimed = self.reclaim_expired_leases()
if reclaimed > 0:
logger.info(f"Reclaimed {reclaimed} expired leases")
self._process_delayed_tasks()
except Exception as e:
logger.error(f"Maintenance error: {e}")
time.sleep(interval)
self._maintenance_thread = threading.Thread(target=maintenance_loop, daemon=True)
self._maintenance_thread.start()
logger.info("Maintenance started")
def stop_maintenance(self):
"""Stop maintenance"""
self._maintenance_active = False
if self._maintenance_thread:
self._maintenance_thread.join(timeout=5)
def request_shutdown(self):
"""Request graceful shutdown"""
self._shutdown_requested = True
shutdown_key = self._key(self.SHUTDOWN_KEY)
self.redis.set(shutdown_key, "1", ex=3600)
logger.warning("Shutdown requested")
def cancel_all_pending(self) -> int:
"""Cancel all pending/queued tasks"""
cancelled = 0
for status in [TaskStatus.PENDING, TaskStatus.QUEUED]:
tasks = self.get_tasks_by_status(status, limit=10000)
for task in tasks:
if self.cancel_task(task.task_id):
cancelled += 1
logger.info(f"Cancelled {cancelled} pending tasks")
return cancelled
# ============================================================================
# HIGH-LEVEL ORCHESTRATION MANAGER
# ============================================================================
class TaskQueueManager:
"""
High-level orchestration manager coordinating all components.
"""
def __init__(self, redis_url: str = "redis://localhost:6379/0"):
# Initialize Redis
self.redis_client = redis.from_url(redis_url, decode_responses=False)
# Core queue
self.queue = RedisTaskQueue(self.redis_client)
# Orchestration components
self.priority_router = PriorityRouter(self.redis_client)
self.worker_manager = WorkerManager(self.redis_client)
self.scheduler = Scheduler(self.redis_client, self)
# Start background services
self.queue.start_maintenance()
self.worker_manager.start_monitoring()
self.scheduler.start()
logger.info("TaskQueueManager initialized with full orchestration")
def submit_task(self, task_type: TaskType, metadata: Dict[str, Any],
priority: Optional[TaskPriority] = None,
depends_on: List[str] = None,
scheduled_for: Optional[float] = None,
deadline: Optional[float] = None) -> str:
"""Submit task with automatic priority routing"""
# Create payload
payload = TaskPayload(
task_id="",
task_type=task_type,
priority=priority or TaskPriority.NORMAL,
created_at=time.time(),
scheduled_for=scheduled_for or time.time(),
depends_on=depends_on or [],
metadata=metadata,
deadline=deadline
)
# Route priority if not explicitly set
if not priority:
payload.priority = self.priority_router.route_task(payload)
logger.info(f"Task auto-routed to priority: {payload.priority.value}")
# Submit to queue
return self.queue.enqueue(payload)
def get_next_task(self, worker_id: str, task_types: Optional[List[TaskType]] = None) -> Optional[TaskPayload]:
"""Get next task for worker with load balancing"""
# Update worker heartbeat
self.worker_manager.update_heartbeat(worker_id)
# Check worker capacity
worker_tasks = self.worker_manager.workers.get(worker_id)
if worker_tasks and len(worker_tasks.current_tasks) >= self.worker_manager.max_tasks_per_worker:
logger.warning(f"Worker {worker_id} at capacity")
return None
# Get task
task = self.queue.claim_next_task(worker_id, task_types)
if task:
self.worker_manager.assign_task(worker_id, task.task_id)
logger.info(f"Task {task.task_id} assigned to worker {worker_id}")
return task
def mark_completed(self, task_id: str, worker_id: str, result: Dict = None):
"""Mark task completed"""
self.queue.complete_task(task_id, result)
self.worker_manager.complete_task(worker_id, task_id, success=True)
def mark_failed(self, task_id: str, worker_id: str, error: str):
"""Mark task failed"""
self.queue.fail_task(task_id, error)
self.worker_manager.complete_task(worker_id, task_id, success=False)
def schedule_task(self, task_type: TaskType, metadata: Dict[str, Any],
scheduled_time: float, priority: TaskPriority = TaskPriority.NORMAL) -> str:
"""Schedule task for future execution"""
return self.scheduler.schedule_task(task_type, metadata, scheduled_time, priority)
def schedule_recurring(self, task_type: TaskType, metadata: Dict[str, Any],
interval_seconds: int, priority: TaskPriority = TaskPriority.NORMAL) -> str:
"""Schedule recurring task"""
return self.scheduler.schedule_recurring(task_type, metadata, interval_seconds, priority)
def get_stats(self) -> Dict[str, Any]:
"""Get comprehensive system statistics"""
return {
'queue_metrics': self.queue.get_metrics(),
'worker_capacity': self.worker_manager.get_cluster_capacity(),
'active_workers': len(self.worker_manager.workers),
'scheduled_tasks': len(self.scheduler.schedules)
}
def graceful_shutdown(self, timeout: float = 60.0):
"""
Graceful shutdown with proper cleanup.
1. Stop accepting new tasks
2. Wait for running tasks to complete
3. Cancel pending tasks
4. Stop all background threads
"""
logger.warning("=== INITIATING GRACEFUL SHUTDOWN ===")
# Request shutdown
self.queue.request_shutdown()
# Stop scheduler from creating new tasks
self.scheduler.stop()
logger.info("Scheduler stopped")
# Wait for running tasks
start_time = time.time()
while time.time() - start_time < timeout:
running_tasks = self.queue.get_tasks_by_status(TaskStatus.RUNNING)
if not running_tasks:
break
logger.info(f"Waiting for {len(running_tasks)} running tasks...")
time.sleep(5)
# Cancel remaining tasks
cancelled = self.queue.cancel_all_pending()
logger.info(f"Cancelled {cancelled} pending tasks")
# Stop background services
self.queue.stop_maintenance()
self.worker_manager.stop_monitoring()
# Close Redis connection
self.redis_client.close()
logger.warning("=== SHUTDOWN COMPLETE ===")
# ============================================================================
# EXAMPLE USAGE
# ============================================================================
if __name__ == "__main__":
# Initialize full system
manager = TaskQueueManager("redis://localhost:6379/0")
# Register custom routing rule
def high_value_rule(task: TaskPayload) -> Optional[TaskPriority]:
if task.metadata.get('high_value'):
return TaskPriority.HIGH
return None
manager.priority_router.add_routing_rule(high_value_rule)
# Register worker
manager.worker_manager.register_worker("worker_001", [TaskType.GENERATE_SCRIPT])
# Submit immediate task
task_id = manager.submit_task(
task_type=TaskType.GENERATE_SCRIPT,
metadata={"niche": "tech", "high_value": True},
deadline=time.time() + 3600
)
logger.info(f"Submitted task: {task_id}")
# Schedule future task
future_task_id = manager.schedule_task(
task_type=TaskType.POST_CONTENT,
metadata={"platform": "tiktok", "account_id": "acc_001"},
scheduled_time=time.time() + 300, # 5 minutes
priority=TaskPriority.HIGH
)
logger.info(f"Scheduled task: {future_task_id}")
# Schedule recurring health check
recurring_id = manager.schedule_recurring(
task_type=TaskType.HEALTH_CHECK,
metadata={},
interval_seconds=3600, # Every hour
priority=TaskPriority.CRITICAL
)
logger.info(f"Recurring task: {recurring_id}")
# Worker gets task
task = manager.get_next_task("worker_001")
if task:
logger.info(f"Worker got task: {task.task_id}")
# Simulate work
time.sleep(2)
# Complete
manager.mark_completed(task.task_id, "worker_001", {"status": "success"})
logger.info("Task completed")
# Get statistics
stats = manager.get_stats()
logger.info(f"System stats: {stats}")
# Graceful shutdown
time.sleep(5)
manager.graceful_shutdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment