Last active
May 15, 2023 22:55
-
-
Save thehesiod/2f56f98370bea45f021d3704b21707a9 to your computer and use it in GitHub Desktop.
Memory Tracer
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tracemalloc | |
import os | |
import linecache | |
import wrapt | |
_TRACE_FILTERS = ( | |
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"), | |
tracemalloc.Filter(False, tracemalloc.__file__, all_frames=True), # needed because tracemalloc calls fnmatch | |
tracemalloc.Filter(False, linecache.__file__), | |
tracemalloc.Filter(False, os.path.abspath(__file__), all_frames=True), # since we call weakref | |
) | |
class MemTracer: | |
def __init__(self, logger: logging.Logger, period_s: float, file_path: str=None, incremental: bool=True): | |
""" | |
Will create an instance of Memory Tracer that dumps out results each `period_s`, first period is ignored to warm-up the tracer | |
:param period_s: | |
:param file_path: file path to log results to | |
:param incremental: set to True to be incremental snapshots | |
""" | |
self._logger = logger | |
self._trace_start = None | |
self._last_snapshot = None | |
self._period_s = period_s | |
self._file_path = file_path | |
self._incremental = incremental | |
self._num_periods = 0 | |
self._num_ticks = 0 | |
if file_path and os.path.exists(file_path): | |
os.unlink(file_path) | |
# gc.set_debug(gc.DEBUG_LEAK) | |
self.patch_thread_pool_executor() | |
@classmethod | |
def patch_thread_pool_executor(cls): | |
if concurrent.futures.ThreadPoolExecutor._adjust_thread_count != cls._adjust_thread_count: | |
wrapt.wrap_function_wrapper('concurrent.futures', 'ThreadPoolExecutor._adjust_thread_count', cls._adjust_thread_count) | |
@classmethod | |
def _adjust_thread_count(cls, wrapped, instance, args, kwargs): | |
num_threads = len(instance._threads) | |
while num_threads < instance._max_workers: | |
wrapped(*args, **kwargs) | |
num_threads = len(instance._threads) | |
@classmethod | |
def patch_and_tick(cls, obj, *args, **kwargs): | |
""" | |
Will add an instance of MemTracer to "obj" as a private property if it doesn't exist and call tick | |
:param obj: object to patch onto | |
:param args: init args to MemTracer | |
:param kwargs: init args to MemTracer | |
""" | |
tracer = getattr(obj, '_tracer', None) | |
if not tracer: | |
tracer = obj._tracer = cls(*args, **kwargs) | |
tracer.tick() | |
def start(self): | |
if not tracemalloc.is_tracing(): | |
tracemalloc.start(40) | |
self._trace_start = time.time() | |
def capture(self, store: bool=True): | |
# to avoid this popping up in the traces | |
re._cache.clear() | |
gc.collect() | |
if self._trace_start is None: | |
self.start() | |
with log_elapsed(self._logger, "Capturing trace"): | |
snapshot = tracemalloc.take_snapshot() | |
snapshot = snapshot.filter_traces(_TRACE_FILTERS) | |
if store: | |
self._last_snapshot = snapshot | |
return snapshot | |
def tick(self): | |
self._num_ticks += 1 | |
if self._trace_start is None: | |
self.start() | |
return | |
elapsed_s = time.time() - self._trace_start | |
if elapsed_s > self._period_s: | |
self._num_periods += 1 | |
try: | |
self.dump_snapshop(elapsed_s) | |
# objgraph.show_most_common_types(limit=50) | |
finally: | |
self._trace_start = time.time() # want to set this at the end so we get the correct period after this dump | |
def dump_snapshop(self, elapsed_s=-1): | |
if self._last_snapshot is None: | |
self.capture() | |
return | |
snapshot = self.capture(False) | |
top_stats: List[tracemalloc.StatisticDiff] = snapshot.compare_to(self._last_snapshot, 'traceback') | |
total_acquired = 0 | |
total_released = 0 | |
max_stats = min(len(top_stats), 40) | |
stream = StringIO() | |
stream.write('===============================' + os.linesep) | |
stream.write(f"[Top {max_stats}/{len(top_stats)} differences elapsed: {round(elapsed_s)}] in periods: {self._num_periods} and ticks: {self._num_ticks} max RSS: {get_max_rss()} MB" + os.linesep) | |
num_printed = 0 | |
for stat in sorted(top_stats, key=lambda x: x.size_diff, reverse=True): | |
if stat.size_diff <= 0: | |
total_released += -stat.size_diff | |
else: | |
total_acquired += stat.size_diff | |
if num_printed < max_stats and stat.size_diff > 0: | |
stream.write(f"{stat.count_diff} memory blocks: {stat.size_diff / 1024} KB" + os.linesep) | |
for line in stat.traceback.format(): | |
stream.write('\t' + str(line) + os.linesep) | |
num_printed += 1 | |
stream.write(f"total KB acquired: {total_acquired / 1024} released: {total_released / 1024}" + os.linesep) | |
stream.write('===============================' + os.linesep) | |
# stream.write(mem_top(25, 300)) | |
if self._file_path: | |
with open(self._file_path, 'w+') as f: | |
f.write(stream.getvalue()) | |
else: | |
print(stream.getvalue()) | |
if self._incremental: | |
self._last_snapshot = snapshot |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment