Skip to content

Instantly share code, notes, and snippets.

@csghone
Created December 4, 2023 18:29
Show Gist options
  • Save csghone/dc6dd0049ec6d2452ba808d03eecf4d7 to your computer and use it in GitHub Desktop.
Save csghone/dc6dd0049ec6d2452ba808d03eecf4d7 to your computer and use it in GitHub Desktop.
Create GraphViz graph with HTML-like labels
#!/usr/bin/env python
import argparse
import logging
import logging.handlers
import os
import sys
import textwrap
import traceback
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import attr
import marshmallow
import marshmallow_dataclass
import pygraphviz
# Use these two lines in all files
logger = logging.getLogger(__name__)
logger.propagate = True
# Call setup_logging() only in file with def main()
# LOG_FORMATTER and def setup_logging() can be moved to a common file for reuse.
LOG_FORMATTER = logging.Formatter(
"%(asctime)s.%(msecs)03d - %(name)s - %(levelname)s - " +
"%(lineno)s - %(funcName)s - " +
"%(message)s",
"%Y%m%d %H:%M:%S")
def setup_logging(inp_file, level=logging.INFO, enable_console=True):
file_log_handler = logging.handlers.RotatingFileHandler(
"__" + os.path.basename(inp_file) + ".main__" + ".log",
maxBytes=1000000,
backupCount=5)
console_log_handler = logging.StreamHandler()
root_logger = logging.getLogger()
root_logger.addHandler(file_log_handler)
if enable_console:
root_logger.addHandler(console_log_handler)
root_logger.setLevel(level)
for handler in logger.root.handlers:
handler.setFormatter(fmt=LOG_FORMATTER)
@dataclass
class TaskPort:
node: str # TaskNode.name
port: str # One of TaskNode.inputs.keys() + TaskNode.outputs.keys()
TaskEdge = Tuple[TaskPort, TaskPort]
@marshmallow_dataclass.add_schema
@attr.s(auto_attribs=True, kw_only=True)
class TaskNode:
name: str
desc: str
group: str = attr.ib(default="A")
inputs: Optional[Dict[str, Any]] = attr.ib(factory=dict, converter=lambda x: x if x is not None else {})
outputs: Optional[Dict[str, Any]] = attr.ib(factory=dict, converter=lambda x: x if x is not None else {})
@marshmallow_dataclass.add_schema
@attr.s(auto_attribs=True, kw_only=True)
class TaskGraph:
name: str
desc: str
nodes: Optional[Dict[str, TaskNode]] = attr.ib(factory=dict, converter=lambda x: x if x is not None else {})
edges: Optional[List[TaskEdge]] = attr.ib(factory=list, converter=lambda x: x if x is not None else [])
def add_node(self, node: TaskNode):
assert node.name not in self.nodes
self.nodes[node.name] = node
def add_edge(self, edge: TaskEdge):
assert edge[0].node in self.nodes
assert edge[1].node in self.nodes
assert edge[0].port in self.nodes[edge[0].node].outputs
assert edge[1].port in self.nodes[edge[1].node].inputs
self.edges.append(edge)
class GraphvizHtmlGraph:
def __init__(self, out_dir: Path, task_graph: TaskGraph):
self.out_dir = out_dir
self.task_graph = task_graph
self.table_header = '<TABLE BORDER="0" CELLBORDER="0" CELLSPACING="0">'
for mode in ["LR", "TB"]:
self.mode = mode
self.create_graph(f"{task_graph.name}", mode)
self.create_subgraphs()
self.create_graph_nodes()
self.create_graph_edges()
self.write_graph()
def create_subgraphs(self):
sub_graph_name_set = set()
for _, task_node in self.task_graph.nodes.items():
sub_graph_name = task_node.group
if sub_graph_name in sub_graph_name_set: continue
sub_graph_name_set.add(sub_graph_name)
self.graph.add_subgraph([], sub_graph_name)
sub_graph = self.graph.get_subgraph(sub_graph_name)
sub_graph.graph_attr["label"] = sub_graph_name
sub_graph.graph_attr["fontname"] = "Consolas"
sub_graph.graph_attr["fontsize"] = 30
sub_graph.graph_attr["cluster"] = "true"
sub_graph.graph_attr["color"] = "darkgrey"
sub_graph.node_attr["style"] = "filled,rounded"
def create_graph_nodes(self):
for node_name, task_node in self.task_graph.nodes.items():
self.create_graph_node(node_name, task_node)
for node_name, task_node in self.task_graph.nodes.items():
self.create_subgraph_node(node_name, task_node)
def create_subgraph_node(self, node_name: str, task_node: TaskNode):
cur_node = self.graph.get_node(node_name)
sub_graph_name = task_node.group
sub_graph = self.graph.get_subgraph(sub_graph_name)
if sub_graph is not None:
sub_graph.add_node(cur_node)
def create_graph_node(self, node_name: str, task_node: TaskNode):
self.graph.add_node(node_name)
gnode = self.graph.get_node(node_name)
node_attr = {}
node_attr["fontname"] = "Consolas"
node_attr["fontsize"] = 20
node_attr["shape"] = "box"
node_attr["style"] = "rounded,filled"
node_attr["color"] = "black"
node_attr["fillcolor"] = "white"
node_attr["width"] = "0"
node_attr["height"] = "0"
node_attr["margin"] = "0"
gnode.attr.update(node_attr)
gnode.attr["group"] = task_node.group
# gnode.attr["fillcolor"] =
gnode.attr["label"] = self.generate_html_label(node_name, task_node)
def create_graph_edges(self):
self.graph.edge_attr.update({
"fontname": "Consolas",
"fontsize": 12,
# "fontcolor":
})
self._edges_labelled = set()
for src_port, dst_port in self.task_graph.edges:
self.create_graph_edge(src_port, dst_port)
def create_graph_edge(self, src: TaskPort, dst: TaskPort):
tooltip = f"{src.node}:{src.port} -> {dst.node}:{dst.port}"
if self.mode == "LR":
compass_point_s = "e"
compass_point_d = "w"
elif self.mode == "TB":
compass_point_s = "s"
compass_point_d = "n"
else:
assert False
head_label = ""
head_key = (dst.node, dst.port)
tail_label = ""
tail_key = (src.node, src.port)
if head_key not in self._edges_labelled:
self._edges_labelled.add(head_key)
head_label = f"{dst.port}"
if tail_key not in self._edges_labelled:
self._edges_labelled.add(tail_key)
tail_label = f"{src.port}"
self.graph.add_edge(
src.node, f"{dst.node}",
tailport=f"{src.port}:{compass_point_s}",
headport=f"{dst.port}:{compass_point_d}",
tooltip=tooltip,
portPos="n",
taillabel=tail_label,
headlabel=head_label
)
def generate_io_section(self, node_name: str, task_node: TaskNode, io_mode: str):
if io_mode == "inp":
table_args = task_node.inputs
elif io_mode == "out":
table_args = task_node.outputs
else:
assert False
if self.mode == "TB":
out = ""
out += f'<TR><TD COLSPAN="{self.num_cols}">\n'
out += self.generate_table(table_args)
out += "\n</TD></TR>\n"
else:
out = "\n"
out += self.generate_table(table_args)
return out
def generate_name_desc(self, node_name: str, task_node: TaskNode):
out = ""
out += "<TR>\n"
name_row_td = '<TD COLOR="{color}" BALIGN="LEFT" BORDER="1" SIDES="{sides}" HEIGHT="40" COLSPAN="{num_cols}">'.format(
color="black",
num_cols=self.num_cols,
sides=self.mode
)
name_row_td += '<FONT POINT-SIZE="20">'
name_row_td += node_name
def _wrap_text(desc_field):
_max_w = 25
out = []
for line in desc_field:
if len(line) < _max_w:
out.append(line)
continue
line = line.replace("- ", "")
lines = textwrap.fill(line, _max_w)
lines = [
f"- {x}" if idx == 0 else f" {x}"
for idx, x in enumerate(lines.split("\n"))
]
out.extend(lines)
return out
if task_node.desc:
name_row_td += '</FONT>'
name_row_td += "<br/>"
name_row_td += '<FONT POINT-SIZE="14">'
desc_field = []
if isinstance(task_node.desc, list):
desc_field.extend([f"- {x}" for x in task_node.desc if x])
else:
desc_field.extend([f"- {x}" for x in task_node.desc.split("\n") if x])
desc_field = _wrap_text(desc_field)
name_row_td += "<br/>".join(desc_field)
name_row_td += '</FONT>'
name_row_td += "</TD>"
out += f"{name_row_td}\n</TR>\n"
return out
def generate_html_label(self, node_name: str, task_node: TaskNode):
num_inps = len(task_node.inputs) if task_node.inputs else 0
num_outs = len(task_node.outputs) if task_node.outputs else 0
self.num_cols = max(num_inps, num_outs)
out = ""
out += self.table_header + "\n"
if self.mode == "LR":
out += "<TR>"
# Generate input section
if self.mode == "LR":
out += "\n<TD>"
out += self.generate_io_section(node_name, task_node, "inp")
if self.mode == "LR":
out += "\n</TD>"
# Generate name/desc section
if self.mode == "LR":
out += "\n<TD>"
out += "\n" + self.table_header + "\n"
out += self.generate_name_desc(node_name, task_node)
if self.mode == "LR":
out += "\n</TABLE>"
out += "</TD>"
# Generate output section
if self.mode == "LR":
out += "\n<TD>"
out += self.generate_io_section(node_name, task_node, "out")
if self.mode == "LR":
out += "\n</TD>"
if self.mode == "LR":
out += "\n</TR>"
out += "</TABLE>"
out = "\n".join(
[" "*8 + x for x in out.split("\n")]
)
return f"<\n{out}>"
def generate_table(self, ios: Dict[str, Any]):
out = self.table_header
if self.mode != "LR":
out += "\n<TR>\n"
if not ios:
out += "<TD></TD>"
if self.mode == "LR":
sides = "TB"
elif self.mode == "TB":
sides = "LR"
else:
assert False
dummy_io = ("dummy", " ")
iterator = ios.items() or [dummy_io]
for idx, (io_id, io) in enumerate(iterator):
if idx == 0:
sides_ = sides[1]
elif idx == len(ios) - 1:
sides_ = sides[0]
else:
sides_ = sides
if len(iterator) == 1:
border_str = ""
else:
border_str = f'BORDER="1" SIDES="{sides_}"'
if self.mode == "LR":
out += "\n<TR>\n"
io_str = str(io)
io_td = '<TD WIDTH="50" {border_str} PORT="{port_name}"><FONT POINT-SIZE="12" >{io_val}</FONT></TD>'.format(
port_name=io_id,
io_val=io_str,
border_str=border_str
)
out += io_td
if self.mode == "LR":
out += "</TR>"
if self.mode != "LR":
out += "\n</TR>"
out += "\n</TABLE>"
out = "\n".join(
[" "*4 + x for x in out.split("\n")]
)
return out
def create_graph(self, graph_name: str, mode : str):
self.graph : pygraphviz.AGraph = pygraphviz.AGraph(directed=True, strict=False)
graph_attr = {}
graph_attr["label"] = graph_name
graph_attr["rankdir"] = mode
graph_attr["bgcolor"] = "lightgray"
graph_attr["fontname"] = "Consolas"
graph_attr["fontsize"] = 30
graph_attr["labelloc"] = "t"
# graph_attr["splines"] =
# graph_attr["nodesep"] =
# graph_attr["ranksep"] =
# if self.mode == "LR":
# graph_attr["ranksep"] =
# graph_attr["compound"] = "true"
def write_graph(self):
graph_name = self.task_graph.name.lower().replace(" ", "_")
path_no_ext = str(self.out_dir / f"{graph_name}.{self.mode}")
self.graph.write(f"{path_no_ext}.dot")
self.graph.draw(f"{path_no_ext}.dot.svg", format="svg", prog="dot")
def process():
task_graph = TaskGraph(name="Test", desc="This is a test")
node_A = TaskNode(name="A", desc="This is 'A'", group="GroupA")
node_B = TaskNode(name="B", desc="This is 'B'", group="GroupA")
node_C = TaskNode(name="C", desc="This is 'C'", group="GroupB")
node_A.outputs = {
"out_0": "Input 0",
"out_1": "Input 1",
}
node_B.inputs = {
"inp_0": "Input 0",
"inp_1": "Input 1",
"inp_2": "Input 2",
}
node_B.outputs = {
"out_0": "Input 0",
"out_1": "Input 1",
}
node_C.inputs = {
"inp_0": "Input 0",
"inp_1": "Input 1",
}
task_graph.add_node(node_A)
task_graph.add_node(node_B)
task_graph.add_node(node_C)
edge_0 = (TaskPort("A", "out_0"), TaskPort("B", "inp_1"))
edge_1 = (TaskPort("B", "out_1"), TaskPort("C", "inp_0"))
task_graph.add_edge(edge_0)
task_graph.add_edge(edge_1)
GraphvizHtmlGraph(Path(__file__).parent, task_graph)
def main():
parser = argparse.ArgumentParser(description="Graphviz Sample")
myargs = parser.parse_args()
setup_logging(__file__, level=logging.INFO)
return process() # process(**vars(myargs)))
if __name__ == "__main__":
try:
sys.exit(main()) # Ensure return value is passed to shell
except Exception as error: # pylint: disable=W0702, W0703
exc_mesg = traceback.format_exc()
logger.error("\n%s", exc_mesg)
logger.error("Error: %s", error)
sys.exit(-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment