Skip to content

Instantly share code, notes, and snippets.

@wangg12
Last active December 18, 2022 20:54
Show Gist options
  • Save wangg12/f11258583ffcc4728eb71adc0f38e832 to your computer and use it in GitHub Desktop.
Save wangg12/f11258583ffcc4728eb71adc0f38e832 to your computer and use it in GitHub Desktop.
from graphviz import Digraph
from torch.autograd import Variable
import torch
def make_dot(var, params=None):
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}
node_attr = dict(style="filled", shape="box", align="left", fontsize="12", ranksep="0.1", height="0.2")
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def size_to_str(size):
return "(" + (", ").join(["%d" % v for v in size]) + ")"
def add_nodes(var):
if var not in seen:
if torch.is_tensor(var):
dot.node(str(id(var)), size_to_str(var.size()), fillcolor="orange")
dot.edge(str(id(var.grad_fn)), str(id(var)))
var = var.grad_fn
if hasattr(var, "variable"):
u = var.variable
name = param_map[id(u)] if params is not None else ""
node_name = "%s\n %s" % (name, size_to_str(u.size()))
dot.node(str(id(var)), node_name, fillcolor="lightblue")
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, "next_functions"):
for u in var.next_functions:
if u[0] is not None:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
if hasattr(var, "saved_tensors"):
for t in var.saved_tensors:
dot.edge(str(id(t)), str(id(var)))
add_nodes(t)
add_nodes(var)
return dot
if __name__ == "__main__":
import torchvision.models as models
inputs = torch.randn(1, 3, 224, 224)
resnet18 = models.resnet18()
y = resnet18(inputs)
# print(y)
g = make_dot(y)
g.view()
@SelvamArul
Copy link

@wangg12 Thanks for the nice script. Unfortunately, I have errors when I use the script.

print (model) prints

FeedForwardNet (
  (fc1): Linear (3 -> 5)
  (relu1): ReLU ()
  (fc2): Linear (5 -> 3)
  (relu2): ReLU ()
  (fc3): Linear (3 -> 1)
)

but g = make_dot(model) results in the following error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-33-b58b8072a3a4> in <module>()
----> 1 g = make_dot(model)

<ipython-input-24-ef81ffc78667> in make_dot(var)
     21                     dot.edge(str(id(u[0])), str(id(var)))
     22                     add_nodes(u[0])
---> 23     add_nodes(var.creator)
     24     return dot


/opt/python3.5/lib/python3.5/site-packages/torch/nn/modules/module.py in __getattr__(self, name)
    235             if name in modules:
    236                 return modules[name]
--> 237         return object.__getattr__(self, name)
    238 
    239     def __setattr__(self, name, value):

AttributeError: type object 'object' has no attribute '__getattr__'

Any thoughts on this?

@FabianIsensee
Copy link

Hi, thank you very much for sharing. Your script seems to not be up to date anymore. When I run it it will show only the very last nodes of the graph and not traverse all the way.

@FabianIsensee
Copy link

The troublemaker seems to be BatchNormBackward. It does not have the attribute 'previous_functions':
In [1]: var
Out[1]: <BatchNormBackward at 0x7fac77697290>

In [2]: hasattr(var, 'previous_functions')
Out[2]: False

In fact, it does not seem to have any attributes at all. Strange...

@yaochx
Copy link

yaochx commented May 14, 2019

def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}

node_attr = dict(style='filled',
                 shape='box',
                 align='left',
                 fontsize='12',
                 ranksep='0.1',
                 height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()

def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'

def add_nodes(var):
    if var not in seen:
        if torch.is_tensor(var):
            dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
        elif hasattr(var, 'variable'):
            u = var.variable
            name = param_map[id(u)] if params is not None else ''
            node_name = '%s\n %s' % (name, size_to_str(u.size()))
            dot.node(str(id(var)), node_name, fillcolor='lightblue')
        else:
            dot.node(str(id(var)), str(type(var).__name__))
        seen.add(var)
        if hasattr(var, 'next_functions'):
            for u in var.next_functions:
                if u[0] is not None:
                    dot.edge(str(id(u[0])), str(id(var)))
                    add_nodes(u[0])
        if hasattr(var, 'saved_tensors'):
            for t in var.saved_tensors:
                dot.edge(str(id(t)), str(id(var)))
                add_nodes(t)
add_nodes(var.grad_fn)
return dot

Author:gyguo95
Ref.:https://blog.csdn.net/gyguo95/article/details/78821617

@chg0901
Copy link

chg0901 commented Jan 4, 2020

I got this error

Traceback (most recent call last):
  File "C:\Users\CHG\AppData\Roaming\Python\Python35\site-packages\IPython\core\interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-10-213383e64662>", line 41, in <module>
    g = make_dot(y)
  File "<ipython-input-10-213383e64662>", line 32, in make_dot
    add_nodes(var.creator)
AttributeError: 'Tensor' object has no attribute 'creator'

@Danial-Alh
Copy link

def make_dot(var, params=None):
""" Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors
saved for backward in torch.autograd.Function
Args:
var: output Variable
params: dict of (name, Variable) to add names to node that
require grad (TODO: make optional)
"""
if params is not None:
assert isinstance(params.values()[0], Variable)
param_map = {id(v): k for k, v in params.items()}

node_attr = dict(style='filled',
                 shape='box',
                 align='left',
                 fontsize='12',
                 ranksep='0.1',
                 height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()

def size_to_str(size):
    return '('+(', ').join(['%d' % v for v in size])+')'

def add_nodes(var):
    if var not in seen:
        if torch.is_tensor(var):
            dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
        elif hasattr(var, 'variable'):
            u = var.variable
            name = param_map[id(u)] if params is not None else ''
            node_name = '%s\n %s' % (name, size_to_str(u.size()))
            dot.node(str(id(var)), node_name, fillcolor='lightblue')
        else:
            dot.node(str(id(var)), str(type(var).__name__))
        seen.add(var)
        if hasattr(var, 'next_functions'):
            for u in var.next_functions:
                if u[0] is not None:
                    dot.edge(str(id(u[0])), str(id(var)))
                    add_nodes(u[0])
        if hasattr(var, 'saved_tensors'):
            for t in var.saved_tensors:
                dot.edge(str(id(t)), str(id(var)))
                add_nodes(t)
add_nodes(var.grad_fn)
return dot

Author:gyguo95
Ref.:https://blog.csdn.net/gyguo95/article/details/78821617

I updated a bit.

from graphviz import Digraph
from torch.autograd import Variable
import torch


def make_dot(var, params=None):
    if params is not None:
        assert isinstance(params.values()[0], Variable)
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def size_to_str(size):
        return '(' + (', ').join(['%d' % v for v in size]) + ')'

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
                dot.edge(str(id(var.grad_fn)), str(id(var)))
                var = var.grad_fn
            if hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)
    add_nodes(var)
    return dot

@wangg12
Copy link
Author

wangg12 commented Jan 20, 2021

Thanks. I updated the script according to yours.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment