Last active
January 13, 2022 22:25
-
-
Save ottomata/e735c930c9a7f3eff34e874b6651f04f to your computer and use it in GitHub Desktop.
SkeinSparkSubmitOperator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from wmf_airflow_common.operators import skein | |
op1 = skein.SkeinSparkSubmitOperator( | |
spark_submit_kwargs={ | |
'application': 'hdfs:///user/otto/spark-examples_2.11-2.4.4.jar', | |
'spark_submit': '/usr/bin/spark2-submit', | |
'master': 'yarn', | |
'deploy_mode': 'client', | |
'java_class': 'org.apache.spark.examples.SparkPi', | |
'application_args': ['10'], | |
} | |
) | |
# Normally, this would be run by airflow: | |
op1.execute_callable() | |
# -> an yarn application id. Viewing yarn logs after it finishes shows it works. | |
# Let's try with a packed conda env with a custom python and pyspark! | |
op2 = skein.SkeinSparkSubmitOperator( | |
# This is needed for our custom spark to find hadoop libs. | |
# We can probably automate this. | |
preamble = ['SPARK_DIST_CLASSPATH=$(hadoop classpath)'], | |
spark_submit_kwargs={ | |
# this is my pyspark application file, which is shipped along with the conda dist env. | |
'application': 'conda_env/bin/myproject_spark.py', | |
# this will be also unpacked on the skein app master. | |
'archives': 'hdfs:///user/otto/c5.tgz#conda_env', | |
'driver_java_options': '"-Dhttp.proxyHost=http://webproxy.eqiad.wmnet -Dhttp.proxyPort=8080 -Dhttps.proxyHost=http://webproxy.eqiad.wmnet -Dhttps.proxyPort=8080"', | |
'spark_submit': 'conda_env/bin/spark-submit', | |
'master': 'yarn', | |
'deploy_mode': 'client', | |
'conf': { | |
'spark.yarn.maxAppAttempts': 1 | |
} | |
}, | |
skein_master_kwargs={ | |
'env': { | |
'REQUESTS_CA_BUNDLE': '/etc/ssl/certs', | |
'HADOOP_CONF_DIR': '/etc/hadoop/conf', | |
'SPARK_HOME': 'conda_env/lib/python3.7/site-packages/pyspark', | |
'PYSPARK_PYTHON': 'conda_env/bin/python', | |
}, | |
'log_level': 'DEBUG', | |
}, | |
skein_client_kwargs={ | |
'log_level': 'DEBUG', | |
} | |
) | |
op2.execute_callable() | |
# it works! | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import skein | |
from airflow.operators.python import PythonOperator | |
# hacked SparkSubmitHook from upstream. I think we can do this without hacking it. | |
from wmf_airflow_common.hooks.spark import SparkSubmitHook | |
def skein_execute(client_kwargs, app_spec): | |
client = skein.Client(**client_kwargs) | |
return client.submit(app_spec) | |
class SkeinBashOperator(PythonOperator): | |
def __init__( | |
self, | |
task_id=None, | |
script=None, | |
queue=None, | |
skein_application_spec_kwargs={}, | |
skein_client_kwargs={}, | |
skein_master_kwargs={}, | |
): | |
if script: | |
skein_master_kwargs['script'] = script | |
if queue: | |
skein_application_spec_kwargs['queue'] = queue | |
print('SkeinBashOperator script', script) | |
self.skein_master = skein.Master( | |
**skein_master_kwargs | |
) | |
self.skein_application_spec = skein.ApplicationSpec( | |
master=self.skein_master, | |
**skein_application_spec_kwargs | |
) | |
super().__init__( | |
task_id=task_id, | |
python_callable=skein_execute, | |
op_kwargs={ | |
'client_kwargs': skein_client_kwargs, | |
'app_spec': self.skein_application_spec, | |
} | |
) | |
class SkeinSparkSubmitOperator(SkeinBashOperator): | |
def __init__( | |
self, | |
task_id='SkeinSparkSubmitOperator_test_otto', | |
preamble=None, | |
skein_master_kwargs={}, | |
skein_client_kwargs={}, | |
spark_submit_kwargs={}, | |
): | |
spark_submit_hook = SparkSubmitHook(**spark_submit_kwargs) | |
skein_master_kwargs['resources'] = skein.Resources( | |
memory=1024, | |
vcores=1, | |
) | |
command = spark_submit_hook._build_spark_submit_command() | |
if preamble: | |
command = preamble + command | |
print('command\n', command) | |
script = ' '.join(command) | |
print('script\n', script) | |
# In spark yarn client mode, the spark driver runs in the skein master. | |
# Make sure skein asks for enough resources to run the driver | |
if 'resources' not in skein_master_kwargs: | |
if spark_submit_hook._deploy_mode == 'cluster': | |
skein_master_kwargs['resources'] = skein.Resources( | |
memory=1024, | |
vcores=1, | |
) | |
else: | |
skein_master_kwargs['resources'] = skein.Resources( | |
memory=spark_submit_hook._driver_memory, | |
vcores=int(spark_submit_hook._driver_cores) + 1, | |
) | |
if spark_submit_hook._archives and 'files' not in skein_master_kwargs: | |
skein_master_kwargs['files'] = {} | |
for archive in spark_submit_hook._archives.split(','): | |
if '#' in archive: | |
(uri, alias) = archive.rsplit('#', 1) | |
else: | |
uri = archive | |
alias = os.path.basename(archive) | |
skein_master_kwargs['files'][alias] = uri | |
# Initialize the BashOperator parent. | |
super().__init__( | |
task_id=task_id, | |
script=script, | |
queue=spark_submit_hook._queue, | |
skein_client_kwargs=skein_client_kwargs, | |
skein_master_kwargs=skein_master_kwargs, | |
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# I had to patch the SparkSubmitHook to do a few things I needed, and also to not use airflow connections for my tests. | |
# I think it should be possible to use the provided SparkSubmitHook though. | |
# Taken from: https://github.com/apache/airflow/blob/main/airflow/providers/apache/spark/hooks/spark_submit.py | |
# | |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# | |
import os | |
import re | |
import subprocess | |
import time | |
from typing import Any, Dict, Iterator, List, Optional, Union | |
from airflow.configuration import conf as airflow_conf | |
from airflow.exceptions import AirflowException | |
from airflow.hooks.base import BaseHook | |
from airflow.security.kerberos import renew_from_kt | |
from airflow.utils.log.logging_mixin import LoggingMixin | |
try: | |
from airflow.kubernetes import kube_client | |
except (ImportError, NameError): | |
pass | |
class SparkSubmitHook(BaseHook, LoggingMixin): | |
""" | |
This hook is a wrapper around the spark-submit binary to kick off a spark-submit job. | |
It requires that the "spark-submit" binary is in the PATH or the spark_home to be | |
supplied. | |
:param conf: Arbitrary Spark configuration properties | |
:type conf: dict | |
:param files: Upload additional files to the executor running the job, separated by a | |
comma. Files will be placed in the working directory of each executor. | |
For example, serialized objects. | |
:type files: str | |
:param py_files: Additional python files used by the job, can be .zip, .egg or .py. | |
:type py_files: str | |
:param: archives: Archives that spark should unzip (and possibly tag with #ALIAS) into | |
the application working directory. | |
:param driver_class_path: Additional, driver-specific, classpath settings. | |
:type driver_class_path: str | |
:param jars: Submit additional jars to upload and place them in executor classpath. | |
:type jars: str | |
:param java_class: the main class of the Java application | |
:type java_class: str | |
:param packages: Comma-separated list of maven coordinates of jars to include on the | |
driver and executor classpaths | |
:type packages: str | |
:param exclude_packages: Comma-separated list of maven coordinates of jars to exclude | |
while resolving the dependencies provided in 'packages' | |
:type exclude_packages: str | |
:param repositories: Comma-separated list of additional remote repositories to search | |
for the maven coordinates given with 'packages' | |
:type repositories: str | |
:param total_executor_cores: (Standalone & Mesos only) Total cores for all executors | |
(Default: all the available cores on the worker) | |
:type total_executor_cores: int | |
:param executor_cores: (Standalone, YARN and Kubernetes only) Number of cores per | |
executor (Default: 2) | |
:type executor_cores: int | |
:param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G) | |
:type executor_memory: str | |
:param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) (Default: 1G) | |
:type driver_memory: str | |
:param keytab: Full path to the file that contains the keytab | |
:type keytab: str | |
:param principal: The name of the kerberos principal used for keytab | |
:type principal: str | |
:param proxy_user: User to impersonate when submitting the application | |
:type proxy_user: str | |
:param name: Name of the job (default airflow-spark) | |
:type name: str | |
:param num_executors: Number of executors to launch | |
:type num_executors: int | |
:param status_poll_interval: Seconds to wait between polls of driver status in cluster | |
mode (Default: 1) | |
:type status_poll_interval: int | |
:param application_args: Arguments for the application being submitted | |
:type application_args: list | |
:param env_vars: Environment variables for spark-submit. It | |
supports yarn and k8s mode too. | |
:type env_vars: dict | |
:param verbose: Whether to pass the verbose flag to spark-submit process for debugging | |
:type verbose: bool | |
:param spark_binary: The command to use for spark submit. | |
Some distros may use spark2-submit. | |
:type spark_binary: str | |
""" | |
conn_name_attr = 'conn_id' | |
conn_type = 'spark' | |
hook_name = 'Spark' | |
@staticmethod | |
def get_ui_field_behaviour() -> Dict: | |
"""Returns custom field behaviour""" | |
return { | |
"hidden_fields": ['schema', 'login', 'password'], | |
"relabeling": {}, | |
} | |
def __init__( | |
self, | |
application: str, | |
application_args: Optional[List[Any]] = None, | |
conf: Optional[Dict[str, Any]] = None, | |
master: Optional[str] = None, | |
deploy_mode: Optional[str] = None, | |
queue: Optional[str] = None, | |
files: Optional[str] = None, | |
py_files: Optional[str] = None, | |
archives: Optional[str] = None, | |
driver_class_path: Optional[str] = None, | |
jars: Optional[str] = None, | |
java_class: Optional[str] = None, | |
packages: Optional[str] = None, | |
exclude_packages: Optional[str] = None, | |
repositories: Optional[str] = None, | |
total_executor_cores: Optional[int] = None, | |
executor_cores: Optional[int] = None, | |
executor_memory: Optional[str] = None, | |
driver_memory: Optional[str] = None, | |
driver_cores: Optional[int] = None, | |
driver_java_options: Optional[str] = None, | |
keytab: Optional[str] = None, | |
principal: Optional[str] = None, | |
proxy_user: Optional[str] = None, | |
name: str = 'default-name', | |
num_executors: Optional[int] = None, | |
status_poll_interval: int = 1, | |
env_vars: Optional[Dict[str, Any]] = None, | |
verbose: bool = False, | |
spark_submit: str = 'spark-submit', | |
) -> None: | |
super().__init__() | |
self._application = application | |
self._application_args = application_args | |
self._conf = conf or {} | |
self._master = master | |
self._deploy_mode = deploy_mode | |
self._queue = queue | |
self._files = files | |
self._py_files = py_files | |
self._archives = archives | |
self._driver_class_path = driver_class_path | |
self._jars = jars | |
self._java_class = java_class | |
self._packages = packages | |
self._exclude_packages = exclude_packages | |
self._repositories = repositories | |
self._total_executor_cores = total_executor_cores | |
self._executor_cores = executor_cores | |
self._executor_memory = executor_memory | |
self._driver_memory = driver_memory | |
self._driver_cores = driver_cores | |
self._driver_java_options = driver_java_options | |
self._keytab = keytab | |
self._principal = principal | |
self._proxy_user = proxy_user | |
self._name = name | |
self._num_executors = num_executors | |
self._status_poll_interval = status_poll_interval | |
self._env_vars = env_vars | |
self._verbose = verbose | |
self._submit_sp: Optional[Any] = None | |
self._yarn_application_id: Optional[str] = None | |
self._kubernetes_driver_pod: Optional[str] = None | |
self._spark_submit = spark_submit | |
self._is_yarn = 'yarn' in self._master | |
# self._is_kubernetes = 'k8s' in self._master | |
# if self._is_kubernetes and kube_client is None: | |
# raise RuntimeError( | |
# f"{self._connection['master']} specified by kubernetes dependencies are not installed!" | |
# ) | |
self._should_track_driver_status = self._resolve_should_track_driver_status() | |
self._driver_id: Optional[str] = None | |
self._driver_status: Optional[str] = None | |
self._spark_exit_code: Optional[int] = None | |
self._env: Optional[Dict[str, Any]] = None | |
def _resolve_should_track_driver_status(self) -> bool: | |
""" | |
Determines whether or not this hook should poll the spark driver status through | |
subsequent spark-submit status requests after the initial spark-submit request | |
:return: if the driver status should be tracked | |
""" | |
return 'spark://' in self._master and self._deploy_mode == 'cluster' | |
# def _resolve_connection(self) -> Dict[str, Any]: | |
# # Build from connection master or default to yarn if not available | |
# conn_data = { | |
# 'master': 'yarn', | |
# 'queue': None, | |
# 'deploy_mode': None, | |
# 'spark_home': None, | |
# 'spark_binary': self._spark_binary or "spark-submit", | |
# 'namespace': None, | |
# } | |
# try: | |
# # Master can be local, yarn, spark://HOST:PORT, mesos://HOST:PORT and | |
# # k8s://https://<HOST>:<PORT> | |
# conn = self.get_connection(self._conn_id) | |
# if conn.port: | |
# conn_data['master'] = f"{conn.host}:{conn.port}" | |
# else: | |
# conn_data['master'] = conn.host | |
# # Determine optional yarn queue from the extra field | |
# extra = conn.extra_dejson | |
# conn_data['queue'] = extra.get('queue') | |
# conn_data['deploy_mode'] = extra.get('deploy-mode') | |
# conn_data['spark_home'] = extra.get('spark-home') | |
# conn_data['spark_binary'] = self._spark_binary or extra.get('spark-binary', "spark-submit") | |
# conn_data['namespace'] = extra.get('namespace') | |
# except AirflowException: | |
# self.log.info( | |
# "Could not load connection string %s, defaulting to %s", self._conn_id, conn_data['master'] | |
# ) | |
# if 'spark.kubernetes.namespace' in self._conf: | |
# conn_data['namespace'] = self._conf['spark.kubernetes.namespace'] | |
# return conn_data | |
def get_conn(self) -> Any: | |
pass | |
def _mask_cmd(self, connection_cmd: Union[str, List[str]]) -> str: | |
# Mask any password related fields in application args with key value pair | |
# where key contains password (case insensitive), e.g. HivePassword='abc' | |
connection_cmd_masked = re.sub( | |
r"(" | |
r"\S*?" # Match all non-whitespace characters before... | |
r"(?:secret|password)" # ...literally a "secret" or "password" | |
# word (not capturing them). | |
r"\S*?" # All non-whitespace characters before either... | |
r"(?:=|\s+)" # ...an equal sign or whitespace characters | |
# (not capturing them). | |
r"(['\"]?)" # An optional single or double quote. | |
r")" # This is the end of the first capturing group. | |
r"(?:(?!\2\s).)*" # All characters between optional quotes | |
# (matched above); if the value is quoted, | |
# it may contain whitespace. | |
r"(\2)", # Optional matching quote. | |
r'\1******\3', | |
' '.join(connection_cmd), | |
flags=re.I, | |
) | |
return connection_cmd_masked | |
def _build_spark_submit_command(self) -> List[str]: | |
""" | |
Construct the spark-submit command to execute. | |
:return: full command to be executed | |
""" | |
connection_cmd = [self._spark_submit] | |
# The url of the spark master | |
connection_cmd += ["--master", self._master] | |
if self._deploy_mode: | |
connection_cmd += ["--deploy-mode", self._deploy_mode] | |
for key in self._conf: | |
connection_cmd += ["--conf", f"{key}={str(self._conf[key])}"] | |
if self._env_vars and (self._is_kubernetes or self._is_yarn): | |
if self._is_yarn: | |
tmpl = "spark.yarn.appMasterEnv.{}={}" | |
# Allow dynamic setting of hadoop/yarn configuration environments | |
self._env = self._env_vars | |
else: | |
tmpl = "spark.kubernetes.driverEnv.{}={}" | |
for key in self._env_vars: | |
connection_cmd += ["--conf", tmpl.format(key, str(self._env_vars[key]))] | |
elif self._env_vars and self._deploy_mode != "cluster": | |
self._env = self._env_vars # Do it on Popen of the process | |
elif self._env_vars and self._deploy_mode == "cluster": | |
raise AirflowException("SparkSubmitHook env_vars is not supported in standalone-cluster mode.") | |
# if self._is_kubernetes and self._connection['namespace']: | |
# connection_cmd += [ | |
# "--conf", | |
# f"spark.kubernetes.namespace={self._namespace}", | |
# ] | |
if self._files: | |
connection_cmd += ["--files", self._files] | |
if self._py_files: | |
connection_cmd += ["--py-files", self._py_files] | |
if self._archives: | |
connection_cmd += ["--archives", self._archives] | |
if self._driver_class_path: | |
connection_cmd += ["--driver-class-path", self._driver_class_path] | |
if self._jars: | |
connection_cmd += ["--jars", self._jars] | |
if self._packages: | |
connection_cmd += ["--packages", self._packages] | |
if self._exclude_packages: | |
connection_cmd += ["--exclude-packages", self._exclude_packages] | |
if self._repositories: | |
connection_cmd += ["--repositories", self._repositories] | |
if self._num_executors: | |
connection_cmd += ["--num-executors", str(self._num_executors)] | |
if self._total_executor_cores: | |
connection_cmd += ["--total-executor-cores", str(self._total_executor_cores)] | |
if self._executor_cores: | |
connection_cmd += ["--executor-cores", str(self._executor_cores)] | |
if self._executor_memory: | |
connection_cmd += ["--executor-memory", self._executor_memory] | |
if self._driver_memory: | |
connection_cmd += ["--driver-memory", self._driver_memory] | |
if self._driver_cores: | |
connection_cmd += ["--driver-cores", self._driver_cores] | |
if self._driver_java_options: | |
connection_cmd += ["--driver-java-options", self._driver_java_options] | |
if self._keytab: | |
connection_cmd += ["--keytab", self._keytab] | |
if self._principal: | |
connection_cmd += ["--principal", self._principal] | |
if self._proxy_user: | |
connection_cmd += ["--proxy-user", self._proxy_user] | |
if self._name: | |
connection_cmd += ["--name", self._name] | |
if self._java_class: | |
connection_cmd += ["--class", self._java_class] | |
if self._verbose: | |
connection_cmd += ["--verbose"] | |
if self._queue: | |
connection_cmd += ["--queue", self._queue] | |
# The actual script to execute | |
connection_cmd += [self._application] | |
# Append any application arguments | |
if self._application_args: | |
connection_cmd += self._application_args | |
# self.log.info("Spark-Submit cmd: %s", self._mask_cmd(connection_cmd)) | |
return connection_cmd | |
def _build_track_driver_status_command(self) -> List[str]: | |
""" | |
Construct the command to poll the driver status. | |
:return: full command to be executed | |
""" | |
curl_max_wait_time = 30 | |
spark_host = self._master | |
if spark_host.endswith(':6066'): | |
spark_host = spark_host.replace("spark://", "http://") | |
connection_cmd = [ | |
"/usr/bin/curl", | |
"--max-time", | |
str(curl_max_wait_time), | |
f"{spark_host}/v1/submissions/status/{self._driver_id}", | |
] | |
self.log.info(connection_cmd) | |
# The driver id so we can poll for its status | |
if self._driver_id: | |
pass | |
else: | |
raise AirflowException( | |
"Invalid status: attempted to poll driver status but no driver id is known. Giving up." | |
) | |
else: | |
connection_cmd = [self._spark_submit] | |
# The url to the spark master | |
connection_cmd += ["--master", self._master] | |
# The driver id so we can poll for its status | |
if self._driver_id: | |
connection_cmd += ["--status", self._driver_id] | |
else: | |
raise AirflowException( | |
"Invalid status: attempted to poll driver status but no driver id is known. Giving up." | |
) | |
self.log.debug("Poll driver status cmd: %s", connection_cmd) | |
return connection_cmd | |
def submit(self, application: str = "", **kwargs: Any) -> None: | |
""" | |
Remote Popen to execute the spark-submit job | |
:param application: Submitted application, jar or py file | |
:type application: str | |
:param kwargs: extra arguments to Popen (see subprocess.Popen) | |
""" | |
spark_submit_cmd = self._build_spark_submit_command(application) | |
if self._env: | |
env = os.environ.copy() | |
env.update(self._env) | |
kwargs["env"] = env | |
self._submit_sp = subprocess.Popen( | |
spark_submit_cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
bufsize=-1, | |
universal_newlines=True, | |
**kwargs, | |
) | |
self._process_spark_submit_log(iter(self._submit_sp.stdout)) # type: ignore | |
returncode = self._submit_sp.wait() | |
# Check spark-submit return code. In Kubernetes mode, also check the value | |
# of exit code in the log, as it may differ. | |
if returncode or (self._is_kubernetes and self._spark_exit_code != 0): | |
if self._is_kubernetes: | |
raise AirflowException( | |
f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. Error code is: {returncode}. " | |
f"Kubernetes spark exit code is: {self._spark_exit_code}" | |
) | |
else: | |
raise AirflowException( | |
f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. Error code is: {returncode}." | |
) | |
self.log.debug("Should track driver: %s", self._should_track_driver_status) | |
# We want the Airflow job to wait until the Spark driver is finished | |
if self._should_track_driver_status: | |
if self._driver_id is None: | |
raise AirflowException( | |
"No driver id is known: something went wrong when executing the spark submit command" | |
) | |
# We start with the SUBMITTED status as initial status | |
self._driver_status = "SUBMITTED" | |
# Start tracking the driver status (blocking function) | |
self._start_driver_status_tracking() | |
if self._driver_status != "FINISHED": | |
raise AirflowException( | |
f"ERROR : Driver {self._driver_id} badly exited with status {self._driver_status}" | |
) | |
def _process_spark_submit_log(self, itr: Iterator[Any]) -> None: | |
""" | |
Processes the log files and extracts useful information out of it. | |
If the deploy-mode is 'client', log the output of the submit command as those | |
are the output logs of the Spark worker directly. | |
Remark: If the driver needs to be tracked for its status, the log-level of the | |
spark deploy needs to be at least INFO (log4j.logger.org.apache.spark.deploy=INFO) | |
:param itr: An iterator which iterates over the input of the subprocess | |
""" | |
# Consume the iterator | |
for line in itr: | |
line = line.strip() | |
# If we run yarn cluster mode, we want to extract the application id from | |
# the logs so we can kill the application when we stop it unexpectedly | |
if self._is_yarn and self._deploy_mode == 'cluster': | |
match = re.search('(application[0-9_]+)', line) | |
if match: | |
self._yarn_application_id = match.groups()[0] | |
self.log.info("Identified spark driver id: %s", self._yarn_application_id) | |
# If we run Kubernetes cluster mode, we want to extract the driver pod id | |
# from the logs so we can kill the application when we stop it unexpectedly | |
elif self._is_kubernetes: | |
match = re.search(r'\s*pod name: ((.+?)-([a-z0-9]+)-driver)', line) | |
if match: | |
self._kubernetes_driver_pod = match.groups()[0] | |
self.log.info("Identified spark driver pod: %s", self._kubernetes_driver_pod) | |
# Store the Spark Exit code | |
match_exit_code = re.search(r'\s*[eE]xit code: (\d+)', line) | |
if match_exit_code: | |
self._spark_exit_code = int(match_exit_code.groups()[0]) | |
# if we run in standalone cluster mode and we want to track the driver status | |
# we need to extract the driver id from the logs. This allows us to poll for | |
# the status using the driver id. Also, we can kill the driver when needed. | |
elif self._should_track_driver_status and not self._driver_id: | |
match_driver_id = re.search(r'(driver-[0-9\-]+)', line) | |
if match_driver_id: | |
self._driver_id = match_driver_id.groups()[0] | |
self.log.info("identified spark driver id: %s", self._driver_id) | |
self.log.info(line) | |
def _process_spark_status_log(self, itr: Iterator[Any]) -> None: | |
""" | |
Parses the logs of the spark driver status query process | |
:param itr: An iterator which iterates over the input of the subprocess | |
""" | |
driver_found = False | |
valid_response = False | |
# Consume the iterator | |
for line in itr: | |
line = line.strip() | |
# A valid Spark status response should contain a submissionId | |
if "submissionId" in line: | |
valid_response = True | |
# Check if the log line is about the driver status and extract the status. | |
if "driverState" in line: | |
self._driver_status = line.split(' : ')[1].replace(',', '').replace('\"', '').strip() | |
driver_found = True | |
self.log.debug("spark driver status log: %s", line) | |
if valid_response and not driver_found: | |
self._driver_status = "UNKNOWN" | |
def _start_driver_status_tracking(self) -> None: | |
""" | |
Polls the driver based on self._driver_id to get the status. | |
Finish successfully when the status is FINISHED. | |
Finish failed when the status is ERROR/UNKNOWN/KILLED/FAILED. | |
Possible status: | |
SUBMITTED | |
Submitted but not yet scheduled on a worker | |
RUNNING | |
Has been allocated to a worker to run | |
FINISHED | |
Previously ran and exited cleanly | |
RELAUNCHING | |
Exited non-zero or due to worker failure, but has not yet | |
started running again | |
UNKNOWN | |
The status of the driver is temporarily not known due to | |
master failure recovery | |
KILLED | |
A user manually killed this driver | |
FAILED | |
The driver exited non-zero and was not supervised | |
ERROR | |
Unable to run or restart due to an unrecoverable error | |
(e.g. missing jar file) | |
""" | |
# When your Spark Standalone cluster is not performing well | |
# due to misconfiguration or heavy loads. | |
# it is possible that the polling request will timeout. | |
# Therefore we use a simple retry mechanism. | |
missed_job_status_reports = 0 | |
max_missed_job_status_reports = 10 | |
# Keep polling as long as the driver is processing | |
while self._driver_status not in ["FINISHED", "UNKNOWN", "KILLED", "FAILED", "ERROR"]: | |
# Sleep for n seconds as we do not want to spam the cluster | |
time.sleep(self._status_poll_interval) | |
self.log.debug("polling status of spark driver with id %s", self._driver_id) | |
poll_drive_status_cmd = self._build_track_driver_status_command() | |
status_process: Any = subprocess.Popen( | |
poll_drive_status_cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
bufsize=-1, | |
universal_newlines=True, | |
) | |
self._process_spark_status_log(iter(status_process.stdout)) | |
returncode = status_process.wait() | |
if returncode: | |
if missed_job_status_reports < max_missed_job_status_reports: | |
missed_job_status_reports += 1 | |
else: | |
raise AirflowException( | |
f"Failed to poll for the driver status {max_missed_job_status_reports} times: " | |
f"returncode = {returncode}" | |
) | |
def _build_spark_driver_kill_command(self) -> List[str]: | |
""" | |
Construct the spark-submit command to kill a driver. | |
:return: full command to kill a driver | |
""" | |
connection_cmd = [self._spark_submit] | |
# The url to the spark master | |
connection_cmd += ["--master", self._master] | |
# The actual kill command | |
if self._driver_id: | |
connection_cmd += ["--kill", self._driver_id] | |
self.log.debug("Spark-Kill cmd: %s", connection_cmd) | |
return connection_cmd | |
def on_kill(self) -> None: | |
"""Kill Spark submit command""" | |
self.log.debug("Kill Command is being called") | |
if self._should_track_driver_status: | |
if self._driver_id: | |
self.log.info('Killing driver %s on cluster', self._driver_id) | |
kill_cmd = self._build_spark_driver_kill_command() | |
with subprocess.Popen( | |
kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
) as driver_kill: | |
self.log.info( | |
"Spark driver %s killed with return code: %s", self._driver_id, driver_kill.wait() | |
) | |
if self._submit_sp and self._submit_sp.poll() is None: | |
self.log.info('Sending kill signal to %s', self._spark_submit) | |
self._submit_sp.kill() | |
if self._yarn_application_id: | |
kill_cmd = f"yarn application -kill {self._yarn_application_id}".split() | |
env = {**os.environ, **(self._env or {})} | |
if self._keytab is not None and self._principal is not None: | |
# we are ignoring renewal failures from renew_from_kt | |
# here as the failure could just be due to a non-renewable ticket, | |
# we still attempt to kill the yarn application | |
renew_from_kt(self._principal, self._keytab, exit_on_fail=False) | |
env = os.environ.copy() | |
env["KRB5CCNAME"] = airflow_conf.get('kerberos', 'ccache') | |
with subprocess.Popen( | |
kill_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE | |
) as yarn_kill: | |
self.log.info("YARN app killed with return code: %s", yarn_kill.wait()) | |
# if self._kubernetes_driver_pod: | |
# self.log.info('Killing pod %s on Kubernetes', self._kubernetes_driver_pod) | |
# # Currently only instantiate Kubernetes client for killing a spark pod. | |
# try: | |
# import kubernetes | |
# client = kube_client.get_kube_client() | |
# api_response = client.delete_namespaced_pod( | |
# self._kubernetes_driver_pod, | |
# self._connection['namespace'], | |
# body=kubernetes.client.V1DeleteOptions(), | |
# pretty=True, | |
# ) | |
# self.log.info("Spark on K8s killed with response: %s", api_response) | |
# except kube_client.ApiException: | |
# self.log.exception("Exception when attempting to kill Spark on K8s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment