Skip to content

Instantly share code, notes, and snippets.

@Stonesjtu
Forked from wangg12/viz_net_pytorch.py
Created January 26, 2018 06:54
Show Gist options
  • Save Stonesjtu/3e3be59efe27ed0ad8e21aafd72a8f9b to your computer and use it in GitHub Desktop.
Save Stonesjtu/3e3be59efe27ed0ad8e21aafd72a8f9b to your computer and use it in GitHub Desktop.
from graphviz import Digraph
import re
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch.autograd import Variable
import torchvision.models as models
def make_dot(var):
node_attr = dict(style='filled',
shape='box',
align='left',
fontsize='12',
ranksep='0.1',
height='0.2')
dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
seen = set()
def add_nodes(var):
if var not in seen:
if isinstance(var, Variable):
value = '('+(', ').join(['%d'% v for v in var.size()])+')'
dot.node(str(id(var)), str(value), fillcolor='lightblue')
else:
dot.node(str(id(var)), str(type(var).__name__))
seen.add(var)
if hasattr(var, 'previous_functions'):
for u in var.previous_functions:
dot.edge(str(id(u[0])), str(id(var)))
add_nodes(u[0])
add_nodes(var.creator)
return dot
inputs = torch.randn(1,3,224,224)
resnet18 = models.resnet18()
y = resnet18(Variable(inputs))
# print(y)
g = make_dot(y)
g.view()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment