Last active
December 26, 2023 06:09
-
-
Save simonLeary42/25dafb478dcddb1ba2b4a5a2e7dea2a9 to your computer and use it in GitHub Desktop.
like `htop` but for all the nodes in a slurm cluster
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import os | |
import grp | |
import sys | |
import json | |
import shutil | |
import subprocess as subp | |
from typing import List | |
SINFO_CACHE_FILE_PATH="/modules/user-resources/cache/sinfo.json" | |
SINFO_N_CACHE_FILE_PATH="/modules/user-resources/cache/sinfo-N.json" | |
DOWN_STATES = {"DOWN", "DRAIN", "NOT_RESPONDING"} | |
MY_FILENAME=os.path.split(sys.argv[0])[-1] | |
def any_elem_is_in_list(any_of_these:list, in_this_list:list) -> bool: | |
return any((x in in_this_list) for x in any_of_these) | |
def split_commas_strip_remove_empty_strings(list_str:str) -> list: | |
return [x.strip() for x in list_str.split(",") if x.strip() != ""] | |
def closest_element_index(_list, target) -> int: | |
""" | |
return the index of the list element which is closest to target | |
""" | |
min_diff = None | |
min_diff_index = -1 | |
for i, element in enumerate(_list): | |
diff = element - target | |
if i == 0 or abs(diff) < abs(min_diff): | |
min_diff = diff | |
min_diff_index = i | |
return min_diff_index | |
def generate_progress_bar(frac: float, _len=15) -> str: | |
if frac < 0: | |
frac = 0 | |
if frac > 1: | |
frac = 1 | |
_len -= 2 # subtract beginning and end characters | |
num_chars2frac = [ x/_len for x in range(_len+1) ] # [ 0, 1/len, 2/len, ... len/len=1 ] | |
num_chars = closest_element_index(num_chars2frac, frac) # round `frac` to the nearest character length fraction | |
progress_bar = '[' + ('#' * num_chars) + (' ' * (_len - num_chars)) + ']' | |
return progress_bar | |
def fmt_table(table) -> List[str]: | |
""" | |
I would use tabulate but I don't want nonstandard imports | |
""" | |
output_lines = [] | |
# no row has more elements than the header row | |
assert(all(len(row) <= len(table[0]) for row in table)) | |
column_widths = [ 0 ] * len(table[0]) | |
for row in table: | |
for i,element in enumerate(row): | |
if len(str(element)) > column_widths[i]: | |
column_widths[i] = len(str(element)) | |
column_widths = [ x + 5 for x in column_widths ] # room for whitespace on either side | |
header = "" | |
for i,column_header in enumerate(table[0]): | |
if i > 0: | |
header += '|' | |
header += str(column_header).center(column_widths[i]-1) # minus one for the '|' | |
output_lines.append(header) | |
output_lines.append(''.join(["="]*len(header))) | |
for row in table[1:]: | |
line = "" | |
for i,value in enumerate(row): | |
line = line + str(value).ljust(column_widths[i]) | |
output_lines.append(line) | |
return(output_lines) | |
def pipe_output_pager_exit(argv, output_lines): | |
with subp.Popen(argv, stdin=subp.PIPE, stdout=sys.stdout) as proc: | |
proc.stdin.write('\n'.join(output_lines).encode()) | |
proc.stdin.flush() | |
sys.exit(0) | |
def print_output_exit(output_lines): | |
for line in output_lines: | |
print(line) | |
sys.exit(0) | |
def read_file_or_exec_command(file_path:str, argv:List[str]): | |
if file_path.lower() != "none" and os.path.isfile(file_path): | |
with open(file_path, 'r', encoding="utf8") as file: | |
return file.read() | |
else: | |
return subp.check_output(argv) | |
class SlurmNodeUsageAnalyzer: | |
def __init__(self): | |
my_uid = os.getuid() | |
self.my_posix_groups = [g.gr_name for g in grp.getgrall() if my_uid in g.gr_mem] | |
self.sinfo_n, self.sinfo, self.squeue, self.my_associations = None, None, None, None | |
self.my_slurm_accounts, self.my_qos = list(), list() | |
self.nodes, self.partitions, self.node_partitions = dict(), dict(), dict() | |
self.down_nodes = set() | |
self.num_untrackable_gpus = 0 | |
print("collecting info from slurm...", file=sys.stderr) | |
self.get_slurm_input() | |
self.parse_slurm_input() | |
def get_slurm_input(self): | |
self.sinfo_n = json.loads(read_file_or_exec_command( | |
SINFO_N_CACHE_FILE_PATH, ["/usr/bin/sinfo", "--all", "-N", "--json"] | |
))["sinfo"] | |
self.sinfo = json.loads(read_file_or_exec_command( | |
SINFO_CACHE_FILE_PATH, ["/usr/bin/sinfo", "--all", "--json"] | |
))["sinfo"] | |
self.squeue = json.loads(subp.check_output(["/usr/bin/squeue", "--all", "--json"])) | |
self.my_associations = json.loads(subp.check_output( | |
["/usr/bin/sacctmgr", "show", "association", "--json", f"user={os.getenv('USER')}"] | |
)) | |
def parse_slurm_input(self): | |
self.my_slurm_accounts = [x["account"] for x in self.my_associations["associations"] if "account" in x] | |
self.my_qos = [x["qos"] for x in self.my_associations["associations"] if "qos" in x] | |
for partition in self.sinfo: | |
partition_name = partition["partition"]["name"] | |
# self.sinfo is a list, but I don't want to do a linear search when I look for a partition | |
self.partitions[partition_name] = partition["partition"] | |
# add node information to `node2partitions` | |
for hostname in partition["nodes"]["nodes"]: | |
try: | |
self.node_partitions[hostname].add(partition_name) | |
except KeyError: | |
self.node_partitions[hostname] = set() | |
self.node_partitions[hostname].add(partition_name) | |
for sinfo_node in self.sinfo_n: | |
name = sinfo_node["nodes"]["nodes"][0] | |
if name in self.nodes or name in self.down_nodes: | |
continue | |
if any([ state in DOWN_STATES for state in sinfo_node["node"]["state"] ]): | |
self.down_nodes.add(name) | |
continue | |
total_gpus = 0 | |
for resource in sinfo_node["gres"]["total"].split(','): | |
if resource.startswith("gpu:"): | |
total_gpus += int(resource.split(':')[-1]) | |
self.nodes[name] = { | |
"total_cpus": int(sinfo_node["cpus"]["maximum"]), | |
"alloc_cpus": 0, | |
"total_gpus": total_gpus, | |
"alloc_gpus": 0, | |
"total_mem_MB": int(sinfo_node["memory"]["maximum"]), | |
"alloc_mem_MB": 0, | |
} | |
for job in self.squeue["jobs"]: | |
if job["job_state"] != "RUNNING": | |
continue | |
for allocated_node in job["job_resources"]["allocated_nodes"]: | |
hostname = allocated_node["nodename"] | |
if hostname in self.down_nodes: | |
continue | |
alloc_cpus_on_this_node = 0 | |
for _, socket in allocated_node["sockets"].items(): | |
alloc_cpus_on_this_socket = 0 | |
for _ in socket["cores"]: | |
alloc_cpus_on_this_socket += 1 | |
alloc_cpus_on_this_node += alloc_cpus_on_this_socket | |
self.nodes[allocated_node["nodename"]]["alloc_cpus"] += alloc_cpus_on_this_node | |
self.nodes[allocated_node["nodename"]]["alloc_mem_MB"] += allocated_node["memory_allocated"] | |
job_gpus = 0 | |
# example: "cpu=4,mem=40G,node=1,billing=1,gres/gpu=1,gres/gpu:2080ti=1" | |
for resource in job["tres_alloc_str"].split(','): | |
if resource.startswith("gres/gpu="): | |
job_gpus += int(resource.split('=')[-1]) | |
# if this job is running on >1 node, we don't know on which nodes the GPUs are allocated | |
if job["node_count"]["number"] > 1: | |
self.num_untrackable_gpus += job_gpus | |
continue | |
job_node = job["nodes"] # at this point there must be exactly 1 job node | |
if job_node in self.down_nodes: | |
continue # don't bother tracking usage of down nodes | |
self.nodes[job_node]["alloc_gpus"] += job_gpus | |
def check_partition_access(self, partition_name:str) -> bool: | |
""" | |
slurm says that it already hides partitions that the user doesn't have access to | |
but it seems that slurm does not pay attention to allowed accounts and denied accounts | |
so I do it myself | |
""" | |
partition = self.partitions[partition_name] | |
allowed_accts = split_commas_strip_remove_empty_strings(partition["accounts"]["allowed"]) | |
denied_accts = split_commas_strip_remove_empty_strings(partition["accounts"]["deny"]) | |
allowed_qos = split_commas_strip_remove_empty_strings(partition["qos"]["allowed"]) | |
denied_qos = split_commas_strip_remove_empty_strings(partition["qos"]["deny"]) | |
allowed_groups = split_commas_strip_remove_empty_strings(partition["groups"]["allowed"]) | |
if len(allowed_accts) > 0 and not any_elem_is_in_list(self.my_slurm_accounts, allowed_accts): | |
return False | |
if len(denied_accts) > 0 and any_elem_is_in_list(self.my_slurm_accounts, denied_accts): | |
return False | |
if len(allowed_qos) > 0 and not any_elem_is_in_list(self.my_qos, allowed_qos): | |
return False | |
if len(denied_qos) > 0 and any_elem_is_in_list(self.my_qos, denied_qos): | |
return False | |
if len(allowed_groups) > 0 and not any_elem_is_in_list(self.my_posix_groups, allowed_groups): | |
return False | |
return True | |
def node_partitions_that_I_can_access(self, hostname:str) -> List[str]: | |
return sorted([x for x in self.node_partitions[hostname] if self.check_partition_access(x)]) | |
def node_usage(self): | |
output_lines = [] | |
node_table = [] | |
for node, usage in self.nodes.items(): | |
num_free_cpus = usage["total_cpus"] - usage["alloc_cpus"] | |
free_cpu_frac = num_free_cpus / usage["total_cpus"] | |
cpu_usage = f"{generate_progress_bar(free_cpu_frac)} {num_free_cpus}/{usage['total_cpus']}" | |
free_mem_MB = usage["total_mem_MB"] - usage["alloc_mem_MB"] | |
free_mem_frac = free_mem_MB / usage["total_mem_MB"] | |
mem_usage = f"{generate_progress_bar(free_mem_frac)} {(free_mem_MB/1000):.1f} GB" | |
if usage["total_gpus"] > 0: | |
num_free_gpus = usage["total_gpus"] - usage["alloc_gpus"] | |
free_gpu_frac = num_free_gpus / usage["total_gpus"] | |
gpu_usage = f"{generate_progress_bar(free_gpu_frac)} {usage['alloc_gpus']}/{usage['total_gpus']}" | |
else: | |
gpu_usage = "" | |
partitions_to_access = ",".join(self.node_partitions_that_I_can_access(node)) | |
node_table.append([node, cpu_usage, mem_usage, gpu_usage, partitions_to_access]) | |
node_table = [[ "Hostname", "Idle CPU Cores", "Idle Memory", "Idle GPUs", "Partitions" ]] + node_table | |
output_lines = fmt_table(node_table) | |
output_lines.append("") | |
if self.num_untrackable_gpus > 0: | |
output_lines.append(f" {self.num_untrackable_gpus} GPUs are shown as idle but are actually in use.") | |
output_lines.append(f" to print output to stdout, set the PAGER environment variable to \"NONE\".") | |
output_lines.append("") | |
output_lines.append("") | |
return output_lines | |
def main(): | |
analyzer = SlurmNodeUsageAnalyzer() | |
output_lines = analyzer.node_usage() | |
pager_environ = os.environ.get("PAGER", "") | |
if pager_environ.lower() == "none": | |
print_output_exit(output_lines) | |
if os.path.isfile(pager_environ): | |
pipe_output_pager_exit(pager_environ, output_lines) | |
# if PAGER is defined but not "none" and not a file itself, look for it | |
if pager_environ: | |
which_pager_environ = shutil.which(pager_environ) | |
if which_pager_environ: | |
pipe_output_pager_exit(which_pager_environ, output_lines) | |
else: | |
print(f"$PAGER=\"{pager_environ}\" but I can't find it!", file=sys.stderr) | |
which_less = shutil.which("less") | |
if which_less: | |
pipe_output_pager_exit([which_less, "-S"], output_lines) | |
print_output_exit(output_lines) | |
if __name__=="__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment