Last active
March 3, 2019 18:42
-
-
Save y2k-shubham/76eb1c9cc2e08628726ae8a770efb8db to your computer and use it in GitHub Desktop.
[Apache-Airflow] ssh_utils & MultiCmdSSHOperator
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional, Dict, Any | |
from airflow.contrib.hooks.ssh_hook import SSHHook | |
from airflow.contrib.operators.ssh_operator import SSHOperator | |
from airflow.exceptions import AirflowException | |
from airflow.utils import apply_defaults | |
import ssh_utils | |
class MultiCmdSSHOperator(SSHOperator): | |
""" | |
An extension of SSHOperator that executes a list of commands | |
(rather that just a single command) | |
If optional list of commands is not provided in __init__ | |
it must be supplied somehow before execute(..) method | |
(for example in pre_execute(..) method) otherwise | |
exception will be thrown | |
:param commands: Optional list of commands to be executed | |
:type commands: Optional[List[str]] | |
:param ssh_hook_args: Optional dictionary of arguments for SSHHook | |
:type ssh_hook_args: Optional[SSHHook] | |
""" | |
@apply_defaults | |
def __init__(self, | |
commands: Optional[List[str]] = None, | |
ssh_hook_args: Optional[Dict[str, Any]] = None, | |
*args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.commands: Optional[List[str]] = ( | |
commands if isinstance(commands, List) else list(commands)) if commands else None | |
self.ssh_hook_args: Optional[Dict[str, Any]] = ssh_hook_args | |
def execute(self, context) -> Optional[List[Dict[str, Optional[str]]]]: | |
if not self.commands: | |
raise AirflowException("no commands specified so nothing to execute here.") | |
else: | |
# instantiate SSHHook | |
self.ssh_hook: SSHHook = SSHHook(ssh_conn_id=self.ssh_conn_id | |
# create a list for holding all return values of commands run over SSH | |
xcoms_list: List[Dict[str, Optional[str]]] = [] | |
for cmd in self.commands: | |
# just for consistency (not required) | |
self.command: str = cmd | |
# run command and gather its output in dictionary | |
output: Optional[str] = ssh_utils.execute_command_for_operator(operator=self, command=cmd) | |
xcoms_list.append({ | |
"command": cmd, | |
"output": output | |
}) | |
if self.do_xcom_push: | |
return xcoms_list |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Any, Dict | |
from airflow.contrib.hooks.ssh_hook import SSHHook | |
from airflow.contrib.operators.ssh_operator import SSHOperator | |
from airflow.exceptions import AirflowException | |
from airflow.models import BaseOperator | |
from select import select | |
def execute_command_for_operator(operator: BaseOperator, | |
ssh_hook: Optional[SSHHook] = None, | |
command: Optional[str] = None) -> Any: | |
""" | |
Utility method that runs a command over SSH | |
for an operator. | |
Although operator can be of any type if ssh_hook and | |
command params are not provided, it is assumed to be | |
an SSHOperator and the said params are read from it | |
:param operator: An operator for which command is to | |
to be run remotely over SSHHook | |
:type operator: BaseOperator or SSHOperator | |
:param ssh_hook: Optional SSHHook over which command | |
is to be run remotely | |
:type ssh_hook: Optional[SSHHook] | |
:param command: Optional command to be run remotely | |
:type command: Optional[str] | |
:return: Aggregated stdout messages of running command | |
:type: Optional[str] (byte-string) | |
""" | |
if not ssh_hook: | |
ssh_hook: SSHHook = operator.ssh_hook | |
if not command: | |
command: str = operator.command | |
operator.log.info(f"Executing command:-\n{command}") | |
with ssh_hook.get_conn() as ssh_client: | |
""" | |
Code borrowed from | |
- airflow.contrib.operators.SSHOperator.execute() method | |
- airflow.contrib.operators.SFTPOperator.execute() method | |
""" | |
# execute command over SSH | |
stdin, stdout, stderr = ssh_client.exec_command(command=command) | |
# get channels | |
channel = stdout.channel | |
# closing stdin | |
stdin.close() | |
channel.shutdown_write() | |
# byte-strings to hold aggregated output / error | |
agg_stdout = b"" | |
agg_stderr = b"" | |
# capture any initial output in case channel is closed already | |
stdout_buffer_length = len(stdout.channel.in_buffer) | |
if stdout_buffer_length > 0: | |
agg_stdout += stdout.channel.recv(stdout_buffer_length) | |
# read from both stdout and stderr | |
while not channel.closed or \ | |
channel.recv_ready() or \ | |
channel.recv_stderr_ready(): | |
timeout: Optional[int] = operator.timeout if hasattr(operator, "timeout") else None | |
readq, _, _ = select([channel], [], [], timeout) | |
for c in readq: | |
if c.recv_ready(): | |
line = stdout.channel.recv(len(c.in_buffer)) | |
line = line | |
agg_stdout += line | |
operator.log.info(line.decode("utf-8").strip("\n")) | |
if c.recv_stderr_ready(): | |
line = stderr.channel.recv_stderr(len(c.in_stderr_buffer)) | |
line = line | |
agg_stderr += line | |
operator.log.warning(line.decode("utf-8").strip("\n")) | |
if stdout.channel.exit_status_ready() \ | |
and not stderr.channel.recv_stderr_ready() \ | |
and not stdout.channel.recv_ready(): | |
stdout.channel.shutdown_read() | |
stdout.channel.close() | |
break | |
stdout.close() | |
stderr.close() | |
exit_status = stdout.channel.recv_exit_status() | |
if exit_status is 0: | |
# returning output if do_xcom_push is set | |
if hasattr(operator, "do_xcom_push") and operator.do_xcom_push: | |
# pickle the output (assume core.xcom_pickling is enabled) | |
return agg_stdout | |
else: | |
error_msg = agg_stderr.decode("utf-8") | |
raise AirflowException("error running cmd: {0}, error: {1}" | |
.format(command, error_msg)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment