Skip to content

Instantly share code, notes, and snippets.

@odidere
Forked from szczeles/SparkOperator.py
Created October 19, 2020 21:03
Show Gist options
  • Save odidere/c1e9a98ca5b907c2364051537a43185b to your computer and use it in GitHub Desktop.
Save odidere/c1e9a98ca5b907c2364051537a43185b to your computer and use it in GitHub Desktop.
SparkOperator for airflow designed to simplify work with Spark on YARN. Simplifies using spark-submit in airflow DAGs, retrieves application id and tracking URL from logs and ensures YARN application is killed on timeout
from airflow.models import BaseOperator
import logging
from subprocess import Popen, STDOUT, PIPE
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.utils.decorators import apply_defaults
'''
SparkOperator for airflow designed to simplify work with Spark on YARN.
Simplifies using spark-submit in airflow DAGs, retrieves application id
and tracking URL from logs and ensures YARN application is killed on timeout.
Example usage:
my_task = SparkOperator(script='hdfs:///location/of/pyspark/script.py',
driver_memory='2g',
extra_params='--files hdfs:///some/extra/file',
execution_timeout = timedelta(minutes=30),
dag=dag)
'''
class SparkOperator(BaseOperator):
@apply_defaults
def __init__(self, script, driver_memory='1g', spark_version='2.0.2', extra_params='', yarn_queue='default', *args, **kwargs):
self.task_id = script.split('/')[-1].split('.')[0]
super(SparkOperator, self).__init__(*args, **dict(kwargs, task_id = self.task_id))
self.env = {'HADOOP_CONF_DIR': '/etc/hadoop/conf', 'PYSPARK_PYTHON': '/opt/conda/bin/python3'}
self.spark_version = spark_version
self.spark_params = '--master yarn --deploy-mode cluster --queue {yarn_queue} --driver-memory {driver_memory} {extra_params} {script}'.format(
yarn_queue=yarn_queue,
driver_memory=driver_memory,
extra_params=extra_params,
script=script
)
def execute(self, context):
try:
self.execute_spark_submit(context)
except AirflowTaskTimeout:
self.on_kill()
raise
def execute_spark_submit(self, context):
command = '/usr/local/bin/spark-submit{spark_version} {params}'.format(spark_version=self.spark_version, params=self.spark_params)
logging.info("Running command: " + command)
self.applicationId = None
self.sp = Popen(command, stdout=PIPE, stderr=STDOUT, env=self.env, shell=True)
logging.info("Output:")
line = ''
for line in iter(self.sp.stdout.readline, b''):
line = line.decode('utf-8').strip()
if 'impl.YarnClientImpl: Submitted application' in line:
self.applicationId = line.split(' ')[-1]
context['ti'].xcom_push(key='application-id', value=self.applicationId)
if 'tracking URL: ' in line:
context['ti'].xcom_push(key='tracking-url', value=line.split(' ')[-1])
logging.info(line)
self.sp.wait()
logging.info("Spark-submit exited with "
"return code {0}".format(self.sp.returncode))
if self.sp.returncode:
raise AirflowException("Spark application failed")
def on_kill(self):
logging.info('Sending SIGTERM signal to spark-submit')
self.sp.terminate()
if self.applicationId:
logging.info('Killing application on YARN...')
yarn_kill = Popen('yarn application -kill ' + self.applicationId, stdout=PIPE, stderr=STDOUT, env=self.env, shell=True)
logging.info('...done, yarn command return code: ' + str(yarn_kill.wait()))
else:
logging.info('Not killing any application on YARN')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment