Skip to content

Instantly share code, notes, and snippets.

@daskol
Last active December 3, 2021 20:36
Show Gist options
  • Save daskol/a3f22c7fede78faa4d4f871a8058ba21 to your computer and use it in GitHub Desktop.
Save daskol/a3f22c7fede78faa4d4f871a8058ba21 to your computer and use it in GitHub Desktop.
Visualization of backward pass graph for simple model in PyTorch
#!/usr/bin/evn python3
# Run this script and then the command below.
#
# dot -Tpng -ograph.png graph.dot
#
import torch as T
from torchviz import make_dot
model = T.nn.Sequential(
T.nn.Linear(8, 4),
T.nn.ReLU(),
T.nn.Linear(4, 1),
)
xs = T.randn((3, 8))
ys = model(xs.requires_grad_())
dot = make_dot(ys, show_attrs=True, show_saved=True)
dot.save('graph.dot')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment