Skip to content

Instantly share code, notes, and snippets.

@albanD
Created July 7, 2023 22:50
Show Gist options
  • Save albanD/0c81c8d46a8f1352cb088dab29b75be9 to your computer and use it in GitHub Desktop.
Save albanD/0c81c8d46a8f1352cb088dab29b75be9 to your computer and use it in GitHub Desktop.
Tracking time and stack traces of when Tensors are created, used and die
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map_only
from torch.utils.weak import WeakTensorKeyDictionary
import time
import warnings
import weakref
import traceback
# I'm lazy and need these weak ref to be alive without creating cycles (which
# would happen if I put them in the state)
GLOBAL_HOLDER_FOR_WEAKREFS = []
class MemTrackerMode(TorchDispatchMode):
def __init__(self):
# Each entry contains: [alloc time, [all use time], dealloc time]
self.state = WeakTensorKeyDictionary()
self.finished_state = []
def __torch_dispatch__(self, func, types, args, kwargs=None):
def track_input(t):
if t not in self.state:
raise RuntimeError("Unknown inputs!")
else:
self.state[t][1].append((time.time(), traceback.extract_stack()))
def track_output(t):
if t not in self.state:
state = [(time.time(), traceback.extract_stack()), [], -1]
def on_del(_):
state[2] = (time.time(), traceback.extract_stack())
self.finished_state.append(state)
GLOBAL_HOLDER_FOR_WEAKREFS.append(weakref.ref(t, on_del))
self.state[t] = state
else:
warnings.warn("Output is already tracked??")
tree_map_only(torch.Tensor, track_input, (args, kwargs))
res = func(*args, **kwargs)
tree_map_only(torch.Tensor, track_output, res)
return res
def foo(x):
return x * 3
with MemTrackerMode() as m:
a = torch.rand(10)
b = a * 2
c = a * 3
del b
e = foo(a)
del a, c
e = e * 3
del e
assert len(m.state) == 0, "There are some Tensors that escaped the scope!"
def print_entry(prefix, v):
print(prefix + "---------------")
print(prefix + f"At time {v[0]}")
print(prefix + "At position:")
full_prefix = 2 * prefix
formatted_stack = v[1].format()
if len(v[1]) > 6 and v[1][-6].name == "__torch_dispatch__":
# Remove the torch dispatch frames to make it readable and
# point to the line that calls into pytorch's API
formatted_stack = formatted_stack[:-6]
elif len(v[1]) > 1 and v[1][-1].name == "on_del":
# Remove the frame from our deleter callback to point to the
# line that triggered the deletion
formatted_stack = formatted_stack[:-1]
print(full_prefix + ("\n" + full_prefix).join(formatted_stack))
print(prefix + "---------------")
for i, v in enumerate(m.finished_state):
c, use, d = v
print(10 * "#")
print(f"Tensor with id {i}:")
print(" Creation:")
print_entry(" ", c)
print(" Usage:")
for val_use in use:
print_entry(" ", val_use)
if len(use) == 0:
print(" No usage")
print(" Deletion:")
print_entry(" ", d)
@albanD
Copy link
Author

albanD commented Jul 7, 2023

This code will output the following:

$ python foo.py 
##########
Tensor with id 0:
  Creation:
  ---------------
  At time 1688770161.8595796
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 49, in <module>
    b = a * 2

  ---------------
  Usage:
    No usage
  Deletion:
  ---------------
  At time 1688770161.8600788
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 51, in <module>
    del b

  ---------------
##########
Tensor with id 1:
  Creation:
  ---------------
  At time 1688770161.8584259
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 48, in <module>
    a = torch.rand(10)

  ---------------
  Usage:
    ---------------
    At time 1688770161.8592172
    At position:
          File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 49, in <module>
    b = a * 2

    ---------------
    ---------------
    At time 1688770161.8598256
    At position:
          File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 50, in <module>
    c = a * 3

    ---------------
    ---------------
    At time 1688770161.8602123
    At position:
          File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 52, in <module>
    e = foo(a)

          File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 45, in foo
    return x * 3

    ---------------
    No usage
  Deletion:
  ---------------
  At time 1688770161.860465
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 53, in <module>
    del a, c

  ---------------
##########
Tensor with id 2:
  Creation:
  ---------------
  At time 1688770161.8599792
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 50, in <module>
    c = a * 3

  ---------------
  Usage:
    No usage
  Deletion:
  ---------------
  At time 1688770161.8605087
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 53, in <module>
    del a, c

  ---------------
##########
Tensor with id 3:
  Creation:
  ---------------
  At time 1688770161.860361
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 52, in <module>
    e = foo(a)

      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 45, in foo
    return x * 3

  ---------------
  Usage:
    ---------------
    At time 1688770161.860627
    At position:
          File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 54, in <module>
    e = e * 3

    ---------------
    No usage
  Deletion:
  ---------------
  At time 1688770161.8608553
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 54, in <module>
    e = e * 3

  ---------------
##########
Tensor with id 4:
  Creation:
  ---------------
  At time 1688770161.8607626
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 54, in <module>
    e = e * 3

  ---------------
  Usage:
    No usage
  Deletion:
  ---------------
  At time 1688770161.860896
  At position:
      File "/home/albandes/local/pytorch/3.11_debug_source/test/foo.py", line 55, in <module>
    del e

  ---------------

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment