Skip to content

Instantly share code, notes, and snippets.

# Implements Alban's idea of making available the forward traceback
# corresponding to the execution of the current backwared node as a global
# Updated of https://gist.github.com/soulitzer/28140cc4cd7d26828ff7f07b1235d9f5
# to add inter op tracking
import torch
from torch import autograd
from torch.utils._python_dispatch import TorchDispatchMode
current_metadata = None
@albanD
albanD / pytreeify.py
Created January 24, 2023 19:34
Make PyTorch custom Function unpack input and output using pytree.
import torch
from torch.autograd import Function
import torch.utils._pytree as pytree
# Basically wraps things in and out before passing it to the real function that the user defined.
def pytreeify(cls):
assert issubclass(cls, Function)
orig_fw = cls.forward
orig_bw = cls.backward
@albanD
albanD / mem_tracker.py
Created July 7, 2023 22:50
Tracking time and stack traces of when Tensors are created, used and die
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakTensorKeyDictionary
import time
import warnings
import weakref
import traceback