Skip to content

Instantly share code, notes, and snippets.

@razhangwei
Last active October 31, 2019 17:07
Show Gist options
  • Save razhangwei/75ee3503373eb29d9c7fb896d48d6c7a to your computer and use it in GitHub Desktop.
Save razhangwei/75ee3503373eb29d9c7fb896d48d6c7a to your computer and use it in GitHub Desktop.
Parallel run wrapper with paramiko #Parallel

Usage:

printf {0..100} | ./ssh_wrapper.py echo --cpus 2 --gpus 1 --email -N 2
printf myscript.py" --foo "{0..100} | ./ssh_wrapper.py python --cpus 2 --gpus 1 --email -N 2
Host Memory CPU GPU Alive
nebula-1 125.6 32 0 1
nebula-2 94.1 32 0 1
nebula-3 63.1 12 0 1
nebula-4 62.8 12 0 1
nebula-5 62.8 12 0 1
nebula-6 62.8 12 0 1
nebula-7 125.7 24 0 1
nebula-8 125.7 24 0 0
nebula-9 125.7 24 0 1
quasar-20 141.9 24 0 1
quasar-21 141.9 24 0 0
quasar-22 141.9 24 0 1
quasar-27 31.3 8 0 0
quasar-28 31.5 8 0 0
quasar-30 31.3 16 0 0
quasar-31 31.3 16 0 1
quasar-32 31.5 16 0 1
quasar-33 23.4 16 0 1
quasar-34 23.4 16 0 1
quasar-37 62.7 12 0 0
quasar-38 15.5 8 0 0
quasar-39 62.7 12 0 1
quasar-40 62.7 12 0 1
quasar-41 62.7 12 0 0
quasar-42 63 12 0 1
quasar-43 63 12 0 1
quasar-44 62.7 12 0 1
quasar-46 126.1 16 0 1
quasar-47 62.7 12 0 1
quasar-48 62.7 12 0 0
quasar-50 126.1 24 0 1
quasar-51 126.1 24 0 1
quasar-52 126.1 24 0 1
nova-1 504 80 1 1
nova-2 504 80 2 1
nova-2 504 80 2 1
#!/bin/python3
import argparse
import os
import os.path as osp
import sys
from multiprocessing import Pool, current_process
from os.path import expanduser
import random
from subprocess import check_output
import time
import warnings
import paramiko
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('exec', metavar='executable', type=str)
parser.add_argument('--name', type=str, default=None, help="Default: None")
parser.add_argument('--cpus', type=int, default=1)
parser.add_argument('--gpus', type=int, default=0)
parser.add_argument('--mem', type=float, default=1, help="Default: 1 (GB)")
parser.add_argument(
"-N", "--jobs_per_host", type=int, default=1, help="Default: 1")
parser.add_argument(
"--sleep",
type=int,
default=0,
help="Sleep certain seconds to avoid congestion. Default: 0")
parser.add_argument(
"--email", action='store_true', help="Enable email notification.")
# for SSH connection
parser.add_argument(
"--pw",
metavar='password',
type=str,
default=None,
help="Password. If not set, read it from ~/.parallel_run_pw ")
parser.add_argument(
"--host_file",
type=str,
default="~/hosts.tsv",
help="Host description file.")
return parser
def get_hosts(args):
hosts = []
with open(expanduser(args.host_file)) as fin:
fin.readline()
for line in fin:
fields = line.strip().split()
mem, cpus, gpus, alive = list(map(float, fields[1:]))
if (alive and cpus > args.cpus and mem > args.mem
and (args.gpus <= gpus)):
hosts.append(fields[0])
return hosts
def runCommand(cmd, pw):
pid = current_process().pid
host = hosts[pid % len(hosts)]
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
try:
with warnings.catch_warnings():
# NOTE: To ignore CryptographyDeprecationWarning;
# remove it in later version of paramiko (2.4.2)
warnings.simplefilter("ignore")
client.connect(host, password=pw)
except Exception:
print("***** failed to connect host ", host, " *****", sep='')
return False
if args.sleep > 0:
time.sleep(random.random() * args.sleep * len(hosts))
print("Running on {}: {}".format(host, cmd))
# execute the main commmand
stdin, stdout, stderr = client.exec_command(
"cd {}; {}".format(os.getcwd(), cmd), get_pty=True)
for line in stdout:
print("[{}]: {}".format(host, line), end="")
client.close()
return True
if __name__ == "__main__":
# generate parameter combination
args = get_parser().parse_args()
if args.pw is None:
with open(expanduser("~/.parallel_run_pw"), 'r') as fin:
args.pw = fin.readline().strip()
# get hosts
hosts = get_hosts(args)
hosts = hosts * args.jobs_per_host
# read commmand fields
exec_path = check_output(["which", args.exec])
args.exec = osp.expanduser(exec_path.decode("utf")).strip()
cmds = []
for line in sys.stdin:
cmds.append(args.exec + " " + line.strip())
n_jobs = len(cmds)
if args.email:
text = """---Arguments---
{}
---Hosts---
{}
---Commands---
{}
---Job Name---
{}
""".format(" ".join(sys.argv[1:]), hosts, "\n".join(cmds), args.name)
send_email(subject="Experiment Notification", text=text + "\nStart.")
with Pool(len(hosts)) as p:
p.starmap(runCommand, zip(cmds, [args.pw] * n_jobs))
if args.email:
send_email(
subject="Experiment Notification", text=text + "\nFinished.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment