Last active
October 20, 2021 12:22
-
-
Save fpytloun/25d4772d8ba58f391849 to your computer and use it in GitHub Desktop.
Cluster execution tool
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
""" | |
Tool for commands execution over clusters | |
""" | |
import os, sys, logging, urllib | |
import argparse | |
import threading | |
import paramiko | |
import time | |
import socket | |
import base64 | |
import signal | |
import getpass | |
logging.basicConfig(level=logging.WARN, format='%(levelname)s: %(message)s') | |
lg = logging.getLogger() | |
sshPool = {} | |
sshFailed = [] | |
sshDone = [] | |
sshHosts = [] | |
threadLimiter = None | |
exitcode = 0 | |
def main(): | |
""" | |
main entrance | |
""" | |
global exitcode, apiUrl | |
# Catch SIGINFO if supported | |
if hasattr(signal, 'SIGINFO'): | |
signal.signal(signal.SIGINFO, siginfo_handler) | |
if hasattr(signal, 'SIGUSR1'): | |
signal.signal(signal.SIGUSR1, siginfo_handler) | |
parser = argparse.ArgumentParser(description='Execute command on cluster', add_help=False) | |
# Required | |
group_req = parser.add_argument_group('Required arguments') | |
group_req.add_argument('command', help="Command to be executed. Use -- after arguments accepting multiple values.", nargs='?') | |
# Optional | |
group_opt = parser.add_argument_group('Optional arguments') | |
group_opt.add_argument('--system-ssh', dest='system_ssh', action='store_true', help="Call system SSH client instead of Pythonish (worse escaping, obsolete)") | |
group_opt.add_argument('--serial', '--no-parallel', dest='serial', action='store_true', help="Execute commands on hosts one-by-one") | |
group_opt.add_argument('-t', '--threads', dest='threads', type=int, default=120, help="Execute commands on hosts in x threads (default 120)") | |
group_opt.add_argument('-f', '--file', dest='file', help="Read list of nodes from file (ignores cluster)") | |
group_opt.add_argument('-m', '--machines', dest='machines', nargs='+', default=[], help="List of machines to operate on") | |
group_opt.add_argument('-u', '--user', dest='sshUser', help="SSH user to connect with, defaults to the current user", default=getpass.getuser()) | |
group_opt.add_argument('-K', '--key-file', dest='sshKeyFile', help="SSH user key file to connect with") | |
group_opt.add_argument('-d', '--domain', dest='domain', help="Location domain to use") | |
group_opt.add_argument('-e', '--exitcode', dest='exitcode', action='store_true', help="Exit with non-zero exit code if command return non-zero exit code") | |
# Output switchers | |
group_out = parser.add_argument_group('Output switchers') | |
group_out.add_argument('--debug', dest='debug', action='store_true') | |
# Action switchers | |
group_act = parser.add_argument_group('Action switchers') | |
group_act.add_argument('-h', '--help', dest='help', action='store_true', help="Show this help") | |
group_act.add_argument('-L', '--list-nodes', dest='list_nodes', action='store_true', help="List nodes where command would be executed") | |
group_act.add_argument('-I', '--interactive', '--shell', dest='interactive', action='store_true', help="Run in interactive mode, same as if command is -") | |
group_act.add_argument('-U', '--upload', dest='upload', help="Upload file to [command] on nodes") | |
args = parser.parse_args() | |
if args.debug: | |
lg.setLevel(logging.DEBUG) | |
if args.help: | |
parser.print_help() | |
sys.exit(0) | |
if args.interactive: | |
args.command = '-' | |
if args.threads != None and args.threads == 0: | |
print(base64.b64decode('ICAgICAgIF8gICAgIF8KICAgICAgIFxgXCAvYC8KICAgICAgICBcIFYgLyAgICAgICAgICAgICAgIAogICAgICAgIC8uIC5cICAgICAgIAogICAgICAgPVwgVCAvPSAgICAgICAgICAgICAgICAgIAogICAgICAgIC8gXiBcICAgICAKICAgICAgIC9cXCAvL1wKICAgICBfX1wgIiAiIC9fXyAgICAgICAgICAgCiAgICAoX19fXy9eXF9fX18pCiAgWW91J3JlIGEgVGVhcG90IQo=')) | |
sys.exit(1) | |
# Set thread limiter | |
global threadLimiter | |
if args.threads: | |
threadLimiter = threading.BoundedSemaphore(args.threads) | |
else: | |
if args.serial: | |
threadLimiter = threading.BoundedSemaphore(1) | |
else: | |
# Default limit is 120 threads at once | |
threadLimiter = threading.BoundedSemaphore(120) | |
# Can't read from stdin for multiple options | |
if args.file == '-' and args.command == '-': | |
lg.error("Can't read nodes and command from stdin, try to use -m option instead of -f") | |
sys.exit(1) | |
lg.debug("Command: %s" % args.command) | |
if args.file: | |
if args.file != '-': | |
try: | |
machines = open(args.file, 'r') | |
except: | |
lg.error("Can't open file %s" % args.file) | |
sys.exit(1) | |
else: | |
machines = sys.stdin | |
m = [] | |
for machine in machines.readlines(): | |
m.append(machine.replace('\n', '')) | |
machines = m | |
elif args.machines: | |
machines = args.machines | |
else: | |
raise RuntimeError("You need to submit list of hosts to connect to") | |
global sshHosts | |
sshHosts = machines | |
# Interactive mode | |
if args.command == '-': | |
import readline | |
readline.parse_and_bind('tab: complete') | |
readline.parse_and_bind('set editing-mode vi') | |
while True: | |
try: | |
args.command = raw_input("$> ") | |
except (KeyboardInterrupt, SystemExit, EOFError): | |
lg.debug("Interrupted") | |
sshCleanup() | |
print() | |
sys.exit(0) | |
if args.command in ['exit', 'quit']: | |
sshCleanup() | |
sys.exit(0) | |
if args.command: | |
# Do the job | |
pool = run(machines, args) | |
# Wait till all threads are done | |
try: | |
alive = len(pool) | |
while alive > 0: | |
alive = len(pool) | |
lg.debug("Waiting for %i threads" % alive) | |
for thread in pool: | |
if not thread.is_alive(): | |
alive -= 1 | |
time.sleep(0.5) | |
except (KeyboardInterrupt, SystemExit): | |
lg.debug("Received keyboard interrupt. Cleaning threads and exitting.") | |
for thread in pool: | |
if thread.is_alive(): | |
lg.debug("Killing thread %s" % thread.getName()) | |
try: | |
thread._Thread__stop() | |
except: | |
lg.error("Thread %s cannot be terminated" % thread.getName()) | |
sshCleanup() | |
sys.exit(1) | |
sshCleanup() | |
sys.exit(0) | |
# Do the job (normal mode) | |
pool = run(machines, args) | |
# Wait till all threads are done | |
try: | |
alive = len(pool) | |
while alive > 0: | |
alive = len(pool) | |
lg.debug("Waiting for %i threads" % alive) | |
for thread in pool: | |
if not thread.is_alive(): | |
alive -= 1 | |
time.sleep(0.5) | |
except (KeyboardInterrupt, SystemExit): | |
lg.debug("Received keyboard interrupt. Cleaning threads and exitting.") | |
for thread in pool: | |
if thread.is_alive(): | |
lg.debug("Killing thread %s" % thread.getName()) | |
try: | |
thread._Thread__stop() | |
except Exception as e: | |
lg.error("Thread %s cannot be terminated: %s" % (thread.getName(), e)) | |
finally: | |
sshCleanup() | |
if sshFailed: | |
lg.error("Failed connections (%s/%s): %s" % (len(sshFailed), len(sshHosts), ','.join(sshFailed))) | |
sys.exit(1) | |
if args.exitcode: | |
sys.exit(exitcode) | |
def wait_threads(): | |
""" | |
Wait until all active threads are done | |
we usually don't want to use this, because | |
it will also wait for infinite transport threads | |
""" | |
try: | |
while threading.activeCount() > 1: | |
lg.debug("Waiting for %i threads" % (threading.activeCount() - 1)) | |
time.sleep(0.5) | |
except (KeyboardInterrupt, SystemExit, EOFError): | |
threads = threading.enumerate() | |
for thread in threads: | |
lg.debug("Killing thread %s" % thread.getName()) | |
try: | |
thread._Thread__stop() | |
except Exception as e: | |
lg.error("Thread %s cannot be terminated: %s" % (thread.getName(), e)) | |
finally: | |
sshCleanup() | |
def run(machines, args): | |
lg.debug("Hosts: %s" % machines) | |
pool = [] | |
if isinstance(machines, list): | |
tmp = {} | |
for host in machines: | |
tmp[host] = { | |
'hostname' : host, | |
'ip_public' : host, | |
'ip' : host, | |
'instance_id' : None, | |
} | |
machines = tmp | |
for hostname in machines.keys(): | |
node = machines[hostname] | |
if args.domain: | |
node['connect'] = "%s.%s" % (hostname, args.domain) | |
else: | |
node['connect'] = hostname | |
if args.list_nodes: | |
if not args.ip: | |
print("{0:<25}{1}".format(hostname, machines[hostname]['instance_id'])) | |
else: | |
print("{0:<25}{1}{2:>20}".format(hostname, machines[hostname]['instance_id'], node['connect'])) | |
else: | |
if not args.command: | |
lg.error("'command' option have to be set") | |
sys.exit(1) | |
if args.upload: | |
t = threading.Thread(target=uploadFile, args=(node, args.upload, args.command, args.sshUser, args.sshKeyFile)) | |
else: | |
if args.system_ssh: | |
lg.warn('Using system SSH is obsolete and may be buggy. Avoid using this option!') | |
t = threading.Thread(name=hostname, target=runSSH, args=(node['connect'], args.command, args.sshUser, args.sshKeyFile)) | |
else: | |
t = threading.Thread(name=hostname, target=runRemote, args=(node, args.command, args.sshUser, args.sshKeyFile)) | |
t.start() | |
pool.append(t) | |
return pool | |
def runSSH(name, command, user, keyFile=None): | |
""" | |
execute command on <name> host | |
""" | |
global exitcode | |
threadLimiter.acquire() | |
try: | |
# popen 0 - last argument means unbuffered output | |
lg.debug("Execute '%s' on '%s'" % (command, name)) | |
command = 'source /etc/profile >/dev/null;%s' % command | |
fh = os.popen(('ssh -qAY -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no -o IdentityFile=%s %s@%s -- \"' % (keyFile, user, name)) + command + ' 2>&1\"', 'r', 0) | |
for line in fh: | |
sys.stdout.write ("%s: %s" % (name, line)) | |
sys.stdout.flush() | |
fh.close() | |
print("close: %s os.status" % (fh.close(), os.WEXITSTATUS)) | |
if os.WEXITSTATUS != 0: | |
exitcode = os.WEXITSTATUS | |
finally: | |
threadLimiter.release() | |
def sshCleanup(): | |
for ssh in sshPool.keys(): | |
lg.debug("Closing connection to %s" % ssh) | |
sshPool[ssh].close() | |
def runRemote(node, command, user, keyFile=None): | |
""" | |
execute command on <node> host with Paramiko | |
""" | |
global sshPool | |
global sshDone | |
global exitcode | |
threadLimiter.acquire() | |
try: | |
lg.debug("Execute '%s' on '%s'" % (command, node)) | |
command = 'source /etc/profile >/dev/null;%s' % command | |
connect = node['connect'] | |
if connect in sshFailed: | |
return False | |
try: sshPool[connect] | |
except: | |
sshPool[connect] = paramiko.SSHClient() | |
sshPool[connect].set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
try: | |
sshPool[connect].load_system_host_keys() | |
except paramiko.SSHException as e: | |
lg.error("Can't load system known hosts: %s" % e) | |
sshFailed.append(connect) | |
return False | |
try: | |
sshPool[connect].connect(connect, username=user, timeout=5, key_filename=keyFile) | |
except KeyboardInterrupt: | |
lg.info("Interrupted") | |
sys.exit(0) | |
except (socket.gaierror, socket.error) as e: | |
lg.error("Can't connect to %s (%s): %s" % (connect, node['hostname'], e)) | |
sshFailed.append(connect) | |
return False | |
except socket.timeout as e: | |
lg.error("Timeout during connecting to %s (%s)" % (connect, node['hostname'])) | |
sshFailed.append(connect) | |
return False | |
except paramiko.SSHException as e: | |
lg.error("Can't connect to %s (%s) as user %s: %s" % (connect, node['hostname'], user, e)) | |
sshFailed.append(connect) | |
return False | |
trans = sshPool[connect].get_transport() | |
if not trans: | |
lg.error("Can't get transport for connection %s (%s)", (connect, node['hostname'])) | |
sshFailed.append(connect) | |
return False | |
chan = trans.open_session() | |
if not chan: | |
lg.error("Connection to %s (%s) no longer active", (connect, node['hostname'])) | |
sshFailed.append(connect) | |
return False | |
chan.get_pty() | |
# Timeout 5 seconds for first command | |
# to test connection | |
chan.settimeout(5) | |
try: | |
output = chan.makefile() | |
chan.exec_command('hostname') | |
for line in output: | |
if command != 'hostname': | |
lg.debug("Connected to %s (hostname %s)" % (connect, line.replace("\r\n", "\n"))) | |
else: | |
sys.stdout.write("%s: %s" % (connect, line.replace("\r\n", "\n"))) | |
sys.stdout.flush() | |
return True | |
except socket.timeout: | |
lg.error("Timeout during communication with %s (%s)" % (connect, node['hostname'])) | |
chan.close() | |
return False | |
# Channel without timeout for our command | |
chan = trans.open_session() | |
chan.settimeout(None) | |
chan.get_pty() | |
output = chan.makefile() | |
chan.exec_command(command) | |
for line in output: | |
sys.stdout.write("%s: %s" % (node['hostname'], line.replace("\r\n", "\n"))) | |
sys.stdout.flush() | |
# Cleanup | |
chan.close() | |
status = chan.recv_exit_status() | |
lg.debug("Exit status: %s" % status) | |
if status != -1 and status != 0 : | |
exitcode = status | |
finally: | |
sshDone.append(node['hostname']) | |
threadLimiter.release() | |
def uploadFile(node, localFile, remoteFile, user, keyFile=None): | |
""" | |
upload file to remote host with Paramiko | |
""" | |
connect = node['connect'] | |
lg.debug("Upload '%s' on '%s:%s'" % (localFile, node['hostname'], remoteFile)) | |
ssh = paramiko.SSHClient() | |
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) | |
ssh.load_system_host_keys() | |
try: | |
ssh.connect(connect, username=user, timeout=5, key_filename=keyFile) | |
except: | |
lg.warn("Can't connect to %s (%s)" % (connect, node['hostname'])) | |
return False | |
ftp = ssh.open_sftp() | |
try: | |
ftp.put(localFile, remoteFile) | |
except (OSError, IOError) as e: | |
lg.error(e) | |
sys.exit(1) | |
ftp.close() | |
ssh.close() | |
def siginfo_handler(signum, frame): | |
if threading.activeCount() > 1: | |
nodes_active = [] | |
for thread in threading.enumerate(): | |
if thread.getName() != 'MainThread': | |
nodes_active.append(thread.getName()) | |
print("--") | |
print("Done: %s/%s" % (len(sshDone), len(sshHosts))) | |
print("SSH connections: %s" % len(sshPool)) | |
print("Connections failed: %s" % len(sshFailed)) | |
print("Threads count: %s" % threading.activeCount()) | |
print("Thread names: %s" % ','.join(nodes_active)) | |
print("--") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment