Skip to content

Instantly share code, notes, and snippets.

@kinow
Created May 7, 2020 23:42
Show Gist options
  • Save kinow/a9286ce0b4b862fe3348110bb8b2917d to your computer and use it in GitHub Desktop.
Save kinow/a9286ce0b4b862fe3348110bb8b2917d to your computer and use it in GitHub Desktop.
import tracemalloc
import os
import linecache
import logging
import time
import concurrent
import re
import gc
from typing import List
from io import StringIO
import wrapt
import psutil
_TRACE_FILTERS = (
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, tracemalloc.__file__, all_frames=True), # needed because tracemalloc calls fnmatch
tracemalloc.Filter(False, linecache.__file__),
tracemalloc.Filter(False, os.path.abspath(__file__), all_frames=True), # since we call weakref
)
class MemTracer:
def __init__(self, logger: logging.Logger, period_s: float, file_path: str=None, incremental: bool=True):
"""
Will create an instance of Memory Tracer that dumps out results each `period_s`, first period is ignored to warm-up the tracer
:param period_s:
:param file_path: file path to log results to
:param incremental: set to True to be incremental snapshots
"""
self._logger = logger
self._trace_start = None
self._last_snapshot = None
self._period_s = period_s
self._file_path = file_path
self._incremental = incremental
self._num_periods = 0
self._num_ticks = 0
if file_path and os.path.exists(file_path):
os.unlink(file_path)
# gc.set_debug(gc.DEBUG_LEAK)
self.patch_thread_pool_executor()
@classmethod
def patch_thread_pool_executor(cls):
if concurrent.futures.ThreadPoolExecutor._adjust_thread_count != cls._adjust_thread_count:
wrapt.wrap_function_wrapper('concurrent.futures', 'ThreadPoolExecutor._adjust_thread_count', cls._adjust_thread_count)
@classmethod
def _adjust_thread_count(cls, wrapped, instance, args, kwargs):
num_threads = len(instance._threads)
while num_threads < instance._max_workers:
wrapped(*args, **kwargs)
num_threads = len(instance._threads)
@classmethod
def patch_and_tick(cls, obj, *args, **kwargs):
"""
Will add an instance of MemTracer to "obj" as a private property if it doesn't exist and call tick
:param obj: object to patch onto
:param args: init args to MemTracer
:param kwargs: init args to MemTracer
"""
tracer = getattr(obj, '_tracer', None)
if not tracer:
tracer = obj._tracer = cls(*args, **kwargs)
tracer.tick()
def start(self):
if not tracemalloc.is_tracing():
tracemalloc.start(40)
self._trace_start = time.time()
def capture(self, store: bool=True):
# to avoid this popping up in the traces
re._cache.clear()
gc.collect()
if self._trace_start is None:
self.start()
snapshot = tracemalloc.take_snapshot()
snapshot = snapshot.filter_traces(_TRACE_FILTERS)
if store:
self._last_snapshot = snapshot
return snapshot
def tick(self):
self._num_ticks += 1
if self._trace_start is None:
self.start()
return
elapsed_s = time.time() - self._trace_start
if elapsed_s > self._period_s:
self._num_periods += 1
try:
self.dump_snapshop(elapsed_s)
# objgraph.show_most_common_types(limit=50)
finally:
self._trace_start = time.time() # want to set this at the end so we get the correct period after this dump
def get_max_rss(self):
return psutil.virtual_memory().total
def dump_snapshop(self, elapsed_s=-1):
if self._last_snapshot is None:
self.capture()
return
snapshot = self.capture(False)
top_stats: List[tracemalloc.StatisticDiff] = snapshot.compare_to(self._last_snapshot, 'traceback')
total_acquired = 0
total_released = 0
max_stats = min(len(top_stats), 40)
stream = StringIO()
stream.write('===============================' + os.linesep)
stream.write(f"[Top {max_stats}/{len(top_stats)} differences elapsed: {round(elapsed_s)}] in periods: {self._num_periods} and ticks: {self._num_ticks} max RSS: {self.get_max_rss()} MB" + os.linesep)
num_printed = 0
for stat in sorted(top_stats, key=lambda x: x.size_diff, reverse=True):
if stat.size_diff <= 0:
total_released += -stat.size_diff
else:
total_acquired += stat.size_diff
if num_printed < max_stats and stat.size_diff > 0:
stream.write(f"{stat.count_diff} memory blocks: {stat.size_diff / 1024} KB" + os.linesep)
for line in stat.traceback.format():
stream.write('\t' + str(line) + os.linesep)
num_printed += 1
stream.write(f"total KB acquired: {total_acquired / 1024} released: {total_released / 1024}" + os.linesep)
stream.write('===============================' + os.linesep)
# stream.write(mem_top(25, 300))
if self._file_path:
with open(self._file_path, 'w+') as f:
f.write(stream.getvalue())
else:
print(stream.getvalue())
if self._incremental:
self._last_snapshot = snapshot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment