Created
April 21, 2020 19:33
-
-
Save mamelara/ba486f8e65e7c0f2591690ba51f8b3bd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def worker(ctx: object, heartbeat_interval_param: int, custom_log_dir: str, | |
custom_job_log_dir_name: str, pool_name_param: str, dry_run: bool, | |
slurm_job_id_param: int, worker_type_param: str, cluster_name_param: str, | |
worker_clone_time_rate_param: float, num_workers_per_node_param: int, | |
worker_id_param: str, charging_account_param: str, | |
num_nodes_to_request_param: int, num_cores_to_request_param: int, | |
constraint_param: str, mem_per_node_to_request_param: str, | |
mem_per_cpu_to_request_param: str, | |
qos_param: str, job_time_to_request_param: str) -> int: | |
global CONFIG | |
CONFIG = ctx.obj['config'] | |
debug = ctx.obj['debug'] | |
# config file has precedence | |
config_debug = CONFIG.configparser.getboolean("SITE", "debug") | |
if config_debug: | |
debug = config_debug | |
global DEBUG | |
DEBUG = debug | |
global WORKER_TYPE | |
WORKER_TYPE = CONFIG.constants.WORKER_TYPE | |
global HB_MSG | |
HB_MSG = CONFIG.constants.HB_MSG | |
global VERSION | |
VERSION = CONFIG.constants.VERSION | |
global COMPUTE_RESOURCES | |
COMPUTE_RESOURCES = CONFIG.constants.COMPUTE_RESOURCES | |
global TASK_TYPE | |
TASK_TYPE = CONFIG.constants.TASK_TYPE | |
global DONE_FLAGS | |
DONE_FLAGS = CONFIG.constants.DONE_FLAGS | |
global NUM_WORKER_PROCS | |
NUM_WORKER_PROCS = CONFIG.constants.NUM_WORKER_PROCS | |
global TASK_KILL_TIMEOUT_MINUTE | |
TASK_KILL_TIMEOUT_MINUTE = CONFIG.constants.TASK_KILL_TIMEOUT_MINUTE | |
global CNAME | |
CNAME = CONFIG.configparser.get("SITE", "instance_name") | |
global JTM_HOST_NAME | |
JTM_HOST_NAME = CONFIG.configparser.get("SITE", "jtm_host_name") | |
global JTM_INNER_REQUEST_Q | |
JTM_INNER_REQUEST_Q = CONFIG.configparser.get("JTM", "jtm_inner_request_q") | |
global CTR | |
CTR = CONFIG.configparser.getfloat("JTM", "clone_time_rate") | |
global JTM_INNER_MAIN_EXCH | |
JTM_INNER_MAIN_EXCH = CONFIG.configparser.get("JTM", "jtm_inner_main_exch") | |
global JTM_CLIENT_HB_EXCH | |
JTM_CLIENT_HB_EXCH = CONFIG.configparser.get("JTM", "jtm_client_hb_exch") | |
global JTM_WORKER_HB_EXCH | |
JTM_WORKER_HB_EXCH = CONFIG.configparser.get("JTM", "jtm_worker_hb_exch") | |
global CLIENT_HB_Q_POSTFIX | |
CLIENT_HB_Q_POSTFIX = CONFIG.configparser.get("JTM", "client_hb_q_postfix") | |
global WORKER_HB_Q_POSTFIX | |
WORKER_HB_Q_POSTFIX = CONFIG.configparser.get("JTM", "worker_hb_q_postfix") | |
global JTM_TASK_KILL_EXCH | |
JTM_TASK_KILL_EXCH = CONFIG.configparser.get("JTM", "jtm_task_kill_exch") | |
global JTM_TASK_KILL_Q | |
JTM_TASK_KILL_Q = CONFIG.configparser.get("JTM", "jtm_task_kill_q") | |
global JTM_WORKER_POISON_EXCH | |
JTM_WORKER_POISON_EXCH = CONFIG.configparser.get("JTM", "jtm_worker_poison_exch") | |
global JTM_WORKER_POISON_Q | |
JTM_WORKER_POISON_Q = CONFIG.configparser.get("JTM", "jtm_worker_poison_q") | |
global NUM_PROCS_CHECK_INTERVAL | |
NUM_PROCS_CHECK_INTERVAL = CONFIG.configparser.getfloat("JTM", "num_procs_check_interval") | |
global ENV_ACTIVATION | |
ENV_ACTIVATION = CONFIG.configparser.get("JTM", "env_activation") | |
WORKER_CONFIG_FILE = CONFIG.configparser.get("JTM", "worker_config_file") | |
RMQ_HOST = CONFIG.configparser.get("RMQ", "host") | |
RMQ_PORT = CONFIG.configparser.get("RMQ", "port") | |
USER_NAME = CONFIG.configparser.get("SITE", "user_name") | |
PRODUCTION = False | |
if CONFIG.configparser.get("JTM", "run_mode") == "prod": | |
PRODUCTION = True | |
JOBTIME = CONFIG.configparser.get("SLURM", "jobtime") | |
CONSTRAINT = CONFIG.configparser.get("SLURM", "constraint") | |
CHARGE_ACCNT = CONFIG.configparser.get("SLURM", "charge_accnt") | |
QOS = CONFIG.configparser.get("SLURM", "qos") | |
PARTITION = CONFIG.configparser.get("SLURM", "partition") | |
MEMPERCPU = CONFIG.configparser.get("SLURM", "mempercpu") | |
MEMPERNODE = CONFIG.configparser.get("SLURM", "mempernode") | |
NWORKERS = CONFIG.configparser.getint("JTM", "num_workers_per_node") | |
NCPUS = CONFIG.configparser.getint("SLURM", "ncpus") | |
global FILE_CHECK_INTERVAL | |
FILE_CHECK_INTERVAL = CONFIG.configparser.getfloat("JTM", "file_check_interval") | |
global FILE_CHECKING_MAX_TRIAL | |
FILE_CHECKING_MAX_TRIAL = CONFIG.configparser.getint("JTM", "file_checking_max_trial") | |
global FILE_CHECK_INT_INC | |
FILE_CHECK_INT_INC = CONFIG.configparser.getfloat("JTM", "file_check_int_inc") | |
# Job dir setting | |
job_script_dir_name = os.path.join(CONFIG.configparser.get("JTM", "log_dir"), "job") | |
if custom_job_log_dir_name: | |
job_script_dir_name = custom_job_log_dir_name | |
make_dir(job_script_dir_name) | |
# Log dir setting | |
log_dir_name = os.path.join(CONFIG.configparser.get("JTM", "log_dir"), "log") | |
if custom_log_dir: | |
log_dir_name = custom_log_dir | |
make_dir(log_dir_name) | |
print("JTM Worker, version: {}".format(VERSION)) | |
# Set uniq worker id if worker id is provided in the params | |
if worker_id_param: | |
global UNIQ_WORKER_ID | |
UNIQ_WORKER_ID = worker_id_param | |
# Logger setting | |
log_level = "info" | |
if DEBUG: | |
log_level = "debug" | |
setup_custom_logger(log_level, log_dir_name, | |
1, 1, | |
worker_id=UNIQ_WORKER_ID) | |
logger.info("\n*****************\nDebug mode is %s\n*****************" | |
% ("ON" if DEBUG else "OFF")) | |
# # Todo: site specific setting --> remove | |
# CORI_KNL_CHARGE_ACCNT = CONFIG.configparser.get("SLURM", "knl_charge_accnt") | |
# CORI_KNL_QOS = CONFIG.configparser.get("SLURM", "knl_qos") | |
hearbeat_interval = CONFIG.configparser.getfloat("JTM", "worker_hb_send_interval") | |
logger.info("Set jtm log file location to %s", log_dir_name) | |
logger.info("Set jtm job file location to %s", job_script_dir_name) | |
logger.info("RabbitMQ broker: %s", RMQ_HOST) | |
logger.info("RabbitMQ port: %s", RMQ_PORT) | |
logger.info("Pika version: %s", pika.__version__) | |
logger.info("JTM user name: %s", USER_NAME) | |
logger.info("Unique worker ID: %s", UNIQ_WORKER_ID) | |
logger.info("\n*****************\nRun mode is %s\n*****************" | |
% ("PROD" if PRODUCTION else "DEV")) | |
logger.info("env activation: %s", ENV_ACTIVATION) | |
logger.info("JTM config file: %s" % (CONFIG.config_file)) | |
# Slurm config | |
num_nodes_to_request = 0 | |
if num_nodes_to_request_param: | |
num_nodes_to_request = num_nodes_to_request_param | |
# Todo | |
# Cori and JGI Cloud are exclusive allocation. So this is not needed. | |
# assert mem_per_node_to_request_param is not None, "-N needs --mem-per-cpu (-mc) setting." | |
# 11.13.2018 decided to remove all default values from argparse | |
num_workers_per_node = num_workers_per_node_param if num_workers_per_node_param else NWORKERS | |
assert num_workers_per_node > 0 | |
mem_per_cpu_to_request = mem_per_cpu_to_request_param if mem_per_cpu_to_request_param else MEMPERCPU | |
mem_per_node_to_request = mem_per_node_to_request_param if mem_per_node_to_request_param else MEMPERNODE | |
assert mem_per_cpu_to_request | |
assert mem_per_node_to_request | |
num_cpus_to_request = num_cores_to_request_param if num_cores_to_request_param else NCPUS | |
assert num_cpus_to_request | |
# Set CPU affinity for limiting the number of cores to use | |
if worker_type_param != "manual" and worker_id_param and worker_id_param.find('_') != -1: | |
# ex) | |
# total_cpu_num = 32, num_workers_per_node_param = 4 | |
# split_cpu_num = 8 | |
# worker_number - 1 == 0 --> [0, 1, 2, 3, 4, 5, 6, 7] | |
# worker_number - 1 == 1 --> [8, 9, 10, 11, 12, 13, 14, 15] | |
proc = psutil.Process(PARENT_PROCESS_ID) | |
try: | |
# Use the appended worker id number as worker_number | |
# ex) 5wZwyCM8rxgNtERsU8znJU_1 --> extract "1" --> worker number | |
worker_number = int(worker_id_param.split('_')[-1]) - 1 | |
except ValueError: | |
logger.exception("Not an expected worker ID. Cancelling CPU affinity setting") | |
else: | |
# Note: may need to use num_cpus_to_request outside LBL | |
total_cpu_num = psutil.cpu_count() | |
logger.info("Total number of cores available: {}".format(total_cpu_num)) | |
split_cpu_num = int(total_cpu_num / num_workers_per_node) | |
cpu_affinity_list = list(range(worker_number * split_cpu_num, | |
((worker_number + 1) * split_cpu_num))) | |
logger.info("Set CPU affinity to use: {}".format(cpu_affinity_list)) | |
try: | |
proc.cpu_affinity(cpu_affinity_list) | |
except Exception as e: | |
logger.exception("Failed to set the CPU usage limit: %s" % (e)) | |
sys.exit(1) | |
# Set memory upper limit | |
# Todo: May need to use all free_memory on Cori and Lbl | |
system_free_mem_bytes = get_free_memory() | |
logger.info("Total available memory (MBytes): %d" | |
% (system_free_mem_bytes / 1024.0 / 1024.0)) | |
if worker_type_param != "manual" and num_workers_per_node > 1: | |
try: | |
mem_per_node_to_request_byte = int(mem_per_node_to_request.lower() | |
.replace("gb", "") | |
.replace("g", "")) * 1024.0 * 1024.0 * 1024.0 | |
logger.info("Requested memory for this worker (MBytes): %d" | |
% (mem_per_node_to_request_byte / 1024.0 / 1024.0)) | |
# if requested mempernode is larger than system avaiable mem space | |
if system_free_mem_bytes < mem_per_node_to_request_byte: | |
logger.critical("Requested memory space is not available") | |
logger.critical("Available space: %d (MBytes)" | |
% (system_free_mem_bytes / 1024.0 / 1024.0)) | |
logger.critical("Requested space: %d (MBytes)" | |
% (mem_per_node_to_request_byte / 1024.0 / 1024.0)) | |
# Option 1 | |
# mem_per_node_to_request_byte = system_free_mem_bytes | |
# Option 2 | |
raise MemoryError | |
MEM_LIMIT_PER_WORKER_BYTES = int(mem_per_node_to_request_byte / | |
num_workers_per_node) | |
except Exception as e: | |
logger.exception("Failed to compute the memory limit: %s", mem_per_node_to_request) | |
logger.exception(e) | |
sys.exit(1) | |
try: | |
soft, hard = resource.getrlimit(resource.RLIMIT_AS) | |
resource.setrlimit(resource.RLIMIT_AS, (MEM_LIMIT_PER_WORKER_BYTES, hard)) | |
logger.info("Set the memory usage upper limit (MBytes): %d" | |
% (MEM_LIMIT_PER_WORKER_BYTES / 1024.0 / 1024.0)) | |
except Exception as e: | |
logger.exception("Failed to set the memory usage limit: %s", mem_per_node_to_request) | |
logger.exception(e) | |
sys.exit(1) | |
job_time_to_request = job_time_to_request_param if job_time_to_request_param else JOBTIME | |
constraint = constraint_param if constraint_param else CONSTRAINT | |
charging_account = charging_account_param if charging_account_param else CHARGE_ACCNT | |
qos = qos_param if qos_param else QOS | |
global THIS_WORKER_TYPE | |
THIS_WORKER_TYPE = worker_type_param | |
job_name = "jtm_worker_" + pool_name_param | |
# Set task queue name | |
inner_task_request_queue = None | |
if heartbeat_interval_param: | |
hearbeat_interval = heartbeat_interval_param | |
# Start hb receive thread | |
tp_name = "" | |
if pool_name_param: | |
tp_name = pool_name_param | |
assert pool_name_param is not None, "User pool name is not set" | |
inner_task_request_queue = JTM_INNER_REQUEST_Q + "." + pool_name_param | |
worker_clone_time_rate = worker_clone_time_rate_param if worker_clone_time_rate_param else CTR | |
if THIS_WORKER_TYPE in ("static", "dynamic"): | |
assert cluster_name_param != "" and \ | |
cluster_name_param != "local", "Static or dynamic worker needs a cluster setting (-cl)." | |
slurm_job_id = slurm_job_id_param | |
cluster_name = cluster_name_param | |
if cluster_name == "cori" and mem_per_cpu_to_request != "" and \ | |
float(mem_per_cpu_to_request.replace("GB", "").replace("G", "").replace("gb", "")) > 1.0: | |
logger.critical("--mem-per-cpu in Cori shouldn't be larger than 1GB. User '--mem' instead.") | |
sys.exit(1) | |
logger.info("RabbitMQ broker: %s", RMQ_HOST) | |
logger.info("Task queue name: %s", inner_task_request_queue) | |
logger.info("Worker type: %s", THIS_WORKER_TYPE) | |
if slurm_job_id == 0 and THIS_WORKER_TYPE in ["static", "dynamic"]: | |
batch_job_script_file = os.path.join(job_script_dir_name, "jtm_%s_worker_%s.job" % | |
(THIS_WORKER_TYPE, UNIQ_WORKER_ID)) | |
batch_job_script_str = "" | |
batch_job_misc_params = "" | |
worker_config = CONFIG.config_file if CONFIG else "" | |
if WORKER_CONFIG_FILE: | |
worker_config = WORKER_CONFIG_FILE | |
if cluster_name in ("cori", "lawrencium", "jgi_cloud", "jaws_lbl_gov", "lbl", "jgi_cluster"): | |
with open(batch_job_script_file, "w") as jf: | |
batch_job_script_str += "#!/bin/bash -l" | |
if cluster_name in ("cori"): | |
if num_nodes_to_request_param: | |
batch_job_script_str += """ | |
#SBATCH -N %(num_nodes_to_request)d | |
#SBATCH --mem=%(mem)s""" % dict(num_nodes_to_request=num_nodes_to_request, mem=mem_per_node_to_request) | |
batch_job_misc_params += " -N %(num_nodes_to_request)d -m %(mem)s" % \ | |
dict(num_nodes_to_request=num_nodes_to_request, | |
mem=mem_per_node_to_request) | |
if num_cores_to_request_param: | |
batch_job_script_str += """ | |
#SBATCH -c %(num_cores)d""" % dict(num_cores=num_cpus_to_request) | |
batch_job_misc_params += " -c %(num_cores)d" % \ | |
dict(num_cores=num_cpus_to_request) | |
else: | |
batch_job_script_str += """ | |
#SBATCH -c %(num_cores)d""" % dict(num_cores=num_cpus_to_request) | |
batch_job_misc_params += " -c %(num_cores)d" % \ | |
dict(num_cores=num_cpus_to_request) | |
if mem_per_node_to_request: | |
batch_job_script_str += """ | |
#SBATCH --mem=%(mem)s""" % dict(mem=mem_per_node_to_request) | |
batch_job_misc_params += " -m %(mem)s " % \ | |
dict(mem=mem_per_node_to_request) | |
else: | |
batch_job_script_str += """ | |
#SBATCH --mem-per-cpu=%(mempercore)s""" % dict(mempercore=mem_per_cpu_to_request) | |
batch_job_misc_params += " -mc %(mempercore)s" % \ | |
dict(mempercore=mem_per_cpu_to_request) | |
if worker_id_param: | |
batch_job_misc_params += " -wi %(worker_id)s_${i}" % \ | |
dict(worker_id=UNIQ_WORKER_ID) | |
########################### | |
if 1: | |
# Need to set both --qos=genepool (or genepool_shared) _and_ -A fungalp | |
# OR | |
# no qos _and_ -A m342 _and_ -C haswell | |
# Note: currently constraint in ["haswell" | "knl"] | |
if constraint == "haswell": | |
if qos_param: | |
batch_job_script_str += """ | |
#SBATCH -q %(qosname)s""" % dict(qosname=qos) | |
batch_job_misc_params += " -q %(qosname)s" % dict(qosname=qos) | |
else: | |
batch_job_script_str += """ | |
#SBATCH -q %(qosname)s""" % dict(qosname=qos) | |
batch_job_script_str += """ | |
#SBATCH -C haswell""" | |
if charging_account == "m342": | |
batch_job_misc_params += " -A %(sa)s" % dict(sa="m342") | |
batch_job_script_str += """ | |
#SBATCH -A %(charging_account)s""" % dict(charging_account=charging_account) | |
elif constraint == "knl": | |
# Note: Basic KNL setting = "-q regular -A m342 -C knl" | |
# | |
# Note: KNL MCDRAM setting -> cache or flat | |
# cache mode - MCDRAM is configured entirely as a last-level cache (L3) | |
# flat mode - MCDRAM is configured entirely as addressable memory | |
# ex) #SBATCH -C knl,quad,cache | |
# ex) #SBATCH -C knl,quad,flat | |
# --> srun <srun options> numactl -p 1 yourapplication.x | |
# | |
# Note: for knl, we should use m342 | |
# | |
# Note: for knl, charging_account can be set via runtime (like lanl, m3408) | |
# | |
batch_job_script_str += """ | |
#SBATCH -C knl | |
#SBATCH -A %(charging_account)s | |
#SBATCH -q %(qosname)s""" % \ | |
dict(charging_account=charging_account, qosname=qos) | |
batch_job_misc_params += " -A %(charging_account)s -q %(qosname)s" % \ | |
dict(charging_account=charging_account, qosname=qos) | |
elif constraint == "skylake": | |
# Example usage with skylakte for Brian F. | |
# 120G | |
# ====================== | |
# -t 48:00:00 -c 16 --job-name=mga-627530 --mem=115G --qos=genepool_special | |
# --exclusive -A gtrqc | |
# | |
# 250G | |
# ====================== | |
# -t 96:00:00 -c 72 --job-name=mga-627834 --mem=240G -C skylake --qos=jgi_exvivo | |
# -A gtrqc | |
# | |
# 500G | |
# ====================== | |
# -t 96:00:00 -c 72 --job-name=mga-627834 --mem=240G -C skylake --qos=jgi_exvivo | |
# -A gtrqc | |
batch_job_script_str += """ | |
#SBATCH -C skylake | |
#SBATCH -A %(charging_account)s | |
#SBATCH -q %(qosname)s""" % \ | |
dict(charging_account=charging_account, qosname=qos) | |
batch_job_misc_params += " -A %(charging_account)s -q %(qosname)s" % \ | |
dict(charging_account=charging_account, qosname=qos) | |
excl_param = "" | |
if constraint != "skylake": | |
excl_param = "#SBATCH --exclusive" | |
tq_param = "" | |
if pool_name_param: | |
tq_param = "-p " + pool_name_param | |
batch_job_script_str += """ | |
#SBATCH -t %(wall_time)s | |
#SBATCH --job-name=%(job_name)s | |
#SBATCH -o %(job_dir)s/jtm_%(worker_type)s_worker_%(worker_id)s.out | |
#SBATCH -e %(job_dir)s/jtm_%(worker_type)s_worker_%(worker_id)s.err | |
%(exclusive)s | |
module unload python | |
%(env_activation_cmd)s | |
%(export_jtm_config_file)s | |
for i in {1..%(num_workers_per_node)d} | |
do | |
echo "jobid: $SLURM_JOB_ID" | |
jtm %(set_jtm_config_file)s %(debug)s worker --slurm_job_id $SLURM_JOB_ID \ | |
-cl cori \ | |
-wt %(worker_type)s \ | |
-t %(wall_time)s \ | |
--clone_time_rate %(clone_time_rate)f %(task_queue)s \ | |
--num_worker_per_node %(num_workers_per_node)d \ | |
-C %(constraint)s \ | |
-m %(mem)s \ | |
%(other_params)s & | |
sleep 1 | |
done | |
wait | |
""" % \ | |
dict(debug="--debug" if DEBUG else "", | |
wall_time=job_time_to_request, | |
job_dir=job_script_dir_name, | |
worker_id=UNIQ_WORKER_ID, | |
worker_type=THIS_WORKER_TYPE, | |
clone_time_rate=worker_clone_time_rate, | |
task_queue=tq_param, | |
num_workers_per_node=num_workers_per_node, | |
env_activation_cmd=ENV_ACTIVATION, | |
other_params=batch_job_misc_params, | |
constraint=constraint, | |
mem=mem_per_node_to_request, | |
job_name=job_name, | |
exclusive=excl_param, | |
export_jtm_config_file="export JTM_CONFIG_FILE=%s" | |
% worker_config, | |
set_jtm_config_file="--config=%s" | |
% worker_config) | |
elif cluster_name in ("lawrencium", "jgi_cloud", "jaws_lbl_gov", "jgi_cluster", "lbl"): | |
if worker_id_param: | |
batch_job_misc_params += " -wi %(worker_id)s_${i}" \ | |
% dict(worker_id=UNIQ_WORKER_ID) | |
tp_param = "" | |
if pool_name_param: | |
tp_param = "-p " + pool_name_param | |
part_param = "" | |
if cluster_name == "lawrencium": | |
part_param = PARTITION | |
else: | |
part_param = PARTITION | |
qos_param = "" | |
if cluster_name == "lawrencium": | |
qos_param = QOS | |
else: | |
qos_param = QOS | |
charge_param = "" | |
if cluster_name == "lawrencium": | |
charge_param = CHARGE_ACCNT | |
else: | |
charge_param = CHARGE_ACCNT | |
nnode_param = 1 | |
if num_nodes_to_request_param: | |
nnode_param = num_nodes_to_request | |
mnode_param = "#SBATCH --mem=%(mem)s" \ | |
% dict(mem=mem_per_node_to_request) | |
batch_job_script_str += """ | |
#SBATCH --time=%(wall_time)s | |
#SBATCH --job-name=%(job_name)s | |
#SBATCH --partition=%(partition_name)s | |
#SBATCH --qos=%(qosname)s | |
#SBATCH --account=%(charging_account)s | |
#SBATCH --nodes=%(num_nodes_to_request)d | |
%(mem_per_node_setting)s | |
#SBATCH -o %(job_dir)s/jtm_%(worker_type)s_worker_%(worker_id)s.out | |
#SBATCH -e %(job_dir)s/jtm_%(worker_type)s_worker_%(worker_id)s.err | |
%(env_activation_cmd)s | |
%(export_jtm_config_file)s | |
for i in {1..%(num_workers_per_node)d} | |
do | |
echo "jobid: $SLURM_JOB_ID" | |
jtm %(set_jtm_config_file)s %(debug)s worker --slurm_job_id $SLURM_JOB_ID \ | |
-cl %(lbl_cluster_name)s \ | |
-wt %(worker_type)s \ | |
-t %(wall_time)s \ | |
--clone_time_rate %(clone_time_rate)f %(task_queue)s \ | |
--num_worker_per_node %(num_workers_per_node)d \ | |
-m %(mem)s \ | |
%(other_params)s & | |
sleep 1 | |
done | |
wait | |
""" % \ | |
dict(debug="--debug" if DEBUG else "", | |
wall_time=job_time_to_request, | |
job_name=job_name, | |
partition_name=part_param, | |
qosname=qos_param, | |
charging_account=charge_param, | |
num_nodes_to_request=nnode_param, | |
mem_per_node_setting=mnode_param, | |
worker_id=UNIQ_WORKER_ID, | |
job_dir=job_script_dir_name, | |
env_activation_cmd=ENV_ACTIVATION, | |
num_workers_per_node=num_workers_per_node, | |
mem=mem_per_node_to_request, | |
lbl_cluster_name=cluster_name, | |
worker_type=THIS_WORKER_TYPE, | |
clone_time_rate=worker_clone_time_rate, | |
task_queue=tp_param, | |
other_params=batch_job_misc_params, | |
export_jtm_config_file="export JTM_CONFIG_FILE=%s" | |
% worker_config, | |
set_jtm_config_file="--config=%s" | |
% worker_config) | |
jf.writelines(batch_job_script_str) | |
os.chmod(batch_job_script_file, 0o775) | |
if dry_run: | |
print(batch_job_script_str) | |
sys.exit(0) | |
sbatch_cmd = "sbatch --parsable %s" % (batch_job_script_file) | |
_, _, ec = run_sh_command(sbatch_cmd, log=logger) | |
assert ec == 0, "Failed to run 'jtm worker' to sbatch dynamic worker." | |
return ec | |
elif cluster_name == "aws": | |
pass | |
# If it's spawned by sbatch | |
# Todo: need to record job_id, worker_id, worker_type, starting_time, wallclocktime | |
# scontrol show jobid -dd <jobid> ==> EndTime | |
# scontrol show jobid <jobid> ==> EndTime | |
# sstat --format=AveCPU,AvePages,AveRSS,AveVMSize,JobID -j <jobid> --allsteps | |
# | |
# if endtime - starttime <= 10%, execute sbatch again | |
# if slurm_job_id != 0 and THIS_WORKER_TYPE == "static": | |
# logger.debug("worker_type: {}".format(THIS_WORKER_TYPE)) | |
# logger.debug("slurm_job_id: {}".format(slurm_job_id)) | |
# Dynamic workers creates [[two]] children when it approaches to the wallclocktime limit | |
# considering the task queue length | |
# Also, maintain the already requested number of workers | |
# if no more workers needed, it won't call sbatch | |
# elif slurm_job_id != 0 and THIS_WORKER_TYPE == "dynamic": | |
# logger.debug("worker_type: {}".format(THIS_WORKER_TYPE)) | |
# logger.debug("slurm_job_id: {}".format(slurm_job_id)) | |
# Remote broker (rmq.nersc.gov) | |
rmq_conn = RmqConnectionHB(config=CONFIG) | |
conn = rmq_conn.open() | |
ch = conn.channel() | |
# ch.confirm_delivery() | |
ch.exchange_declare(exchange=JTM_INNER_MAIN_EXCH, | |
exchange_type="direct", | |
passive=False, | |
durable=True, | |
auto_delete=False) | |
# Declare task receiving queue (client --> worker) | |
# | |
# If you have a queueu that is durable, RabbitMQ will never lose our queue. | |
# If you have a queue that is exclusive, then when the channel that declared | |
# the queue is closed, the queue is deleted. | |
# If you have a queue that is auto-deleted, then when there are no | |
# subscriptions left on that queue it will be deleted. | |
# | |
ch.queue_declare(queue=inner_task_request_queue, | |
durable=True, | |
exclusive=False, | |
auto_delete=True) | |
ch.queue_bind(exchange=JTM_INNER_MAIN_EXCH, | |
queue=inner_task_request_queue, | |
routing_key=inner_task_request_queue) | |
logger.info("Waiting for a request...") | |
logger.debug("Main pid = {}".format(PARENT_PROCESS_ID)) | |
pid_list = [] | |
# Start task termination proc | |
try: | |
task_kill_proc_hdl = mp.Process(target=recv_task_kill_request_proc) | |
task_kill_proc_hdl.start() | |
pid_list.append(task_kill_proc_hdl) | |
except Exception as e: | |
logger.exception("recv_task_kill_request_proc: {}".format(e)) | |
proc_clean(pid_list) | |
conn_clean(conn, ch) | |
sys.exit(1) | |
# Start send_hb_to_client_proc proc | |
try: | |
recv_hb_from_client_proc_hdl = mp.Process(target=send_hb_to_client_proc, | |
args=(hearbeat_interval, | |
slurm_job_id, | |
mem_per_node_to_request, | |
mem_per_cpu_to_request, | |
num_cpus_to_request, | |
job_time_to_request, | |
worker_clone_time_rate, | |
inner_task_request_queue, | |
tp_name, | |
num_workers_per_node, | |
JTM_WORKER_HB_EXCH, | |
WORKER_HB_Q_POSTFIX)) | |
recv_hb_from_client_proc_hdl.start() | |
pid_list.append(recv_hb_from_client_proc_hdl) | |
except Exception as e: | |
logger.exception("send_hb_to_client_proc: {}".format(e)) | |
proc_clean(pid_list) | |
conn_clean(conn, ch) | |
sys.exit(1) | |
logger.info("Start sending my heartbeat to the client in every %d sec to %s" | |
% (hearbeat_interval, WORKER_HB_Q_POSTFIX)) | |
# Start poison receive thread | |
try: | |
recv_poison_proc_hdl = mp.Process(target=recv_reproduce_or_die_proc, | |
args=(pool_name_param, | |
cluster_name, | |
mem_per_node_to_request, | |
mem_per_cpu_to_request, | |
num_nodes_to_request, | |
num_cpus_to_request, | |
job_time_to_request, | |
worker_clone_time_rate, | |
num_workers_per_node, | |
JTM_WORKER_POISON_EXCH, | |
JTM_WORKER_POISON_Q)) | |
recv_poison_proc_hdl.start() | |
pid_list.append(recv_poison_proc_hdl) | |
except Exception as e: | |
logger.exception("recv_reproduce_or_die_proc: {}".format(e)) | |
proc_clean(pid_list) | |
conn_clean(conn, ch) | |
sys.exit(1) | |
# Start hb send thread | |
try: | |
send_hb_to_client_proc_hdl = mp.Process(target=recv_hb_from_client_proc2, | |
args=(inner_task_request_queue, | |
JTM_CLIENT_HB_EXCH, | |
CLIENT_HB_Q_POSTFIX)) | |
send_hb_to_client_proc_hdl.start() | |
pid_list.append(send_hb_to_client_proc_hdl) | |
except Exception as e: | |
logger.exception("Worker termination request: {}".format(e)) | |
proc_clean(pid_list) | |
conn_clean(conn, ch) | |
sys.exit(1) | |
# Checking the total number of child processes | |
try: | |
check_processes_hdl = mp.Process(target=check_processes, | |
args=(pid_list,)) | |
check_processes_hdl.start() | |
pid_list.append(check_processes_hdl) | |
except Exception as e: | |
logger.exception("check_processes: {}".format(e)) | |
proc_clean(pid_list) | |
conn_clean(conn, ch) | |
sys.exit(1) | |
def signal_handler(signum, frame): | |
proc_clean(pid_list) | |
signal.signal(signal.SIGTERM, signal_handler) | |
# Waiting for request | |
ch.basic_qos(prefetch_count=1) | |
# OLD | |
# try: | |
# ch.basic_consume(queue=inner_task_request_queue, | |
# on_message_callback=do_work, | |
# auto_ack=False) | |
# except OSError as err: | |
# logger.exception("Worker terminated: {}".format(err)) | |
# proc_clean() | |
# conn_clean() | |
# sys.exit(1) | |
# NEW | |
# Ref) https://github.com/pika/pika/blob/1.0.1/examples/ | |
# https://stackoverflow.com/questions/51752890/how-to-disable-heartbeats-with-pika-and-rabbitmq | |
# https://github.com/pika/pika/blob/master/examples/basic_consumer_threaded.py | |
threads = [] | |
on_message_callback = functools.partial(on_task_request, | |
args=(conn, threads)) | |
ch.basic_consume(queue=inner_task_request_queue, | |
on_message_callback=on_message_callback) | |
try: | |
ch.start_consuming() | |
except KeyboardInterrupt: | |
proc_clean() | |
conn_clean() | |
# Wait for all to complete | |
# Note: prefetch_count=1 ==> #thread = 1 | |
for thread in threads: | |
thread.join() | |
if conn: | |
conn.close() | |
return 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment