Created
March 8, 2021 15:19
-
-
Save jameshfisher/f99ad86fc23d2ae7c856ee2f2ec89cd8 to your computer and use it in GitHub Desktop.
Plot a TensorFlow graph with graphviz/dot
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
try: | |
# pydot-ng is a fork of pydot that is better maintained. | |
import pydot_ng as pydot | |
except ImportError: | |
# pydotplus is an improved version of pydot | |
try: | |
import pydotplus as pydot | |
except ImportError: | |
# Fall back on pydot if necessary. | |
try: | |
import pydot | |
except ImportError: | |
pydot = None | |
def add_edge(dot, src, dst, **kwargs): | |
if not dot.get_edge(src, dst): | |
dot.add_edge(pydot.Edge(src, dst, **kwargs)) | |
def format_shape(shape): | |
return str(shape).replace(str(None), 'None').replace('<', '').replace('>', '') | |
subgraph_attrs = [ | |
'_true_graph', '_false_graph', # StatelessIf | |
'_cond_graph', '_body_graph', # StatelessWhile | |
# TODO what other attrs refer to subgraphs? | |
] | |
def add_graph_to_dot(graph, dot): | |
graph_input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.inputs)]) | |
graphinput = pydot.Node(f"graphinput_{str(id(graph))}", label=f'Graph inputs: |{graph_input_labels}') | |
dot.add_node(graphinput) | |
graph_output_labels = '|'.join([f"<out{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(graph.outputs)]) | |
graphoutput = pydot.Node(f"graphoutput_{str(id(graph))}", label=f'Graph outputs: |{graph_output_labels}') | |
dot.add_node(graphoutput) | |
for (f_name, f) in graph._functions.items(): | |
# Note: pydot prepends "cluster_" to the id, which is how you draw a border (awful) | |
cluster = pydot.Cluster(str(id(f.graph)), label=f_name) | |
dot.add_subgraph(cluster) | |
add_graph_to_dot(f.graph, cluster) | |
ops = graph.get_operations() | |
# Add nodes first | |
for op in ops: | |
if op.type == 'Placeholder': | |
# For our purposes, a Placeholder _does_ have an input. | |
# It comes from the graph inputs. | |
# We instead use the placeholder's outputs to describe its input. | |
input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)]) | |
else: | |
input_labels = '|'.join([f"<in{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.inputs)]) | |
output_labels = '|'.join([f"<out{pos}>"+tensor.dtype.name+' '+format_shape(tensor.shape) for (pos, tensor) in enumerate(op.outputs)]) | |
label = f"{op.name}: {op.type}\n|{{inputs:|outputs:}}|{{{{{input_labels}}}|{{{output_labels}}}}}" | |
op_node = pydot.Node(str(id(op)), label=label) | |
dot.add_node(op_node) | |
# Now add edges | |
for op in ops: | |
try: | |
for pos, input_tensor in enumerate(op.inputs): | |
# Don't show the tensors; just draw arrows between operations | |
add_edge( | |
dot, | |
f"{str(id(input_tensor.op))}:out{input_tensor.value_index}", | |
f"{str(id(op))}:in{pos}", | |
) | |
except: | |
# Get an exception for _OperationWithOutputs - a tensorflow bug? | |
print(f"Could not get inputs for {op}") | |
for subgraph_attr in subgraph_attrs: | |
if hasattr(op, subgraph_attr): | |
subgraph = getattr(op, subgraph_attr) | |
add_edge( | |
dot, | |
f"graphoutput_{str(id(subgraph))}", | |
str(id(op)), | |
ltail=f"cluster_{str(id(subgraph))}", | |
label=subgraph_attr, | |
) | |
for pos, input_tensor in enumerate(graph.inputs): | |
# Note: always to input 0, because it's always to a Placeholder with one input | |
add_edge( | |
dot, | |
f"graphinput_{str(id(graph))}:in{pos}", | |
f"{str(id(input_tensor.op))}:in0" | |
) | |
for pos, output_tensor in enumerate(graph.outputs): | |
add_edge( | |
dot, | |
f"{str(id(output_tensor.op))}:out{output_tensor.value_index}", | |
f"graphoutput_{str(id(graph))}:out{pos}" | |
) | |
def graph_to_dot(graph): | |
dot = pydot.Dot() | |
dot.set('rankdir', 'TB') | |
dot.set('concentrate', 'true') | |
dot.set('dpi', 96) | |
dot.set_node_defaults(shape='record') | |
dot.set('compound', 'true') # https://stackoverflow.com/a/2012106/229792 | |
dot.set('newrank', 'true') | |
add_graph_to_dot(graph, dot) | |
return dot | |
def plot_graph(graph): | |
dot = graph_to_dot(graph) | |
print(dot) | |
dot.write('./graph.png', format='png') | |
### EXAMPLE | |
def py_func(x): | |
if tf.random.uniform(()) < 0.5: | |
x = x*x | |
x = tf.cast(x, 'float32') | |
return 2*x + 5 | |
tf_func = tf.function(py_func) | |
tf_concrete_func = tf_func.get_concrete_function(tf.constant(3)) | |
my_graph = tf_concrete_func.graph | |
plot_graph(my_graph) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment