Created
July 7, 2023 22:50
-
-
Save albanD/0c81c8d46a8f1352cb088dab29b75be9 to your computer and use it in GitHub Desktop.
Tracking time and stack traces of when Tensors are created, used and die
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch.utils._python_dispatch import TorchDispatchMode | |
from torch.utils._pytree import tree_map_only | |
from torch.utils.weak import WeakTensorKeyDictionary | |
import time | |
import warnings | |
import weakref | |
import traceback | |
# I'm lazy and need these weak ref to be alive without creating cycles (which | |
# would happen if I put them in the state) | |
GLOBAL_HOLDER_FOR_WEAKREFS = [] | |
class MemTrackerMode(TorchDispatchMode): | |
def __init__(self): | |
# Each entry contains: [alloc time, [all use time], dealloc time] | |
self.state = WeakTensorKeyDictionary() | |
self.finished_state = [] | |
def __torch_dispatch__(self, func, types, args, kwargs=None): | |
def track_input(t): | |
if t not in self.state: | |
raise RuntimeError("Unknown inputs!") | |
else: | |
self.state[t][1].append((time.time(), traceback.extract_stack())) | |
def track_output(t): | |
if t not in self.state: | |
state = [(time.time(), traceback.extract_stack()), [], -1] | |
def on_del(_): | |
state[2] = (time.time(), traceback.extract_stack()) | |
self.finished_state.append(state) | |
GLOBAL_HOLDER_FOR_WEAKREFS.append(weakref.ref(t, on_del)) | |
self.state[t] = state | |
else: | |
warnings.warn("Output is already tracked??") | |
tree_map_only(torch.Tensor, track_input, (args, kwargs)) | |
res = func(*args, **kwargs) | |
tree_map_only(torch.Tensor, track_output, res) | |
return res | |
def foo(x): | |
return x * 3 | |
with MemTrackerMode() as m: | |
a = torch.rand(10) | |
b = a * 2 | |
c = a * 3 | |
del b | |
e = foo(a) | |
del a, c | |
e = e * 3 | |
del e | |
assert len(m.state) == 0, "There are some Tensors that escaped the scope!" | |
def print_entry(prefix, v): | |
print(prefix + "---------------") | |
print(prefix + f"At time {v[0]}") | |
print(prefix + "At position:") | |
full_prefix = 2 * prefix | |
formatted_stack = v[1].format() | |
if len(v[1]) > 6 and v[1][-6].name == "__torch_dispatch__": | |
# Remove the torch dispatch frames to make it readable and | |
# point to the line that calls into pytorch's API | |
formatted_stack = formatted_stack[:-6] | |
elif len(v[1]) > 1 and v[1][-1].name == "on_del": | |
# Remove the frame from our deleter callback to point to the | |
# line that triggered the deletion | |
formatted_stack = formatted_stack[:-1] | |
print(full_prefix + ("\n" + full_prefix).join(formatted_stack)) | |
print(prefix + "---------------") | |
for i, v in enumerate(m.finished_state): | |
c, use, d = v | |
print(10 * "#") | |
print(f"Tensor with id {i}:") | |
print(" Creation:") | |
print_entry(" ", c) | |
print(" Usage:") | |
for val_use in use: | |
print_entry(" ", val_use) | |
if len(use) == 0: | |
print(" No usage") | |
print(" Deletion:") | |
print_entry(" ", d) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This code will output the following: