Skip to content

Instantly share code, notes, and snippets.

@albanD
Created August 17, 2022 21:12
Show Gist options
  • Save albanD/2d6f567f8c9b9b9886b2a3524ede06b0 to your computer and use it in GitHub Desktop.
Save albanD/2d6f567f8c9b9b9886b2a3524ede06b0 to your computer and use it in GitHub Desktop.
# Implements Alban's idea of making available the forward traceback
# corresponding to the execution of the current backwared node as a global
# Updated of https://gist.github.com/soulitzer/28140cc4cd7d26828ff7f07b1235d9f5
# to add inter op tracking
import torch
from torch import autograd
from torch.utils._python_dispatch import TorchDispatchMode
current_metadata = None
callback_set = False
# Set up hooks so that during backward the global is properly set/unset
def setup_hooks(root):
def iter_graph(root):
q = []
if not root:
return
i = 0
seen = set([root])
q.append(root)
yield root
while i < len(q):
for fn, _idx in q[i].next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield fn
i += 1
for node in iter_graph(root):
metadata = node.metadata
name = str(node)
def callback():
global current_metadata, callback_set
current_metadata = None
callback_set = False
def get_prehook(name_, metadata_):
def prehook(grad_output):
global current_metadata, callback_set
assert current_metadata is None or current_metadata[0] == "InBetween", \
"Reentrant backward is not supported"
current_metadata = (name_, metadata_)
if not callback_set:
torch.autograd.variable.Variable._execution_engine.queue_callback(callback)
callback_set = True
return prehook
def posthook(grad_input, grad_output):
global current_metadata
current_metadata = ("InBetween", {"traceback_": "Whatever you want here"})
node.register_prehook(get_prehook(name, metadata))
node.register_hook(posthook)
# A mode that accesses the global set with the above
class FwdTracebackMode(TorchDispatchMode):
@staticmethod
def print_metadata(metadata):
# print("Node: ", node)
if 'traceback_' in metadata:
for tb in metadata["traceback_"]:
print(tb, end='')
# For high-order gradient computation
while 'parent_' in metadata:
metadata = metadata['parent_'].metadata
if 'traceback_' in metadata:
print("The traceback of the forward that induced the above computation:")
for tb in metadata["traceback_"]:
print(tb, end='')
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
global current_metadata
if kwargs is None:
kwargs = {}
rs = func(*args, **kwargs)
print("\nfunc:", func.__name__)
if current_metadata is not None:
print("Traceback that induced the computation of", func.__name__)
FwdTracebackMode.print_metadata(current_metadata[1])
return rs
# Usage:
# (1) run forward with anomaly mode, check_nan=False to avoid spurious errors
# (2) iterate through the graph created by this forward and attach hooks
# (3) as we run backward, the hooks will properly set the global
# Examples:
print("\n === Single backward === \n")
with autograd.detect_anomaly(check_nan=False):
a = torch.tensor(1., requires_grad=True)
b = a.sin()
# Don't need to enable anomaly mode here
setup_hooks(b.grad_fn)
with FwdTracebackMode():
ga, = torch.autograd.grad(b, (a,))
print("\n === Higher order gradients === \n")
with autograd.detect_anomaly(check_nan=False):
a = torch.tensor(1., requires_grad=True)
b = a.sin()
ga, = torch.autograd.grad(b, (a,), create_graph=True)
gga, = torch.autograd.grad(ga, (a,), create_graph=True)
# Don't need to enable anomaly mode here
setup_hooks(gga.grad_fn)
with FwdTracebackMode():
gga.backward()
print("\n === Inter Op Tracking === \n")
with autograd.detect_anomaly(check_nan=False):
a = torch.tensor(1., requires_grad=True)
b = a.sin()
c = b + a
# Don't need to enable anomaly mode here
setup_hooks(c.grad_fn)
with FwdTracebackMode():
c.backward()
# Check if there is a cycle involving grad_fn
import weakref
import gc
def get_ref():
with autograd.detect_anomaly(check_nan=False):
a = torch.tensor(1., requires_grad=True)
b = a.sin()
# Don't need to enable anomaly mode here
setup_hooks(b.grad_fn)
class A():
pass
a = A()
b.grad_fn.metadata["a"] = a
ref = weakref.ref(a)
return ref
gc.collect()
assert get_ref()() is None
@albanD
Copy link
Author

albanD commented Aug 17, 2022

Generated output:

$ python foo.py 

 === Single backward === 

foo.py:97: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with autograd.detect_anomaly(check_nan=False):

func: ones_like.default

func: cos.default
Traceback that induced the computation of cos.default
  File "foo.py", line 99, in <module>
    b = a.sin()

func: mul.Tensor
Traceback that induced the computation of mul.Tensor
  File "foo.py", line 99, in <module>
    b = a.sin()

 === Higher order gradients === 

foo.py:108: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with autograd.detect_anomaly(check_nan=False):

func: ones_like.default

func: mul.Tensor
Traceback that induced the computation of mul.Tensor
  File "foo.py", line 112, in <module>
    gga, = torch.autograd.grad(ga, (a,), create_graph=True)
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
The traceback of the forward that induced the above computation:
  File "foo.py", line 111, in <module>
    ga, = torch.autograd.grad(b, (a,), create_graph=True)
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
The traceback of the forward that induced the above computation:
  File "foo.py", line 110, in <module>
    b = a.sin()

func: neg.default
Traceback that induced the computation of neg.default
  File "foo.py", line 112, in <module>
    gga, = torch.autograd.grad(ga, (a,), create_graph=True)
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
The traceback of the forward that induced the above computation:
  File "foo.py", line 111, in <module>
    ga, = torch.autograd.grad(b, (a,), create_graph=True)
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
The traceback of the forward that induced the above computation:
  File "foo.py", line 110, in <module>
    b = a.sin()

func: cos.default
Traceback that induced the computation of cos.default
  File "foo.py", line 112, in <module>
    gga, = torch.autograd.grad(ga, (a,), create_graph=True)
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
The traceback of the forward that induced the above computation:
  File "foo.py", line 111, in <module>
    ga, = torch.autograd.grad(b, (a,), create_graph=True)
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
The traceback of the forward that induced the above computation:
  File "foo.py", line 110, in <module>
    b = a.sin()

func: mul.Tensor
Traceback that induced the computation of mul.Tensor
  File "foo.py", line 112, in <module>
    gga, = torch.autograd.grad(ga, (a,), create_graph=True)
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
The traceback of the forward that induced the above computation:
  File "foo.py", line 111, in <module>
    ga, = torch.autograd.grad(b, (a,), create_graph=True)
  File "/home/albandes/local/pytorch/3.8_debug_source/torch/autograd/__init__.py", line 294, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
The traceback of the forward that induced the above computation:
  File "foo.py", line 110, in <module>
    b = a.sin()

func: detach.default
Traceback that induced the computation of detach.default
  File "foo.py", line 110, in <module>
    b = a.sin()

func: detach.default
Traceback that induced the computation of detach.default
  File "foo.py", line 110, in <module>
    b = a.sin()

 === Inter Op Tracking === 

foo.py:120: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with autograd.detect_anomaly(check_nan=False):

func: ones_like.default

func: cos.default
Traceback that induced the computation of cos.default
  File "foo.py", line 122, in <module>
    b = a.sin()

func: mul.Tensor
Traceback that induced the computation of mul.Tensor
  File "foo.py", line 122, in <module>
    b = a.sin()

func: add.Tensor
Traceback that induced the computation of add.Tensor
Whatever you want here
func: detach.default
Traceback that induced the computation of detach.default
  File "foo.py", line 122, in <module>
    b = a.sin()

func: detach.default
Traceback that induced the computation of detach.default
  File "foo.py", line 122, in <module>
    b = a.sin()
foo.py:136: UserWarning: Anomaly Detection has been enabled. This mode will increase the runtime and should only be enabled for debugging.
  with autograd.detect_anomaly(check_nan=False):

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment