Created
December 4, 2023 18:29
-
-
Save csghone/dc6dd0049ec6d2452ba808d03eecf4d7 to your computer and use it in GitHub Desktop.
Create GraphViz graph with HTML-like labels
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
import argparse | |
import logging | |
import logging.handlers | |
import os | |
import sys | |
import textwrap | |
import traceback | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import attr | |
import marshmallow | |
import marshmallow_dataclass | |
import pygraphviz | |
# Use these two lines in all files | |
logger = logging.getLogger(__name__) | |
logger.propagate = True | |
# Call setup_logging() only in file with def main() | |
# LOG_FORMATTER and def setup_logging() can be moved to a common file for reuse. | |
LOG_FORMATTER = logging.Formatter( | |
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - " + | |
"%(lineno)s - %(funcName)s - " + | |
"%(message)s", | |
"%Y%m%d %H:%M:%S") | |
def setup_logging(inp_file, level=logging.INFO, enable_console=True): | |
file_log_handler = logging.handlers.RotatingFileHandler( | |
"__" + os.path.basename(inp_file) + ".main__" + ".log", | |
maxBytes=1000000, | |
backupCount=5) | |
console_log_handler = logging.StreamHandler() | |
root_logger = logging.getLogger() | |
root_logger.addHandler(file_log_handler) | |
if enable_console: | |
root_logger.addHandler(console_log_handler) | |
root_logger.setLevel(level) | |
for handler in logger.root.handlers: | |
handler.setFormatter(fmt=LOG_FORMATTER) | |
@dataclass | |
class TaskPort: | |
node: str # TaskNode.name | |
port: str # One of TaskNode.inputs.keys() + TaskNode.outputs.keys() | |
TaskEdge = Tuple[TaskPort, TaskPort] | |
@marshmallow_dataclass.add_schema | |
@attr.s(auto_attribs=True, kw_only=True) | |
class TaskNode: | |
name: str | |
desc: str | |
group: str = attr.ib(default="A") | |
inputs: Optional[Dict[str, Any]] = attr.ib(factory=dict, converter=lambda x: x if x is not None else {}) | |
outputs: Optional[Dict[str, Any]] = attr.ib(factory=dict, converter=lambda x: x if x is not None else {}) | |
@marshmallow_dataclass.add_schema | |
@attr.s(auto_attribs=True, kw_only=True) | |
class TaskGraph: | |
name: str | |
desc: str | |
nodes: Optional[Dict[str, TaskNode]] = attr.ib(factory=dict, converter=lambda x: x if x is not None else {}) | |
edges: Optional[List[TaskEdge]] = attr.ib(factory=list, converter=lambda x: x if x is not None else []) | |
def add_node(self, node: TaskNode): | |
assert node.name not in self.nodes | |
self.nodes[node.name] = node | |
def add_edge(self, edge: TaskEdge): | |
assert edge[0].node in self.nodes | |
assert edge[1].node in self.nodes | |
assert edge[0].port in self.nodes[edge[0].node].outputs | |
assert edge[1].port in self.nodes[edge[1].node].inputs | |
self.edges.append(edge) | |
class GraphvizHtmlGraph: | |
def __init__(self, out_dir: Path, task_graph: TaskGraph): | |
self.out_dir = out_dir | |
self.task_graph = task_graph | |
self.table_header = '<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">' | |
for mode in ["LR", "TB"]: | |
self.mode = mode | |
self.create_graph(f"{task_graph.name}", mode) | |
self.create_subgraphs() | |
self.create_graph_nodes() | |
self.create_graph_edges() | |
self.write_graph() | |
def create_subgraphs(self): | |
sub_graph_name_set = set() | |
for _, task_node in self.task_graph.nodes.items(): | |
sub_graph_name = task_node.group | |
if sub_graph_name in sub_graph_name_set: continue | |
sub_graph_name_set.add(sub_graph_name) | |
self.graph.add_subgraph([], sub_graph_name) | |
sub_graph = self.graph.get_subgraph(sub_graph_name) | |
sub_graph.graph_attr["label"] = sub_graph_name | |
sub_graph.graph_attr["fontname"] = "Consolas" | |
sub_graph.graph_attr["fontsize"] = 30 | |
sub_graph.graph_attr["cluster"] = "true" | |
sub_graph.graph_attr["color"] = "darkgrey" | |
sub_graph.node_attr["style"] = "filled,rounded" | |
def create_graph_nodes(self): | |
for node_name, task_node in self.task_graph.nodes.items(): | |
self.create_graph_node(node_name, task_node) | |
for node_name, task_node in self.task_graph.nodes.items(): | |
self.create_subgraph_node(node_name, task_node) | |
def create_subgraph_node(self, node_name: str, task_node: TaskNode): | |
cur_node = self.graph.get_node(node_name) | |
sub_graph_name = task_node.group | |
sub_graph = self.graph.get_subgraph(sub_graph_name) | |
if sub_graph is not None: | |
sub_graph.add_node(cur_node) | |
def create_graph_node(self, node_name: str, task_node: TaskNode): | |
self.graph.add_node(node_name) | |
gnode = self.graph.get_node(node_name) | |
node_attr = {} | |
node_attr["fontname"] = "Consolas" | |
node_attr["fontsize"] = 20 | |
node_attr["shape"] = "box" | |
node_attr["style"] = "rounded,filled" | |
node_attr["color"] = "black" | |
node_attr["fillcolor"] = "white" | |
node_attr["width"] = "0" | |
node_attr["height"] = "0" | |
node_attr["margin"] = "0" | |
gnode.attr.update(node_attr) | |
gnode.attr["group"] = task_node.group | |
# gnode.attr["fillcolor"] = | |
gnode.attr["label"] = self.generate_html_label(node_name, task_node) | |
def create_graph_edges(self): | |
self.graph.edge_attr.update({ | |
"fontname": "Consolas", | |
"fontsize": 12, | |
# "fontcolor": | |
}) | |
self._edges_labelled = set() | |
for src_port, dst_port in self.task_graph.edges: | |
self.create_graph_edge(src_port, dst_port) | |
def create_graph_edge(self, src: TaskPort, dst: TaskPort): | |
tooltip = f"{src.node}:{src.port} -> {dst.node}:{dst.port}" | |
if self.mode == "LR": | |
compass_point_s = "e" | |
compass_point_d = "w" | |
elif self.mode == "TB": | |
compass_point_s = "s" | |
compass_point_d = "n" | |
else: | |
assert False | |
head_label = "" | |
head_key = (dst.node, dst.port) | |
tail_label = "" | |
tail_key = (src.node, src.port) | |
if head_key not in self._edges_labelled: | |
self._edges_labelled.add(head_key) | |
head_label = f"{dst.port}" | |
if tail_key not in self._edges_labelled: | |
self._edges_labelled.add(tail_key) | |
tail_label = f"{src.port}" | |
self.graph.add_edge( | |
src.node, f"{dst.node}", | |
tailport=f"{src.port}:{compass_point_s}", | |
headport=f"{dst.port}:{compass_point_d}", | |
tooltip=tooltip, | |
portPos="n", | |
taillabel=tail_label, | |
headlabel=head_label | |
) | |
def generate_io_section(self, node_name: str, task_node: TaskNode, io_mode: str): | |
if io_mode == "inp": | |
table_args = task_node.inputs | |
elif io_mode == "out": | |
table_args = task_node.outputs | |
else: | |
assert False | |
if self.mode == "TB": | |
out = "" | |
out += f'<TR><TD COLSPAN="{self.num_cols}">\n' | |
out += self.generate_table(table_args) | |
out += "\n</TD></TR>\n" | |
else: | |
out = "\n" | |
out += self.generate_table(table_args) | |
return out | |
def generate_name_desc(self, node_name: str, task_node: TaskNode): | |
out = "" | |
out += "<TR>\n" | |
name_row_td = '<TD COLOR="{color}" BALIGN="LEFT" BORDER="1" SIDES="{sides}" HEIGHT="40" COLSPAN="{num_cols}">'.format( | |
color="black", | |
num_cols=self.num_cols, | |
sides=self.mode | |
) | |
name_row_td += '<FONT POINT-SIZE="20">' | |
name_row_td += node_name | |
def _wrap_text(desc_field): | |
_max_w = 25 | |
out = [] | |
for line in desc_field: | |
if len(line) < _max_w: | |
out.append(line) | |
continue | |
line = line.replace("- ", "") | |
lines = textwrap.fill(line, _max_w) | |
lines = [ | |
f"- {x}" if idx == 0 else f" {x}" | |
for idx, x in enumerate(lines.split("\n")) | |
] | |
out.extend(lines) | |
return out | |
if task_node.desc: | |
name_row_td += '</FONT>' | |
name_row_td += "<br/>" | |
name_row_td += '<FONT POINT-SIZE="14">' | |
desc_field = [] | |
if isinstance(task_node.desc, list): | |
desc_field.extend([f"- {x}" for x in task_node.desc if x]) | |
else: | |
desc_field.extend([f"- {x}" for x in task_node.desc.split("\n") if x]) | |
desc_field = _wrap_text(desc_field) | |
name_row_td += "<br/>".join(desc_field) | |
name_row_td += '</FONT>' | |
name_row_td += "</TD>" | |
out += f"{name_row_td}\n</TR>\n" | |
return out | |
def generate_html_label(self, node_name: str, task_node: TaskNode): | |
num_inps = len(task_node.inputs) if task_node.inputs else 0 | |
num_outs = len(task_node.outputs) if task_node.outputs else 0 | |
self.num_cols = max(num_inps, num_outs) | |
out = "" | |
out += self.table_header + "\n" | |
if self.mode == "LR": | |
out += "<TR>" | |
# Generate input section | |
if self.mode == "LR": | |
out += "\n<TD>" | |
out += self.generate_io_section(node_name, task_node, "inp") | |
if self.mode == "LR": | |
out += "\n</TD>" | |
# Generate name/desc section | |
if self.mode == "LR": | |
out += "\n<TD>" | |
out += "\n" + self.table_header + "\n" | |
out += self.generate_name_desc(node_name, task_node) | |
if self.mode == "LR": | |
out += "\n</TABLE>" | |
out += "</TD>" | |
# Generate output section | |
if self.mode == "LR": | |
out += "\n<TD>" | |
out += self.generate_io_section(node_name, task_node, "out") | |
if self.mode == "LR": | |
out += "\n</TD>" | |
if self.mode == "LR": | |
out += "\n</TR>" | |
out += "</TABLE>" | |
out = "\n".join( | |
[" "*8 + x for x in out.split("\n")] | |
) | |
return f"<\n{out}>" | |
def generate_table(self, ios: Dict[str, Any]): | |
out = self.table_header | |
if self.mode != "LR": | |
out += "\n<TR>\n" | |
if not ios: | |
out += "<TD></TD>" | |
if self.mode == "LR": | |
sides = "TB" | |
elif self.mode == "TB": | |
sides = "LR" | |
else: | |
assert False | |
dummy_io = ("dummy", " ") | |
iterator = ios.items() or [dummy_io] | |
for idx, (io_id, io) in enumerate(iterator): | |
if idx == 0: | |
sides_ = sides[1] | |
elif idx == len(ios) - 1: | |
sides_ = sides[0] | |
else: | |
sides_ = sides | |
if len(iterator) == 1: | |
border_str = "" | |
else: | |
border_str = f'BORDER="1" SIDES="{sides_}"' | |
if self.mode == "LR": | |
out += "\n<TR>\n" | |
io_str = str(io) | |
io_td = '<TD WIDTH="50" {border_str} PORT="{port_name}"><FONT POINT-SIZE="12" >{io_val}</FONT></TD>'.format( | |
port_name=io_id, | |
io_val=io_str, | |
border_str=border_str | |
) | |
out += io_td | |
if self.mode == "LR": | |
out += "</TR>" | |
if self.mode != "LR": | |
out += "\n</TR>" | |
out += "\n</TABLE>" | |
out = "\n".join( | |
[" "*4 + x for x in out.split("\n")] | |
) | |
return out | |
def create_graph(self, graph_name: str, mode : str): | |
self.graph : pygraphviz.AGraph = pygraphviz.AGraph(directed=True, strict=False) | |
graph_attr = {} | |
graph_attr["label"] = graph_name | |
graph_attr["rankdir"] = mode | |
graph_attr["bgcolor"] = "lightgray" | |
graph_attr["fontname"] = "Consolas" | |
graph_attr["fontsize"] = 30 | |
graph_attr["labelloc"] = "t" | |
# graph_attr["splines"] = | |
# graph_attr["nodesep"] = | |
# graph_attr["ranksep"] = | |
# if self.mode == "LR": | |
# graph_attr["ranksep"] = | |
# graph_attr["compound"] = "true" | |
def write_graph(self): | |
graph_name = self.task_graph.name.lower().replace(" ", "_") | |
path_no_ext = str(self.out_dir / f"{graph_name}.{self.mode}") | |
self.graph.write(f"{path_no_ext}.dot") | |
self.graph.draw(f"{path_no_ext}.dot.svg", format="svg", prog="dot") | |
def process(): | |
task_graph = TaskGraph(name="Test", desc="This is a test") | |
node_A = TaskNode(name="A", desc="This is 'A'", group="GroupA") | |
node_B = TaskNode(name="B", desc="This is 'B'", group="GroupA") | |
node_C = TaskNode(name="C", desc="This is 'C'", group="GroupB") | |
node_A.outputs = { | |
"out_0": "Input 0", | |
"out_1": "Input 1", | |
} | |
node_B.inputs = { | |
"inp_0": "Input 0", | |
"inp_1": "Input 1", | |
"inp_2": "Input 2", | |
} | |
node_B.outputs = { | |
"out_0": "Input 0", | |
"out_1": "Input 1", | |
} | |
node_C.inputs = { | |
"inp_0": "Input 0", | |
"inp_1": "Input 1", | |
} | |
task_graph.add_node(node_A) | |
task_graph.add_node(node_B) | |
task_graph.add_node(node_C) | |
edge_0 = (TaskPort("A", "out_0"), TaskPort("B", "inp_1")) | |
edge_1 = (TaskPort("B", "out_1"), TaskPort("C", "inp_0")) | |
task_graph.add_edge(edge_0) | |
task_graph.add_edge(edge_1) | |
GraphvizHtmlGraph(Path(__file__).parent, task_graph) | |
def main(): | |
parser = argparse.ArgumentParser(description="Graphviz Sample") | |
myargs = parser.parse_args() | |
setup_logging(__file__, level=logging.INFO) | |
return process() # process(**vars(myargs))) | |
if __name__ == "__main__": | |
try: | |
sys.exit(main()) # Ensure return value is passed to shell | |
except Exception as error: # pylint: disable=W0702, W0703 | |
exc_mesg = traceback.format_exc() | |
logger.error("\n%s", exc_mesg) | |
logger.error("Error: %s", error) | |
sys.exit(-1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment