Last active
September 21, 2018 12:30
-
-
Save jbmlaird/f4c107aa206c903cdf44c18867aaa05f to your computer and use it in GitHub Desktop.
Pull messages from Pub/Sub clearing task instances
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
import datetime as dt | |
import json | |
import logging | |
from airflow import AirflowException, settings | |
from airflow.models import BaseOperator, DagBag, DagRun, TaskInstance | |
from airflow.utils.decorators import apply_defaults | |
from airflow.utils.state import State | |
# Must be declared like this (as opposed to from pubsub_sensor) as imports always check sys.path | |
# https://stackoverflow.com/a/46212814/4624156 | |
from pubsub_sensor import PubSubHook | |
from sqlalchemy import or_ | |
class DagRunOrder(object): | |
def __init__(self, run_id=None, payload=None): | |
self.run_id = run_id | |
self.payload = payload | |
class PubSubTrigger(BaseOperator): | |
""" | |
Triggers a DAG run for a specified ``dag_id`` if a criteria is met | |
:param trigger_dag_id: the dag_id to trigger | |
:type trigger_dag_id: str | |
:param python_callable: a reference to a python function that will be | |
called while passing it the ``context`` object and a placeholder | |
object ``obj`` for your callable to fill and return if you want | |
a DagRun created. This ``obj`` object contains a ``run_id`` and | |
``payload`` attribute that you can modify in your function. | |
The ``run_id`` should be a unique identifier for that DAG run, and | |
the payload has to be a picklable object that will be made available | |
to your tasks while executing that DAG run. Your function header | |
should look like ``def foo(context, dag_run_obj):`` | |
:type python_callable: python callable | |
""" | |
dfp_pubsub_name = "dfp_dt" | |
dfp_dag_name = "dfp-import" | |
youtube_pubsub_name = "youtube_content_owner" | |
youtube_dag_name = "youtube-video-import" | |
template_fields = tuple() | |
template_ext = tuple() | |
ui_color = '#ffefeb' | |
@apply_defaults | |
def __init__( | |
self, | |
project, | |
subscription, | |
ack_messages=False, | |
return_immediately=False, | |
max_messages=10, | |
gcp_conn_id='google_cloud_default', | |
delegate_to=None, | |
*args, **kwargs): | |
super(PubSubTrigger, self).__init__(*args, **kwargs) | |
self.project = project | |
self.subscription = subscription | |
self.ack_messages = ack_messages | |
self.return_immediately = return_immediately | |
self.max_messages = max_messages | |
self.gcp_conn_id = gcp_conn_id | |
self.delegate_to = delegate_to | |
def clear_task_instances(self, tis, session, activate_dag_runs=True): | |
""" | |
Clears a set of task instances, but makes sure the running ones | |
get killed. | |
Pulled from models.py. | |
""" | |
job_ids = [] | |
for ti in tis: | |
if ti.state == State.RUNNING: | |
if ti.job_id: | |
ti.state = State.SHUTDOWN | |
job_ids.append(ti.job_id) | |
else: | |
session.delete(ti) | |
if job_ids: | |
from airflow.jobs import BaseJob as BJ | |
for job in session.query(BJ).filter(BJ.id.in_(job_ids)).all(): | |
job.state = State.SHUTDOWN | |
if activate_dag_runs: | |
execution_dates = {ti.execution_date for ti in tis} | |
dag_ids = {ti.dag_id for ti in tis} | |
drs = session.query(DagRun).filter( | |
DagRun.dag_id.in_(dag_ids), | |
DagRun.execution_date.in_(execution_dates), | |
).all() | |
for dr in drs: | |
dr.state = State.RUNNING | |
dr.start_date = dt.datetime.now() | |
def fetch_then_clear_dag_tis(self, | |
dag, | |
start_date=None, | |
end_date=None, | |
only_failed=False, | |
only_running=False, | |
include_subdags=False): | |
""" | |
Clear all task instances of the associated DAG. | |
Pulled from models.py. | |
:param dag: DAG whose TIs to clear | |
:param start_date: | |
:param end_date: | |
:param only_failed: Only clear failed DAGs | |
:param only_running: Only clear running DAGs | |
:param include_subdags: Also clear subdags | |
""" | |
logging.info("Clearing DAG tis") | |
session = settings.Session() | |
TI = TaskInstance | |
tis = session.query(TI) | |
if include_subdags: | |
# Crafting the right filter for dag_id and task_ids combo | |
conditions = [] | |
for dag in dag.subdags + [dag]: | |
conditions.append( | |
TI.dag_id.like(dag.dag_id) & TI.task_id.in_(dag.task_ids) | |
) | |
tis = tis.filter(or_(*conditions)) | |
else: | |
tis = session.query(TI).filter(TI.dag_id == dag.dag_id) | |
if hasattr(dag, 'task_ids'): | |
tis = tis.filter(TI.task_id.in_(dag.task_ids)) | |
if start_date: | |
tis = tis.filter(TI.execution_date >= start_date) | |
if end_date: | |
tis = tis.filter(TI.execution_date <= end_date) | |
if only_failed: | |
tis = tis.filter(TI.state == State.FAILED) | |
if only_running: | |
tis = tis.filter(TI.state == State.RUNNING) | |
self.clear_task_instances(tis, session) | |
def decode_pubsub(self, pubsub_message): | |
""" | |
Convert the pubsub message data contents to a JSON object | |
:param pubsub_message: Raw PubSub message | |
:return: JSON object of the data contents | |
""" | |
b64_decoded_data = pubsub_message.get('message').get('data').decode('base64') | |
logging.info('pubsub_decoded: {}'.format(b64_decoded_data)) | |
return json.loads(b64_decoded_data) | |
def get_run_status(self, data_json): | |
state = data_json.get('state') | |
logging.info('state: {}'.format(state)) | |
return state == 'SUCCEEDED' | |
def get_target_dag(self, data_json): | |
data_source_id = data_json.get('dataSourceId') | |
logging.info('data_source_id: {}, type: {}'.format(data_source_id, type(data_source_id))) | |
if data_source_id == self.youtube_pubsub_name: | |
target_dag = self.youtube_dag_name | |
elif data_source_id == self.dfp_pubsub_name: | |
target_dag = self.dfp_dag_name | |
else: | |
target_dag = None | |
return target_dag | |
def get_run_time(self, data_json): | |
""" | |
:param data_json: PubSub message | |
:return: runTime extracted from the PubSub | |
""" | |
run_time_string = data_json.get('runTime')[:10] | |
run_time_datetime = dt.datetime.strptime(run_time_string, "%Y-%m-%d") | |
return run_time_datetime | |
def get_ack_id(self, pubsub_message): | |
""" | |
Fetch the acknowledgement ID from the pubsub message | |
:param pubsub_message: raw PubSub message | |
:return: acknowledgement ID | |
""" | |
logging.info("AckId: {}".format(pubsub_message.get('ackId'))) | |
return pubsub_message.get('ackId') | |
def acknowledge_run_ack_ids(self, ack_ids, hook): | |
""" | |
:param ack_ids: List of ack_ids of messages that have had a dag run triggered | |
""" | |
if ack_ids: | |
hook.acknowledge(self.project, self.subscription, ack_ids) | |
logging.info("Acknowledged IDs: {}".format([ack_id for ack_id in ack_ids])) | |
def execute(self, context): | |
hook = PubSubHook(gcp_conn_id=self.gcp_conn_id, | |
delegate_to=self.delegate_to) | |
messages = hook.pull( | |
self.project, self.subscription, self.max_messages, | |
self.return_immediately) | |
triggered_ack_ids = [] | |
logging.info('Number of messages from PubSub: {}'.format(len(messages))) | |
for message_json in messages: | |
data_json = self.decode_pubsub(message_json) | |
if not self.get_run_status(data_json): | |
# If the transfer failed then acknowledge this message and go to the next one | |
logging.info('skipping message: {}'.format(data_json)) | |
triggered_ack_ids.append(self.get_ack_id(message_json)) | |
continue | |
target_dag = self.get_target_dag(data_json) | |
run_time_datetime = self.get_run_time(data_json) | |
self.trigger_dag(dag_id=target_dag, execution_date=run_time_datetime) | |
triggered_ack_ids.append(self.get_ack_id(message_json)) | |
self.acknowledge_run_ack_ids(triggered_ack_ids, hook) | |
def trigger_dag(self, dag_id, run_id=None, execution_date=None): | |
dagbag = DagBag() | |
if dag_id not in dagbag.dags: | |
raise AirflowException("Dag id {} not found".format(dag_id)) | |
dag = dagbag.get_dag(dag_id) | |
logging.info("dag_id: {}, dag: {}".format(dag_id, dag.__dict__)) | |
if execution_date is None: | |
logging.info("Creating new execution_date") | |
execution_date = dt.datetime.utcnow() | |
assert isinstance(execution_date, dt.datetime) | |
execution_date = execution_date.replace(microsecond=0) | |
if not run_id: | |
run_id = "pstrig__{0}".format(dt.datetime.utcnow().isoformat()) | |
logging.info("No run_id provided. Using: {}".format(run_id)) | |
# Check for existing DAGs with this execution date | |
dr = DagRun.find(dag_id=dag_id, execution_date=execution_date) | |
if dr: | |
logging.info("Clearing") | |
# Clear all task instances for the DAGs found | |
for item in dr: | |
logging.info("item: {}".format(item.__dict__)) | |
self.fetch_then_clear_dag_tis(item, start_date=execution_date, end_date=execution_date) | |
else: | |
logging.info("Not found a DAG with this id: {} and execution_date: {}".format(dag_id, execution_date)) | |
logging.info("Creating a new DagRun") | |
# Create new DAG run and execute | |
trigger = dag.create_dagrun( | |
run_id=run_id, | |
execution_date=execution_date, | |
state=State.RUNNING, | |
external_trigger=True, | |
) | |
logging.info("{} started".format(dag_id)) | |
return trigger |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment