Last active
January 22, 2019 22:40
-
-
Save jbmlaird/7453bf19ff80ffed4adbcd2c4cc3ff56 to your computer and use it in GitHub Desktop.
Pull from Pub/Sub
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
import datetime as dt | |
import json | |
import logging | |
from airflow import AirflowException, settings | |
from airflow.models import BaseOperator, DagBag, DagRun, TaskInstance | |
from airflow.utils.decorators import apply_defaults | |
from airflow.utils.state import State | |
# Must be declared like this (as opposed to from pubsub_sensor) as imports always check sys.path | |
# https://stackoverflow.com/a/46212814/4624156 | |
from pubsub_sensor import PubSubHook | |
from sqlalchemy import or_ | |
class DagRunOrder(object): | |
def __init__(self, run_id=None, payload=None): | |
self.run_id = run_id | |
self.payload = payload | |
class PubSubTrigger(BaseOperator): | |
""" | |
Triggers a DAG run for a specified ``dag_id`` if a criteria is met | |
:param trigger_dag_id: the dag_id to trigger | |
:type trigger_dag_id: str | |
:param python_callable: a reference to a python function that will be | |
called while passing it the ``context`` object and a placeholder | |
object ``obj`` for your callable to fill and return if you want | |
a DagRun created. This ``obj`` object contains a ``run_id`` and | |
``payload`` attribute that you can modify in your function. | |
The ``run_id`` should be a unique identifier for that DAG run, and | |
the payload has to be a picklable object that will be made available | |
to your tasks while executing that DAG run. Your function header | |
should look like ``def foo(context, dag_run_obj):`` | |
:type python_callable: python callable | |
""" | |
dfp_pubsub_name = "dfp_dt" | |
dfp_dag_name = "dfp-import" | |
youtube_pubsub_name = "youtube_content_owner" | |
youtube_dag_name = "youtube-video-import" | |
template_fields = tuple() | |
template_ext = tuple() | |
ui_color = '#ffefeb' | |
@apply_defaults | |
def __init__( | |
self, | |
project, | |
subscription, | |
ack_messages=False, | |
return_immediately=False, | |
max_messages=10, | |
gcp_conn_id='google_cloud_default', | |
delegate_to=None, | |
*args, **kwargs): | |
super(PubSubTrigger, self).__init__(*args, **kwargs) | |
self.project = project | |
self.subscription = subscription | |
self.ack_messages = ack_messages | |
self.return_immediately = return_immediately | |
self.max_messages = max_messages | |
self.gcp_conn_id = gcp_conn_id | |
self.delegate_to = delegate_to | |
def decode_pubsub(self, pubsub_message): | |
""" | |
Convert the pubsub message data contents to a JSON object | |
:param pubsub_message: Raw PubSub message | |
:return: JSON object of the data contents | |
""" | |
b64_decoded_data = pubsub_message.get('message').get('data').decode('base64') | |
logging.info('pubsub_decoded: {}'.format(b64_decoded_data)) | |
return json.loads(b64_decoded_data) | |
def get_run_status(self, data_json): | |
state = data_json.get('state') | |
logging.info('state: {}'.format(state)) | |
return state == 'SUCCEEDED' | |
def get_target_dag(self, data_json): | |
data_source_id = data_json.get('dataSourceId') | |
logging.info('data_source_id: {}, type: {}'.format(data_source_id, type(data_source_id))) | |
if data_source_id == self.youtube_pubsub_name: | |
target_dag = self.youtube_dag_name | |
elif data_source_id == self.dfp_pubsub_name: | |
target_dag = self.dfp_dag_name | |
else: | |
target_dag = None | |
return target_dag | |
def get_run_time(self, data_json): | |
""" | |
:param data_json: PubSub message | |
:return: runTime extracted from the PubSub | |
""" | |
run_time_string = data_json.get('runTime')[:10] | |
run_time_datetime = dt.datetime.strptime(run_time_string, "%Y-%m-%d") | |
return run_time_datetime | |
def get_ack_id(self, pubsub_message): | |
""" | |
Fetch the acknowledgement ID from the pubsub message | |
:param pubsub_message: raw PubSub message | |
:return: acknowledgement ID | |
""" | |
logging.info("AckId: {}".format(pubsub_message.get('ackId'))) | |
return pubsub_message.get('ackId') | |
def acknowledge_run_ack_ids(self, ack_ids, hook): | |
""" | |
:param ack_ids: List of ack_ids of messages that have had a dag run triggered | |
""" | |
if ack_ids: | |
hook.acknowledge(self.project, self.subscription, ack_ids) | |
logging.info("Acknowledged IDs: {}".format([ack_id for ack_id in ack_ids])) | |
def execute(self, context): | |
hook = PubSubHook(gcp_conn_id=self.gcp_conn_id, | |
delegate_to=self.delegate_to) | |
messages = hook.pull( | |
self.project, self.subscription, self.max_messages, | |
self.return_immediately) | |
triggered_ack_ids = [] | |
logging.info('Number of messages from PubSub: {}'.format(len(messages))) | |
for message_json in messages: | |
data_json = self.decode_pubsub(message_json) | |
if not self.get_run_status(data_json): | |
# If the transfer failed then acknowledge this message and go to the next one | |
logging.info('skipping message: {}'.format(data_json)) | |
triggered_ack_ids.append(self.get_ack_id(message_json)) | |
continue | |
target_dag = self.get_target_dag(data_json) | |
run_time_datetime = self.get_run_time(data_json) | |
self.trigger_dag(dag_id=target_dag, execution_date=run_time_datetime) | |
triggered_ack_ids.append(self.get_ack_id(message_json)) | |
self.acknowledge_run_ack_ids(triggered_ack_ids, hook) | |
def trigger_dag(self, dag_id, run_id=None, execution_date=None): | |
dagbag = DagBag() | |
if dag_id not in dagbag.dags: | |
raise AirflowException("Dag id {} not found".format(dag_id)) | |
dag = dagbag.get_dag(dag_id) | |
logging.info("dag_id: {}, dag: {}".format(dag_id, dag.__dict__)) | |
if execution_date is None: | |
logging.info("Creating new execution_date") | |
execution_date = dt.datetime.utcnow() | |
assert isinstance(execution_date, dt.datetime) | |
execution_date = execution_date.replace(microsecond=0) | |
if not run_id: | |
run_id = "pstrig__{0}".format(dt.datetime.utcnow().isoformat()) | |
logging.info("No run_id provided. Using: {}".format(run_id)) | |
trigger = dag.create_dagrun( | |
run_id=run_id, | |
execution_date=execution_date, | |
state=State.RUNNING, | |
external_trigger=True, | |
) | |
logging.info("{} started".format(dag_id)) | |
return trigger |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment