|
from __future__ import annotations |
|
|
|
import traceback |
|
from collections import Counter |
|
from collections import defaultdict |
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
from dataclasses import field |
|
|
|
from sqlalchemy import event |
|
from sqlalchemy import inspect |
|
from sqlalchemy import select |
|
from sqlalchemy.engine import Engine |
|
from sqlalchemy.orm import Session |
|
|
|
|
|
@dataclass |
|
class LazySQLDetector: |
|
threshold: int = 10 |
|
stack_filter: str = "" |
|
_stack_sql_map: dict[tuple[str, ...], list[str]] = field(default_factory=lambda: defaultdict(list)) |
|
|
|
@contextmanager |
|
def track(self): |
|
def _record_sql(_c, _cur, stmt, _p, _ctx, _x): |
|
stack = _lazy_stack(self.stack_filter) |
|
sql = pretty_query(str(stmt)) |
|
self._stack_sql_map[stack].append(sql) |
|
|
|
event.listen(Engine, "before_cursor_execute", _record_sql, retval=False) |
|
try: |
|
yield |
|
finally: |
|
event.remove(Engine, "before_cursor_execute", _record_sql) |
|
|
|
def summary(self) -> str: |
|
lines = ["❌ **Stack traces with excessive SQLs**"] |
|
|
|
for stack, sqls in sorted(self._stack_sql_map.items(), key=lambda kv: len(kv[1]), reverse=True): |
|
total = len(sqls) |
|
if total >= self.threshold: |
|
lines.append(f" • {total} SQLs from this stack:") |
|
for frame in stack: |
|
lines.append(f" {frame}") |
|
for sql, count in Counter(sqls).items(): |
|
lines.append(f" ({count}×)\n{sql}") |
|
lines.append("") |
|
|
|
if not self.has_errors(): |
|
lines.append(" ✅ none") |
|
|
|
return "\n".join(lines) |
|
|
|
def has_errors(self) -> bool: |
|
return any(len(sqls) >= self.threshold for sqls in self._stack_sql_map.values()) |
|
|
|
|
|
def _lazy_stack(filter_str: str) -> tuple[str, ...]: |
|
frames = traceback.extract_stack() |
|
for idx, frame in enumerate(frames): |
|
if filter_str in frame.filename and __file__ not in frame.filename: |
|
# Start here, take next few frames (including this one) |
|
return tuple(traceback.format_list(frames[idx : idx + 8])) |
|
return () # Fallback if no relevant frame is found |
|
|
|
|
|
@contextmanager |
|
def no_dirty_allowed(): |
|
def _record_dirty(session, _flush_context): |
|
by_type_field: dict[str, Counter[str]] = defaultdict(Counter) |
|
|
|
for obj in session.dirty: |
|
if not session.is_modified(obj, include_collections=False): |
|
continue |
|
|
|
insp = inspect(obj) |
|
dirty_fields = [attr.key for attr in insp.attrs if attr.history.has_changes()] |
|
for dirty_field in dirty_fields: |
|
by_type_field[type(obj).__name__][dirty_field] += 1 |
|
|
|
if by_type_field: |
|
lines = ["❌ Dirty objects detected:"] |
|
for cls, field_counts in by_type_field.items(): |
|
lines.append(f" {cls}:") |
|
for field, count in field_counts.most_common(): |
|
lines.append(f" • {count}× {field}") |
|
raise RuntimeError("\n".join(lines)) |
|
|
|
event.listen(Session, "after_flush", _record_dirty, retval=False) |
|
try: |
|
yield |
|
finally: |
|
event.remove(Session, "after_flush", _record_dirty) |