Skip to content

Instantly share code, notes, and snippets.

@abelardojarab
Last active February 9, 2025 17:13
Show Gist options
  • Save abelardojarab/a255b202752a7519d4c58aca4ab5c4ea to your computer and use it in GitHub Desktop.
Save abelardojarab/a255b202752a7519d4c58aca4ab5c4ea to your computer and use it in GitHub Desktop.
Parallel flame graph generation using Ray
import collections
import json
import re
import ray
import graphviz
ray.init(ignore_reinit_error=True)
@ray.remote
class StackProcessor:
"""Ray actor to process traces in parallel and store stack trees."""
def __init__(self):
self.stack_tree = collections.defaultdict(lambda: [0, collections.defaultdict(dict)])
self.total_function_counts = collections.defaultdict(int)
def process_trace(self, trace_lines):
"""Processes a batch of traces and updates the local stack tree."""
for line in trace_lines:
line = line.strip()
if not line or line.startswith("#"):
continue
match = re.match(r"(\d+\.\d+)\s+(.+)", line)
if match:
_, stack_trace = match.groups()
stack_trace = stack_trace.strip().split(";")
self._insert_stack(stack_trace)
def _insert_stack(self, stack_trace):
"""Inserts a stack trace into a nested dictionary (tree structure)."""
node = self.stack_tree
for func in stack_trace:
if func not in node:
node[func] = [0, collections.defaultdict(dict)]
node[func][0] += 1
self.total_function_counts[func] += 1
node = node[func][1] # Move deeper
def get_results(self):
"""Returns the processed stack tree and function counts."""
return self.stack_tree, self.total_function_counts
def merge_trees(global_tree, local_tree):
"""Merges a local stack tree into the global tree."""
for func, (count, children) in local_tree.items():
if func not in global_tree:
global_tree[func] = [0, collections.defaultdict(dict)]
global_tree[func][0] += count
merge_trees(global_tree[func][1], children)
class ParallelKernelFlameGraph:
def __init__(self, num_workers=4):
self.num_workers = num_workers
self.global_stack_tree = collections.defaultdict(lambda: [0, collections.defaultdict(dict)])
self.total_function_counts = collections.defaultdict(int)
def process_traces_parallel(self, trace_lines):
"""Splits traces and processes them in parallel using Ray actors."""
chunk_size = len(trace_lines) // self.num_workers
trace_chunks = [trace_lines[i:i + chunk_size] for i in range(0, len(trace_lines), chunk_size)]
workers = [StackProcessor.remote() for _ in range(len(trace_chunks))]
futures = [workers[i].process_trace.remote(trace_chunks[i]) for i in range(len(trace_chunks))]
ray.get(futures) # Wait for all workers to complete
# Merge results from all workers
for worker in workers:
local_tree, local_counts = ray.get(worker.get_results.remote())
merge_trees(self.global_stack_tree, local_tree)
for func, count in local_counts.items():
self.total_function_counts[func] += count
def visualize_tree(self, output_file="/mnt/data/parallel_stack_tree"):
"""Uses Graphviz to generate a visual representation of the parallel call hierarchy."""
dot = graphviz.Digraph(format="png", graph_attr={"rankdir": "TB"})
def add_nodes_edges(node, parent=None):
for func, (count, children) in node.items():
label = f"{func}\\n({count} calls)"
dot.node(func, label=label, shape="box", style="filled", fillcolor="lightblue")
if parent:
dot.edge(parent, func)
add_nodes_edges(children, func)
add_nodes_edges(self.global_stack_tree)
dot.render(output_file)
return output_file + ".png"
# Simulated kernel trace with timestamps and full stack traces
kernel_trace = [
"1618998023.123 funcA;funcB;funcC",
"1618998023.456 funcA;funcB",
"1618998023.789 funcA",
"1618998024.001 funcD;funcE",
"1618998024.123 funcA;funcB;funcC",
"1618998024.456 funcA",
"1618998025.123 funcX;funcB;funcC",
"1618998025.456 funcX;funcB"
]
# Run parallel processing
flame_graph = ParallelKernelFlameGraph(num_workers=2)
flame_graph.process_traces_parallel(kernel_trace)
# Generate visualization
output_path = flame_graph.visualize_tree()
output_path
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment