Skip to content

Instantly share code, notes, and snippets.

@kurtbrose
Last active June 28, 2025 00:39
Show Gist options
  • Save kurtbrose/5be3fc39f74d26043d5fad15aed55bbb to your computer and use it in GitHub Desktop.
Save kurtbrose/5be3fc39f74d26043d5fad15aed55bbb to your computer and use it in GitHub Desktop.
Some helpers for debugging N+1 queries.

helper for n+1 query debugging

how to use?

# inside your test
detector = LazySQLDetector()

with detector.track(), no_dirty_allowed():
    your_heavy_read_only_api()

if detector.has_errors():
    self.fail(detector.summary())

this will flag two things for you:

  1. if any models get marked as dirty, the next round trip to the database will immediately halt
  2. if more than 10 (configurable) queries dispatch from the same stack they will show up in a report at the end
from __future__ import annotations
import traceback
from collections import Counter
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from dataclasses import field
from sqlalchemy import event
from sqlalchemy import inspect
from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
@dataclass
class LazySQLDetector:
threshold: int = 10
stack_filter: str = ""
_stack_sql_map: dict[tuple[str, ...], list[str]] = field(default_factory=lambda: defaultdict(list))
@contextmanager
def track(self):
def _record_sql(_c, _cur, stmt, _p, _ctx, _x):
stack = _lazy_stack(self.stack_filter)
sql = pretty_query(str(stmt))
self._stack_sql_map[stack].append(sql)
event.listen(Engine, "before_cursor_execute", _record_sql, retval=False)
try:
yield
finally:
event.remove(Engine, "before_cursor_execute", _record_sql)
def summary(self) -> str:
lines = ["❌ **Stack traces with excessive SQLs**"]
for stack, sqls in sorted(self._stack_sql_map.items(), key=lambda kv: len(kv[1]), reverse=True):
total = len(sqls)
if total >= self.threshold:
lines.append(f" • {total} SQLs from this stack:")
for frame in stack:
lines.append(f" {frame}")
for sql, count in Counter(sqls).items():
lines.append(f" ({count}×)\n{sql}")
lines.append("")
if not self.has_errors():
lines.append(" ✅ none")
return "\n".join(lines)
def has_errors(self) -> bool:
return any(len(sqls) >= self.threshold for sqls in self._stack_sql_map.values())
def _lazy_stack(filter_str: str) -> tuple[str, ...]:
frames = traceback.extract_stack()
for idx, frame in enumerate(frames):
if filter_str in frame.filename and __file__ not in frame.filename:
# Start here, take next few frames (including this one)
return tuple(traceback.format_list(frames[idx : idx + 8]))
return () # Fallback if no relevant frame is found
@contextmanager
def no_dirty_allowed():
def _record_dirty(session, _flush_context):
by_type_field: dict[str, Counter[str]] = defaultdict(Counter)
for obj in session.dirty:
if not session.is_modified(obj, include_collections=False):
continue
insp = inspect(obj)
dirty_fields = [attr.key for attr in insp.attrs if attr.history.has_changes()]
for dirty_field in dirty_fields:
by_type_field[type(obj).__name__][dirty_field] += 1
if by_type_field:
lines = ["❌ Dirty objects detected:"]
for cls, field_counts in by_type_field.items():
lines.append(f" {cls}:")
for field, count in field_counts.most_common():
lines.append(f" • {count}× {field}")
raise RuntimeError("\n".join(lines))
event.listen(Session, "after_flush", _record_dirty, retval=False)
try:
yield
finally:
event.remove(Session, "after_flush", _record_dirty)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment