Skip to content

Instantly share code, notes, and snippets.

Created May 14, 2020 16:41
Show Gist options
  • Save OniDaito/f5bfb83b6677835219feeb95ddb1e4ad to your computer and use it in GitHub Desktop.
Save OniDaito/f5bfb83b6677835219feeb95ddb1e4ad to your computer and use it in GitHub Desktop.
Given a model from pytorch, print out to console and graphviz to see what is going on
# Our drawing graph functions. We rely / have borrowed from the following
# python libraries:
def draw_graph(start, watch=[]):
from graphviz import Digraph
node_attr = dict(style='filled',
graph = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
assert(hasattr(start, "grad_fn"))
if start.grad_fn is not None:
_draw_graph(loss.grad_fn, graph, watch=watching)
size_per_element = 0.15
min_size = 12
# Get the approximate number of nodes and edges
num_rows = len(graph.body)
content_size = num_rows * size_per_element
size = max(min_size, content_size)
size_str = str(size) + "," + str(size)
def _draw_graph(var, graph, watch=[], seen=[], indent="", pobj=None):
''' recursive function going through the hierarchical graph printing off
what we need to see what autograd is doing.'''
from rich import print
if hasattr(var, "next_functions"):
for fun in var.next_functions:
joy = fun[0]
if joy is not None:
if joy not in seen:
label = str(type(joy)).replace(
"class", "").replace("'", "").replace(" ", "")
label_graph = label
colour_graph = ""
if hasattr(joy, 'variable'):
happy = joy.variable
if happy.is_leaf:
label += " \U0001F343"
colour_graph = "green"
for (name, obj) in watch:
if obj is happy:
label += " \U000023E9 " + \
"[b][u][color=#FF00FF]" + name + \
label_graph += name
colour_graph = "blue"
vv = [str(obj.shape[x])
for x in range(len(obj.shape))]
label += " [["
label += ', '.join(vv)
label += "]]"
label += " " + str(happy.var())
graph.node(str(joy), label_graph, fillcolor=colour_graph)
print(indent + label)
_draw_graph(joy, graph, watch, seen, indent + ".", joy)
if pobj is not None:
graph.edge(str(pobj), str(joy))
Copy link

seyeeet commented Sep 30, 2021

would it be possible to do it for all the elements in the models instead of defining the watching list?

Copy link

cocoaaa commented Jan 8, 2022

Thank you for sharing the code. However, could you check if the posted code is correct?
I'm encountering errors when I run it, e.g. line 67 "obj" is undefined -- probably an indentation mistake.

Copy link

@OniDaito - How do I read the output of this? I had the SAME issue with using .view(), but I want to prove to myself that I "broke the computation graph", but when I switch between using movedim() vs. view(), I don't see the difference in the printed output of your scripts:

This is the output when I use .view(), which results in a model that doesn't learn:

..<AccumulateGrad> πŸƒ [[1, 250]] tensor(nan, grad_fn=<VarBackward0>)
.....<AccumulateGrad> πŸƒ [[1, 250]] tensor(0.0015, grad_fn=<VarBackward0>)
................<AccumulateGrad> πŸƒ ⏩ embedding [[5000, 50]] tensor(1.0022, grad_fn=<VarBackward0>)
.............<AccumulateGrad> πŸƒ ⏩ conv1 [[250, 50, 3]] tensor(0.0022, grad_fn=<VarBackward0>)
............<AccumulateGrad> πŸƒ [[1, 250]] tensor(0.0020, grad_fn=<VarBackward0>)
......<AccumulateGrad> πŸƒ ⏩ linear1 [[250, 250]] tensor(0.0013, grad_fn=<VarBackward0>)
...<AccumulateGrad> πŸƒ ⏩ output [[1, 250]] tensor(0.0012, grad_fn=<VarBackward0>)

Would LOVE your help / to understand this visualization!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment