Last active
August 24, 2020 06:49
-
-
Save pingzh/cc44c97336560b658d012c225a2242cc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
from airflow.configuration import conf | |
from airflow.utils.log.logging_mixin import LoggingMixin | |
class JobDispatcherExecutor(LoggingMixin): | |
def __init__(self, celery_executor, kubernetes_executor): | |
""" | |
""" | |
self.celery_executor = celery_executor | |
self.kubernetes_executor = kubernetes_executor | |
self.is_sensor_service_enabled = conf.getboolean('smart_sensor', 'use_smart_sensor') | |
self.operators_support_sensor_service = set( | |
map(lambda l: l.strip(), conf.get('smart_sensor', 'sensors_enabled').split(',')) | |
) | |
@property | |
def queued_tasks(self): | |
queued_tasks = self.celery_executor.queued_tasks.copy() | |
queued_tasks.update(self.kubernetes_executor.queued_tasks) | |
return queued_tasks | |
@property | |
def running(self): | |
running = self.celery_executor.running.copy() | |
running.update(self.kubernetes_executor.running) | |
return running | |
def start(self): # pragma: no cover | |
""" | |
Executors may need to get things started. For example LocalExecutor | |
starts N workers. | |
""" | |
self.celery_executor.start() | |
self.kubernetes_executor.start() | |
def queue_command(self, simple_task_instance, command, priority=1, queue=None): | |
# queue command based on operator | |
task_operator = simple_task_instance.operator | |
executor = self._router(task_operator) | |
self.log.debug("Using executor: %s for %s, with operator: %s", | |
executor.__class__.__name__, simple_task_instance.key, task_operator | |
) | |
executor.queue_command(simple_task_instance, command, priority, queue) | |
def queue_task_instance( | |
self, | |
task_instance, | |
mark_success=False, | |
pickle_id=None, | |
ignore_all_deps=False, | |
ignore_depends_on_past=False, | |
ignore_task_deps=False, | |
ignore_ti_state=False, | |
pool=None, | |
cfg_path=None): | |
task_operator = task_instance.operator | |
executor = self._router(task_operator) | |
self.log.debug("Using executor: %s to queue_task_instance for %s", executor.__class__, task_instance.key) | |
executor.queue_task_instance( | |
task_instance, | |
mark_success, | |
pickle_id, | |
ignore_all_deps, | |
ignore_depends_on_past, | |
ignore_task_deps, | |
ignore_ti_state, | |
pool, | |
cfg_path | |
) | |
def has_task(self, task_instance): | |
""" | |
Checks if a task is either queued or running in this executor | |
:param task_instance: TaskInstance | |
:return: True if the task is known to this executor | |
""" | |
return self.celery_executor.has_task(task_instance) or self.kubernetes_executor.has_task(task_instance) | |
def sync(self): | |
""" | |
Sync will get called periodically by the heartbeat method. | |
Executors should override this to perform gather statuses. | |
""" | |
self.celery_executor.sync() | |
self.kubernetes_executor.sync() | |
def heartbeat(self): | |
# Triggering new jobs | |
start_at = time.time() | |
try: | |
self.celery_executor.heartbeat() | |
self.kubernetes_executor.heartbeat() | |
finally: | |
perf = time.time() - start_at | |
if (perf >= 2): | |
self.log.info("Executor heartbeat perf: %s", perf) | |
def trigger_tasks(self, open_slots): | |
""" | |
Trigger tasks | |
:param open_slots: Number of open slots | |
:return: | |
""" | |
self.celery_executor.trigger_tasks() | |
self.kubernetes_executor.trigger_tasks() | |
def change_state(self, key, state): | |
# it is ok to add the key to the event buffer for | |
# both executors as the get_event_buffer will check | |
# the existence of the dag | |
self.celery_executor.change_state(key, state) | |
self.kubernetes_executor.change_state(key, state) | |
def get_event_buffer(self, dag_ids=None): | |
""" | |
Returns and flush the event buffer. In case dag_ids is specified | |
it will only return and flush events for the given dag_ids. Otherwise | |
it returns and flushes all | |
:param dag_ids: to dag_ids to return events for, if None returns all | |
:return: a dict of events | |
""" | |
cleared_events_from_celery = self.celery_executor.get_event_buffer(dag_ids) | |
cleared_events_from_kubernetes = self.kubernetes_executor.get_event_buffer(dag_ids) | |
tmp_copy = cleared_events_from_celery.copy() | |
tmp_copy.update(cleared_events_from_kubernetes) # update is in place | |
return tmp_copy | |
def execute_async(self, | |
key, | |
command, | |
queue=None, | |
executor_config=None): # pragma: no cover | |
""" | |
This method will execute the command asynchronously. | |
""" | |
self.celery_executor.execute_async(key, command, queue, executor_config) | |
self.kubernetes_executor.execute_async(key, command, queue, executor_config) | |
def end(self): # pragma: no cover | |
""" | |
This method is called when the caller is done submitting job and is | |
wants to wait synchronously for the job submitted previously to be | |
all done. | |
""" | |
self.celery_executor.end() | |
self.kubernetes_executor.end() | |
def terminate(self): | |
""" | |
This method is called when the daemon receives a SIGTERM | |
""" | |
self.celery_executor.terminate() | |
self.kubernetes_executor.terminate() | |
def _router(self, operator): | |
if self.is_sensor_service_enabled and operator in self.operators_support_sensor_service: | |
return self.celery_executor | |
return self.kubernetes_executor | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment