Created
April 8, 2022 18:39
-
-
Save ian-r-rose/b826ef75e3cdc6e83e3b31878eaf1305 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import datetime | |
import distributed | |
from distributed.diagnostics import SchedulerPlugin | |
from distributed.utils import key_split, key_split_group | |
class TaskGroupStatistics(SchedulerPlugin): | |
def __init__(self): | |
"""Initialize the plugin""" | |
self.groups = {} | |
self.scheduler = None | |
def start(self, scheduler): | |
"""Called on scheduler start as well as on registration time""" | |
self.scheduler = scheduler | |
scheduler.handlers["get_task_groups"] = self.get_task_groups | |
def transition(self, key, start, finish, *args, **kwargs): | |
"""On key transition to memory, update the task group data""" | |
if self.scheduler is None: | |
# Should not get here if initialization has happened correctly | |
return | |
if start == "processing" and finish == "memory": | |
prefix_name = key_split(key) | |
group_name = key_split_group(key) | |
if group_name not in self.groups: | |
self.groups[group_name] = {} | |
group = self.scheduler.task_groups[group_name] | |
self.groups[group_name]["prefix"] = prefix_name | |
self.groups[group_name]["duration"] = group.duration | |
self.groups[group_name]["start"] = str( | |
datetime.datetime.fromtimestamp(group.start) | |
) | |
self.groups[group_name]["stop"] = str( | |
datetime.datetime.fromtimestamp(group.stop) | |
) | |
self.groups[group_name]["nbytes"] = group.nbytes_total | |
async def get_task_groups(self, comm): | |
return self.groups | |
def restart(self, scheduler): | |
self.groups = {} | |
if __name__ == "__main__": | |
import time | |
import pandas | |
cluster = distributed.LocalCluster() | |
client = distributed.Client(cluster) | |
client.register_scheduler_plugin(TaskGroupStatistics()) | |
def slowinc(x): | |
time.sleep(0.1) | |
return x + 1 | |
def slowdouble(x): | |
time.sleep(0.1) | |
return x * 2 | |
futs = client.map(slowinc, range(10)) | |
futs2 = client.map(slowdouble, futs) | |
client.gather(futs2) | |
tg_data = client.sync(client.scheduler.get_task_groups) | |
df = pandas.DataFrame.from_dict(tg_data, orient="index") | |
df.index.name = "group" | |
df.to_csv("tmp.csv") | |
client.close() | |
cluster.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment