Skip to content

Instantly share code, notes, and snippets.

@y2k-shubham
Last active March 3, 2019 18:42
Show Gist options
  • Save y2k-shubham/76eb1c9cc2e08628726ae8a770efb8db to your computer and use it in GitHub Desktop.
Save y2k-shubham/76eb1c9cc2e08628726ae8a770efb8db to your computer and use it in GitHub Desktop.
[Apache-Airflow] ssh_utils & MultiCmdSSHOperator
from typing import List, Optional, Dict, Any
from airflow.contrib.hooks.ssh_hook import SSHHook
from airflow.contrib.operators.ssh_operator import SSHOperator
from airflow.exceptions import AirflowException
from airflow.utils import apply_defaults
import ssh_utils
class MultiCmdSSHOperator(SSHOperator):
"""
An extension of SSHOperator that executes a list of commands
(rather that just a single command)
If optional list of commands is not provided in __init__
it must be supplied somehow before execute(..) method
(for example in pre_execute(..) method) otherwise
exception will be thrown
:param commands: Optional list of commands to be executed
:type commands: Optional[List[str]]
:param ssh_hook_args: Optional dictionary of arguments for SSHHook
:type ssh_hook_args: Optional[SSHHook]
"""
@apply_defaults
def __init__(self,
commands: Optional[List[str]] = None,
ssh_hook_args: Optional[Dict[str, Any]] = None,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.commands: Optional[List[str]] = (
commands if isinstance(commands, List) else list(commands)) if commands else None
self.ssh_hook_args: Optional[Dict[str, Any]] = ssh_hook_args
def execute(self, context) -> Optional[List[Dict[str, Optional[str]]]]:
if not self.commands:
raise AirflowException("no commands specified so nothing to execute here.")
else:
# instantiate SSHHook
self.ssh_hook: SSHHook = SSHHook(ssh_conn_id=self.ssh_conn_id
# create a list for holding all return values of commands run over SSH
xcoms_list: List[Dict[str, Optional[str]]] = []
for cmd in self.commands:
# just for consistency (not required)
self.command: str = cmd
# run command and gather its output in dictionary
output: Optional[str] = ssh_utils.execute_command_for_operator(operator=self, command=cmd)
xcoms_list.append({
"command": cmd,
"output": output
})
if self.do_xcom_push:
return xcoms_list
from typing import Optional, Any, Dict
from airflow.contrib.hooks.ssh_hook import SSHHook
from airflow.contrib.operators.ssh_operator import SSHOperator
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from select import select
def execute_command_for_operator(operator: BaseOperator,
ssh_hook: Optional[SSHHook] = None,
command: Optional[str] = None) -> Any:
"""
Utility method that runs a command over SSH
for an operator.
Although operator can be of any type if ssh_hook and
command params are not provided, it is assumed to be
an SSHOperator and the said params are read from it
:param operator: An operator for which command is to
to be run remotely over SSHHook
:type operator: BaseOperator or SSHOperator
:param ssh_hook: Optional SSHHook over which command
is to be run remotely
:type ssh_hook: Optional[SSHHook]
:param command: Optional command to be run remotely
:type command: Optional[str]
:return: Aggregated stdout messages of running command
:type: Optional[str] (byte-string)
"""
if not ssh_hook:
ssh_hook: SSHHook = operator.ssh_hook
if not command:
command: str = operator.command
operator.log.info(f"Executing command:-\n{command}")
with ssh_hook.get_conn() as ssh_client:
"""
Code borrowed from
- airflow.contrib.operators.SSHOperator.execute() method
- airflow.contrib.operators.SFTPOperator.execute() method
"""
# execute command over SSH
stdin, stdout, stderr = ssh_client.exec_command(command=command)
# get channels
channel = stdout.channel
# closing stdin
stdin.close()
channel.shutdown_write()
# byte-strings to hold aggregated output / error
agg_stdout = b""
agg_stderr = b""
# capture any initial output in case channel is closed already
stdout_buffer_length = len(stdout.channel.in_buffer)
if stdout_buffer_length > 0:
agg_stdout += stdout.channel.recv(stdout_buffer_length)
# read from both stdout and stderr
while not channel.closed or \
channel.recv_ready() or \
channel.recv_stderr_ready():
timeout: Optional[int] = operator.timeout if hasattr(operator, "timeout") else None
readq, _, _ = select([channel], [], [], timeout)
for c in readq:
if c.recv_ready():
line = stdout.channel.recv(len(c.in_buffer))
line = line
agg_stdout += line
operator.log.info(line.decode("utf-8").strip("\n"))
if c.recv_stderr_ready():
line = stderr.channel.recv_stderr(len(c.in_stderr_buffer))
line = line
agg_stderr += line
operator.log.warning(line.decode("utf-8").strip("\n"))
if stdout.channel.exit_status_ready() \
and not stderr.channel.recv_stderr_ready() \
and not stdout.channel.recv_ready():
stdout.channel.shutdown_read()
stdout.channel.close()
break
stdout.close()
stderr.close()
exit_status = stdout.channel.recv_exit_status()
if exit_status is 0:
# returning output if do_xcom_push is set
if hasattr(operator, "do_xcom_push") and operator.do_xcom_push:
# pickle the output (assume core.xcom_pickling is enabled)
return agg_stdout
else:
error_msg = agg_stderr.decode("utf-8")
raise AirflowException("error running cmd: {0}, error: {1}"
.format(command, error_msg))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment