Last active
February 9, 2025 17:13
-
-
Save abelardojarab/a255b202752a7519d4c58aca4ab5c4ea to your computer and use it in GitHub Desktop.
Parallel flame graph generation using Ray
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import collections | |
import json | |
import re | |
import ray | |
import graphviz | |
ray.init(ignore_reinit_error=True) | |
@ray.remote | |
class StackProcessor: | |
"""Ray actor to process traces in parallel and store stack trees.""" | |
def __init__(self): | |
self.stack_tree = collections.defaultdict(lambda: [0, collections.defaultdict(dict)]) | |
self.total_function_counts = collections.defaultdict(int) | |
def process_trace(self, trace_lines): | |
"""Processes a batch of traces and updates the local stack tree.""" | |
for line in trace_lines: | |
line = line.strip() | |
if not line or line.startswith("#"): | |
continue | |
match = re.match(r"(\d+\.\d+)\s+(.+)", line) | |
if match: | |
_, stack_trace = match.groups() | |
stack_trace = stack_trace.strip().split(";") | |
self._insert_stack(stack_trace) | |
def _insert_stack(self, stack_trace): | |
"""Inserts a stack trace into a nested dictionary (tree structure).""" | |
node = self.stack_tree | |
for func in stack_trace: | |
if func not in node: | |
node[func] = [0, collections.defaultdict(dict)] | |
node[func][0] += 1 | |
self.total_function_counts[func] += 1 | |
node = node[func][1] # Move deeper | |
def get_results(self): | |
"""Returns the processed stack tree and function counts.""" | |
return self.stack_tree, self.total_function_counts | |
def merge_trees(global_tree, local_tree): | |
"""Merges a local stack tree into the global tree.""" | |
for func, (count, children) in local_tree.items(): | |
if func not in global_tree: | |
global_tree[func] = [0, collections.defaultdict(dict)] | |
global_tree[func][0] += count | |
merge_trees(global_tree[func][1], children) | |
class ParallelKernelFlameGraph: | |
def __init__(self, num_workers=4): | |
self.num_workers = num_workers | |
self.global_stack_tree = collections.defaultdict(lambda: [0, collections.defaultdict(dict)]) | |
self.total_function_counts = collections.defaultdict(int) | |
def process_traces_parallel(self, trace_lines): | |
"""Splits traces and processes them in parallel using Ray actors.""" | |
chunk_size = len(trace_lines) // self.num_workers | |
trace_chunks = [trace_lines[i:i + chunk_size] for i in range(0, len(trace_lines), chunk_size)] | |
workers = [StackProcessor.remote() for _ in range(len(trace_chunks))] | |
futures = [workers[i].process_trace.remote(trace_chunks[i]) for i in range(len(trace_chunks))] | |
ray.get(futures) # Wait for all workers to complete | |
# Merge results from all workers | |
for worker in workers: | |
local_tree, local_counts = ray.get(worker.get_results.remote()) | |
merge_trees(self.global_stack_tree, local_tree) | |
for func, count in local_counts.items(): | |
self.total_function_counts[func] += count | |
def visualize_tree(self, output_file="/mnt/data/parallel_stack_tree"): | |
"""Uses Graphviz to generate a visual representation of the parallel call hierarchy.""" | |
dot = graphviz.Digraph(format="png", graph_attr={"rankdir": "TB"}) | |
def add_nodes_edges(node, parent=None): | |
for func, (count, children) in node.items(): | |
label = f"{func}\\n({count} calls)" | |
dot.node(func, label=label, shape="box", style="filled", fillcolor="lightblue") | |
if parent: | |
dot.edge(parent, func) | |
add_nodes_edges(children, func) | |
add_nodes_edges(self.global_stack_tree) | |
dot.render(output_file) | |
return output_file + ".png" | |
# Simulated kernel trace with timestamps and full stack traces | |
kernel_trace = [ | |
"1618998023.123 funcA;funcB;funcC", | |
"1618998023.456 funcA;funcB", | |
"1618998023.789 funcA", | |
"1618998024.001 funcD;funcE", | |
"1618998024.123 funcA;funcB;funcC", | |
"1618998024.456 funcA", | |
"1618998025.123 funcX;funcB;funcC", | |
"1618998025.456 funcX;funcB" | |
] | |
# Run parallel processing | |
flame_graph = ParallelKernelFlameGraph(num_workers=2) | |
flame_graph.process_traces_parallel(kernel_trace) | |
# Generate visualization | |
output_path = flame_graph.visualize_tree() | |
output_path |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment