Created
September 22, 2022 14:21
-
-
Save ZaxR/39d86ce7080998d441d6ce14a85adc57 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python3 | |
"""CLV acquisition predictions, triggered daily.""" | |
from datetime import datetime, timedelta | |
from typing import Optional | |
from airflow import DAG | |
from helpers import k8s, mlops, mlops_factories, settings | |
from helpers.custom_operators import ModifiedKubernetesPodOperator, SnapshotStartDateOperator, UpdateConfOperator | |
from helpers.sensors.bigquery import BigQueryTableUpdatedSensor | |
from helpers.slack_ids import SlackID | |
############################################################################## | |
# DAG Setup | |
############################################################################## | |
DAG_ID = "clv_acquisition_predict" | |
PROJECT_NAME = "ds-cc-clv-acquisition" | |
# Capturing run_id based on start date instead of datetime.now(). | |
# Note that the run_id in the UI will still be created using execution_date, and cannot be changed. | |
RUN_ID = mlops.run_id_template(settings) | |
RELEASE_TYPE = "prod" if settings.IS_PROD_ENVIRONMENT else "dev" | |
DOCKER_TAG = "prod" if settings.IS_PROD_ENVIRONMENT else "dev" | |
GA_DEBUG_MODE = False # True for local development only | |
GA_TRACKING_ID = "UA-11111111-1" if settings.IS_PROD_ENVIRONMENT else "UA-11111111-2" | |
GCP_PROJECT = "prodproject" if settings.IS_PROD_ENVIRONMENT else "devproject" | |
OUTPUT_BQ_PROJECT = "prodproject" if settings.IS_PROD_ENVIRONMENT else "devproject" | |
GOOGLE_BUCKET = "gcsbucket" | |
DEV_SLACK_ID = SlackID.DEVELOPER | |
CHANNEL_ID = SlackID.AUTOBIDDER if settings.IS_PROD_ENVIRONMENT else DEV_SLACK_ID | |
SENSOR_INTERVAL = 60 * 5 | |
MAX_TOTAL_TASK_RUN_TIME = ( | |
60 * 10 | |
) # Max expected total DAG run time excluding sensor waiting | |
CONF = { | |
"IMAGE": f"gcr.io/zorodataplatform/{PROJECT_NAME}:{DOCKER_TAG}", | |
"NOTIFICATIONS": { | |
"SUCCESS": {"SLACK_IDS": [DEV_SLACK_ID, CHANNEL_ID]}, | |
"FAILURE": { | |
"MAINTAINER_SLACK_IDS": [DEV_SLACK_ID], | |
"SLACK_IDS": [DEV_SLACK_ID], | |
}, | |
}, | |
"RELEASE_TYPE": RELEASE_TYPE, | |
"PREDICT_BATCH": { | |
"CPU": "1", | |
"ENV_VARS": { | |
"GOOGLE_BUCKET": GOOGLE_BUCKET, | |
"GOOGLE_PROJECT_BQ": GCP_PROJECT, | |
"OUTPUT_BQ_PROJECT": OUTPUT_BQ_PROJECT, | |
"RUN_ID": RELEASE_TYPE, | |
}, | |
"IMAGE": f"gcr.io/zorodataplatform/{PROJECT_NAME}:{DOCKER_TAG}", | |
"MEMORY": "2G", | |
"RELEASE_TYPE": RELEASE_TYPE, | |
"SCRIPT": "legacy_predict.py", # "predict.py", | |
# Typically takes 1-3 minutes | |
"TIMEOUT": int(60 * 10), | |
}, | |
} | |
############################################################################## | |
# Helper Functions | |
############################################################################## | |
def add_task( | |
dag: DAG, | |
task_name: str, | |
task_suffix: Optional[str] = None, | |
task_suffix_sep: str = mlops.TASK_SUFFIX_SEP, | |
task_retries: int = mlops.DEFAULT_TASK_RETRIES, | |
default_task_timeout: int = mlops.DEFAULT_TASK_TIMEOUT, | |
do_xcom_push: bool = False, | |
trigger_rule: str = "all_success", | |
) -> None: | |
"""Factory function to add add a task to `dag`. | |
Args: | |
dag: DAG to add the task to. | |
task_name: Base name of the task in the DAG. | |
Will be the whole task name if no `task_suffix` is provided. | |
Used to determine the namespace from which to grab `conf`. | |
task_suffix: Suffix to be appended to `task_name` when naming the task, if desired. | |
Useful when repeating a task multiple times in the same dag. | |
Used to determine the sub-namespace from which to grab `conf`, if provided. | |
task_suffix_sep: Characters to separate `task_name` from `task_suffix` in the task name, | |
when `task_suffix` is provided. | |
task_retries: Number of task tries. Supercedes the DAG's default. | |
default_task_timeout: Default task timeout in seconds, if not provided via conf. | |
do_xcom_push: Whether or not to push xcom. | |
Only enable this if you're writing xcom to /airflow/xcom/return.json, | |
or a handshake error (see MLOPS-94) may result. | |
""" | |
if task_suffix is not None and not isinstance(task_suffix, str): | |
raise ValueError("`task_suffix` must be a string") | |
user_defined_filters = { | |
"get_secret_name": mlops.get_secret_name, | |
"get_mlops_tolerations": k8s.get_mlops_tolerations, | |
} | |
user_defined_macros = { | |
"get_task_conf": mlops.get_task_conf, | |
} | |
mlops.add_macros_and_filters( | |
dag=dag, | |
user_defined_filters=user_defined_filters, | |
user_defined_macros=user_defined_macros, | |
) | |
full_task_name = ( | |
f"{task_name}{task_suffix_sep}{task_suffix}".lower() | |
if task_suffix is not None | |
else task_name.lower() | |
) | |
conf_str = f"""get_task_conf(dag_run, {task_name!r}, {task_suffix!r})""" | |
task = ModifiedKubernetesPodOperator( # noqa: F841 | |
task_id=full_task_name, | |
trigger_rule=trigger_rule, | |
name=f"mlops_{full_task_name}", | |
image=f"{{{{ {conf_str}['IMAGE'] }}}}", | |
namespace="default", | |
cluster_name=mlops.CLUSTER_NAME, | |
cluster_zone=mlops.CLUSTER_ZONE, | |
image_pull_policy="Always", | |
labels={ | |
"timeout_seconds": f"{{{{ {conf_str}['TIMEOUT'] | default('{default_task_timeout}', true) }}}}", | |
}, | |
startup_timeout_seconds=600, | |
env_vars=f"{{{{ {conf_str} }}}}", | |
secrets=f'{{{{ {conf_str}.get("RELEASE_TYPE", "exp") | get_secret_name() }}}}', | |
resources={ | |
"request_memory": f"{{{{ {conf_str}.get('MEMORY', '12G') }}}}", | |
"request_cpu": f"{{{{ {conf_str}.get('CPU', '3') }}}}", | |
"limit_memory": f"{{{{ {conf_str}.get('MEMORY', '12G') }}}}", | |
"limit_cpu": f"{{{{ {conf_str}.get('CPU', '3') }}}}", | |
"limit_gpu": f"{{{{ {conf_str}.get('GPU_COUNT') }}}}", | |
}, | |
# Pushes the content of /airflow/xcom/return.json from container to an XCom when the container ends. | |
do_xcom_push=do_xcom_push, | |
node_selectors={"zdp/purpose": "mlops"}, | |
tolerations=f"{{{{ {conf_str}.get('GPU_TYPE') | get_mlops_tolerations }}}}", | |
# Params overwrite execution_timeout from conf["TRAIN_TIMEOUT"], if provided | |
execution_timeout=timedelta(seconds=default_task_timeout), | |
params={ | |
"TIMEOUT_CONF": "TIMEOUT", | |
"CONF_STR": conf_str, | |
}, | |
retries=task_retries, | |
retry_delay=timedelta(seconds=300), | |
retry_exponential_backoff=True, | |
on_failure_callback=mlops.send_failure_message, | |
dag=dag, | |
) | |
############################################################################## | |
# DAG | |
############################################################################## | |
with DAG( | |
dag_id=DAG_ID, | |
schedule_interval="0 17 * * *", # Daily at 17:00 UTC; time is relatively arbitrary | |
max_active_runs=1, | |
catchup=False, | |
default_args=mlops.get_dag_default_args(), | |
) as dag: | |
t_snapshot_start_date = SnapshotStartDateOperator(task_id="snapshot_start_date") | |
t_update_conf = UpdateConfOperator( | |
task_id="update_conf", given_conf=CONF, replace=False | |
) | |
# TODO: Replace with a Deferrable Operator once available in Composer | |
t_bq_sensor_s_customer_acq = BigQueryTableUpdatedSensor( | |
task_id="bq_sensor_s_customer_acq", | |
project_id=GCP_PROJECT, | |
dataset_id="some_dataset", | |
table_id=f"s_customer_acquisition", | |
comparison_time='{{dag_run.conf["_zoro_mlops"]["first_attempt_start_date"]}}', | |
poke_interval=SENSOR_INTERVAL, | |
# Must time out before the next schedule interval (daily) to avoid potential deadlock. | |
# timeout is relative to a single task try (i.e. timeout*retries = total time). | |
timeout=60 * 60 * 4 - (SENSOR_INTERVAL + MAX_TOTAL_TASK_RUN_TIME), | |
# Reschedule mode frees up a worker slot between checks. | |
mode="reschedule", | |
retries=0, | |
) | |
partition_date = ( | |
datetime.now(mlops.CHICAGO_TIMEZONE) - timedelta(days=1) | |
).strftime( | |
"%Y%m%d" | |
) # e.g. 20220620 | |
# This task will always run, even if there's no valid input data | |
mlops_factories.add_predict_batch_task(dag) | |
( | |
t_snapshot_start_date | |
>> t_update_conf | |
>> t_bq_sensor_s_customer_acq | |
>> dag.get_task("predict_batch") | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment