Skip to content

Instantly share code, notes, and snippets.

@BasPH
Last active July 20, 2022 12:27
Show Gist options
  • Save BasPH/e2ab6cf75e6a659dcb3df10d359d6439 to your computer and use it in GitHub Desktop.
Save BasPH/e2ab6cf75e6a659dcb3df10d359d6439 to your computer and use it in GitHub Desktop.
Proposed Airflow SLA checker
import logging
import random
import string
import threading
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, Set
logging.basicConfig(level=logging.INFO, format="%(threadName)s %(message)s")
@dataclass(frozen=True)
class SLACheck:
run_id: str
task_id: str
timestamp: int
class SLAManager:
"""
This class is responsible for maintaining a calendar of checks for SLAs.
It keeps a dictionary of {timestamp, {sla checks}} and waits until a next timestamp is reached. Once
reached, it executes the scheduled SLA checks and waits again until the next timestamp.
"""
_thread = None
_event = None
_sla_calendar: Dict[float, Set[SLACheck]] = defaultdict(set)
_stop_flag = False
def start(self):
logging.info("Starting %s", self.__class__.__name__)
self._event = threading.Event()
self._thread = threading.Thread(target=self._main_loop)
self._thread.daemon = True # Stop this thread if main thread dies
self._thread.start()
def shutdown(self):
logging.info("Shutting down.")
logging.info("There were %s scheduled SLA checks.", len(self))
self._stop_flag = True
self._event.set()
self._thread.join()
del self._thread
def schedule_sla_check(self, timestamp_: int, run_id: str, task_id: str):
"""
Schedule an SLA check at the given timestamp.
:param timestamp_: The timestamp at which to schedule an SLA check.
:param run_id: The run_id of the DAG run to check.
:param task_id: The task_id of the task to check.
"""
# Determine if waiting time should be reset when a check is scheduled prior to the currently earliest
# scheduled check.
should_reset_wait_interval = not self._sla_calendar.keys() or timestamp_ < min(
self._sla_calendar.keys()
)
sla_check = SLACheck(run_id=run_id, task_id=task_id, timestamp=timestamp_)
self._sla_calendar[timestamp_].add(sla_check)
logging.info(
"Added SLA check %s at timestamp %s (in %s seconds). Total # of scheduled checks = %s.",
sla_check,
timestamp_,
timestamp_ - time.time(),
len(self),
)
if should_reset_wait_interval:
logging.info("Added SLA check prior to the earliest timestamp. Resetting wait interval.")
self._event.set() # Set internal flag to True (cancelling any wait() that might happen now)
def _main_loop(self):
"""Infinitely running loop which waits until there's work to do."""
wait_seconds = threading.TIMEOUT_MAX
while True:
if wait_seconds == threading.TIMEOUT_MAX:
logging.info("No SLAs to check. Waiting till the end of time.")
else:
logging.info("Waiting %s seconds...", wait_seconds)
self._event.wait(wait_seconds) # Block until timeout or internal flag is set to True
if self._stop_flag:
logging.info("Exiting main loop.")
break
self._event.clear() # Reset internal flag to False
wait_seconds = self.do_work()
def __len__(self) -> int:
"""Return the total number of scheduled SLA checks."""
return sum(len(_) for _ in self._sla_calendar.values())
def do_work(self) -> float:
"""
Check SLAs and return the number of seconds to wait for the next SLA check.
:return: Number of seconds to wait for the next SLA check.
"""
first_timestamp = min(self._sla_calendar.keys())
if first_timestamp <= time.time():
to_print = self._sla_calendar[first_timestamp]
for s in to_print:
logging.info(
"Checking SLA at timestamp %s (current timestamp = %s, diff = %s seconds) -> %s",
first_timestamp,
time.time(),
time.time() - first_timestamp,
s,
)
self._sla_calendar.pop(first_timestamp, None)
logging.info(
"Processed SLA check at timestamp %s. Removed %s SLA check(s). SLA checks remaining = %s.",
first_timestamp,
len(to_print),
len(self),
)
if self._sla_calendar:
next_timestamp = min(self._sla_calendar.keys())
wait_seconds = next_timestamp - time.time()
return wait_seconds
else:
return threading.TIMEOUT_MAX
if __name__ == "__main__":
sla_manager = SLAManager()
sla_manager.start()
try:
while True:
# Simulate activity to keep main thread alive
sla_checks_to_add = 3
max_seconds = 30
for _ in range(sla_checks_to_add):
random_timestamp = int(time.time() + random.randint(1, max_seconds))
random_run_id = ''.join(random.choice(string.ascii_letters) for _ in range(10))
random_task_id = ''.join(random.choice(string.ascii_letters) for _ in range(10))
sla_manager.schedule_sla_check(
timestamp_=random_timestamp, run_id=random_run_id, task_id=random_task_id
)
time.sleep(5)
except (KeyboardInterrupt, SystemExit):
sla_manager.shutdown()
@BasPH
Copy link
Author

BasPH commented Jul 20, 2022

Note: SLAManager._sla_calendar (dict) is not thread safe. Should be accessed using threading.Lock. Thread-unsafe error can be simulated by setting max_seconds very low and sla_checks_to_add very high.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment