Created
August 17, 2022 21:12
-
-
Save albanD/2d6f567f8c9b9b9886b2a3524ede06b0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Implements Alban's idea of making available the forward traceback | |
# corresponding to the execution of the current backwared node as a global | |
# Updated of https://gist.github.com/soulitzer/28140cc4cd7d26828ff7f07b1235d9f5 | |
# to add inter op tracking | |
import torch | |
from torch import autograd | |
from torch.utils._python_dispatch import TorchDispatchMode | |
current_metadata = None | |
callback_set = False | |
# Set up hooks so that during backward the global is properly set/unset | |
def setup_hooks(root): | |
def iter_graph(root): | |
q = [] | |
if not root: | |
return | |
i = 0 | |
seen = set([root]) | |
q.append(root) | |
yield root | |
while i < len(q): | |
for fn, _idx in q[i].next_functions: | |
if fn in seen or fn is None: | |
continue | |
seen.add(fn) | |
q.append(fn) | |
yield fn | |
i += 1 | |
for node in iter_graph(root): | |
metadata = node.metadata | |
name = str(node) | |
def callback(): | |
global current_metadata, callback_set | |
current_metadata = None | |
callback_set = False | |
def get_prehook(name_, metadata_): | |
def prehook(grad_output): | |
global current_metadata, callback_set | |
assert current_metadata is None or current_metadata[0] == "InBetween", \ | |
"Reentrant backward is not supported" | |
current_metadata = (name_, metadata_) | |
if not callback_set: | |
torch.autograd.variable.Variable._execution_engine.queue_callback(callback) | |
callback_set = True | |
return prehook | |
def posthook(grad_input, grad_output): | |
global current_metadata | |
current_metadata = ("InBetween", {"traceback_": "Whatever you want here"}) | |
node.register_prehook(get_prehook(name, metadata)) | |
node.register_hook(posthook) | |
# A mode that accesses the global set with the above | |
class FwdTracebackMode(TorchDispatchMode): | |
@staticmethod | |
def print_metadata(metadata): | |
# print("Node: ", node) | |
if 'traceback_' in metadata: | |
for tb in metadata["traceback_"]: | |
print(tb, end='') | |
# For high-order gradient computation | |
while 'parent_' in metadata: | |
metadata = metadata['parent_'].metadata | |
if 'traceback_' in metadata: | |
print("The traceback of the forward that induced the above computation:") | |
for tb in metadata["traceback_"]: | |
print(tb, end='') | |
def __torch_dispatch__(self, func, types, args=(), kwargs=None): | |
global current_metadata | |
if kwargs is None: | |
kwargs = {} | |
rs = func(*args, **kwargs) | |
print("\nfunc:", func.__name__) | |
if current_metadata is not None: | |
print("Traceback that induced the computation of", func.__name__) | |
FwdTracebackMode.print_metadata(current_metadata[1]) | |
return rs | |
# Usage: | |
# (1) run forward with anomaly mode, check_nan=False to avoid spurious errors | |
# (2) iterate through the graph created by this forward and attach hooks | |
# (3) as we run backward, the hooks will properly set the global | |
# Examples: | |
print("\n === Single backward === \n") | |
with autograd.detect_anomaly(check_nan=False): | |
a = torch.tensor(1., requires_grad=True) | |
b = a.sin() | |
# Don't need to enable anomaly mode here | |
setup_hooks(b.grad_fn) | |
with FwdTracebackMode(): | |
ga, = torch.autograd.grad(b, (a,)) | |
print("\n === Higher order gradients === \n") | |
with autograd.detect_anomaly(check_nan=False): | |
a = torch.tensor(1., requires_grad=True) | |
b = a.sin() | |
ga, = torch.autograd.grad(b, (a,), create_graph=True) | |
gga, = torch.autograd.grad(ga, (a,), create_graph=True) | |
# Don't need to enable anomaly mode here | |
setup_hooks(gga.grad_fn) | |
with FwdTracebackMode(): | |
gga.backward() | |
print("\n === Inter Op Tracking === \n") | |
with autograd.detect_anomaly(check_nan=False): | |
a = torch.tensor(1., requires_grad=True) | |
b = a.sin() | |
c = b + a | |
# Don't need to enable anomaly mode here | |
setup_hooks(c.grad_fn) | |
with FwdTracebackMode(): | |
c.backward() | |
# Check if there is a cycle involving grad_fn | |
import weakref | |
import gc | |
def get_ref(): | |
with autograd.detect_anomaly(check_nan=False): | |
a = torch.tensor(1., requires_grad=True) | |
b = a.sin() | |
# Don't need to enable anomaly mode here | |
setup_hooks(b.grad_fn) | |
class A(): | |
pass | |
a = A() | |
b.grad_fn.metadata["a"] = a | |
ref = weakref.ref(a) | |
return ref | |
gc.collect() | |
assert get_ref()() is None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Generated output: