Last active
November 7, 2024 18:18
-
-
Save lmann4/3fb7ecbe627072fbc6a0301846aa1cc1 to your computer and use it in GitHub Desktop.
A utilitiy for adding printing and debugging queries.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import inspect | |
import logging | |
import shutil | |
import sys | |
from time import time | |
from types import MethodType | |
from typing import Optional | |
from django.conf import settings | |
from django.db import DEFAULT_DB_ALIAS, connections | |
from django.db.backends.utils import CursorWrapper | |
import sqlparse | |
from pygments import highlight | |
from pygments.formatters.terminal256 import Terminal256Formatter | |
from pygments.lexers.sql import SqlLexer | |
def getframeinfo(frame, context=1): | |
""" | |
Get information about a frame or traceback object. | |
A tuple of five things is returned: the filename, the line number of | |
the current line, the function name, a list of lines of context from | |
the source code, and the index of the current line within that list. | |
The optional second argument specifies the number of lines of context | |
to return, which are centered around the current line. | |
This originally comes from ``inspect`` but is modified to handle issues | |
with ``findsource()``. | |
""" | |
if inspect.istraceback(frame): | |
lineno = frame.tb_lineno | |
frame = frame.tb_frame | |
else: | |
lineno = frame.f_lineno | |
if not inspect.isframe(frame): | |
raise TypeError("arg is not a frame or traceback object") | |
filename = inspect.getsourcefile(frame) or inspect.getfile(frame) | |
if context > 0: | |
start = lineno - 1 - context // 2 | |
try: | |
lines, lnum = inspect.findsource(frame) | |
except (OSError, IndexError): | |
lines = index = None | |
else: | |
start = max(start, 1) | |
start = max(0, min(start, len(lines) - context)) | |
lines = lines[start : start + context] | |
index = lineno - 1 - start | |
else: | |
lines = index = None | |
return inspect.Traceback(filename, lineno, frame.f_code.co_name, lines, index) | |
def get_stack(context=1): | |
""" | |
Get a list of records for a frame and all higher (calling) frames. | |
Each record contains a frame object, filename, line number, function | |
name, a list of lines of context, and index within the context. | |
Modified version of ``inspect.stack()`` which calls our own ``getframeinfo()`` | |
""" | |
frame = sys._getframe(1) | |
framelist = [] | |
while frame: | |
framelist.append((frame,) + getframeinfo(frame, context)) | |
frame = frame.f_back | |
return framelist | |
def tidy_stacktrace(stack): | |
""" | |
Clean up stacktrace and remove all entries that: | |
1. Are part of Django (except contrib apps) | |
2. Are part of SocketServer (used by Django's dev server) | |
3. Are the last entry (which is part of our stacktracing code) | |
``stack`` should be a list of frame tuples from ``inspect.stack()`` | |
""" | |
trace = [] | |
for frame, path, line_no, func_name, text in (f[:5] for f in stack): | |
# Support hiding of frames -- used in various utilities that provide | |
# inspection. | |
if "__traceback_hide__" in frame.f_locals: | |
continue | |
text = ("".join(text)).strip() if text else "" | |
trace.append((path, line_no, func_name, text)) | |
return trace | |
def get_stacktrace(): | |
stack = get_stack() | |
return tidy_stacktrace(reversed(stack)) | |
"""Functions for wrapping strings in ANSI color codes. Borrowed from | |
https://github.com/fabric/fabric/blob/master/fabric/colors.py""" | |
def _wrap_with(code): | |
def inner(text, bold=False): | |
c = code | |
if bold: | |
c = f"1;{c}" | |
return f"\033[{c}m{text}\033[0m" | |
return inner | |
red = _wrap_with("31") | |
green = _wrap_with("32") | |
stata_green = _wrap_with("38;5;78") | |
yellow = _wrap_with("33") | |
blue = _wrap_with("34") | |
magenta = _wrap_with("35") | |
cyan = _wrap_with("36") | |
white = _wrap_with("37") | |
valid_colors = ("red", "green", "light_green", "yellow", "blue", "magenta", "cyan", "white") | |
logger = logging.getLogger("django.db.backends") | |
class StacktraceCursorWrapper(CursorWrapper): | |
"""Wrapper for substitution the CursorWrapper. | |
Added to SQl-query a comment with python stack trace. | |
""" | |
def execute(self, sql, params=None): | |
stacktrace = get_stacktrace() | |
start = time() | |
try: | |
return super().execute(sql, params) | |
finally: | |
stop = time() | |
duration = stop - start | |
sql = self.db.ops.last_executed_query(self.cursor, sql, params) | |
self.db.queries_log.append({"sql": sql, "time": "%.3f" % duration, "stacktrace": stacktrace}) | |
logger.debug( | |
"(%.3f) %s; args=%s", duration, sql, params, extra={"duration": duration, "sql": sql, "params": params} | |
) | |
def format_sql(sql: str, colorize: bool = True) -> str: | |
""" | |
Adds indenting and color coding to SQL queries. | |
:param sql: The SQL query to format | |
:param colorize: Should the SQL be formatted with color? | |
:return: Formatted SQL string. | |
""" | |
# See the docs for format options: https://sqlparse.readthedocs.io/en/latest/api/#formatting-of-sql-statements | |
rtn_val = sqlparse.format(sql, keyword_case="upper", reindent=True, reindent_aligned=False, wrap_after=120) | |
if colorize is True: | |
lexer = SqlLexer() | |
formatter = Terminal256Formatter(style="stata-dark") | |
rtn_val = highlight(rtn_val, lexer, formatter) | |
return rtn_val | |
class ShowDBQueries: | |
final_queries: int = 0 | |
""" | |
This class can be used to print the queries the number of queries used for a section of code. | |
For example if you wanted to see the queries used in a unit test you could do the following:: | |
... | |
def test_some_test(self): | |
with ShowDBQueries(): | |
self.client.get(reverse('my_view')) | |
""" | |
def __init__( | |
self, | |
db_connection: Optional[object] = None, | |
print_queries: bool = True, | |
colorize: bool = True, | |
file_path: Optional[str] = None, | |
include_stacktrace: bool = False, | |
threshold: Optional[float] = None, | |
): | |
self.print_queries = print_queries | |
self.colorize = colorize | |
self.file_path = file_path | |
self.include_stacktrace = include_stacktrace | |
self.query_time_threshold = threshold | |
if db_connection is None: | |
db_connection = connections[DEFAULT_DB_ALIAS] | |
self.connection = db_connection | |
@property | |
def captured_queries(self): | |
return self.connection.queries_log | |
@staticmethod | |
def format_stacktrace(stacks, colorize=False): | |
stacktrace = [] | |
for stack in stacks: | |
line = stack[0] | |
if "apps/base/utils/db_stacktrace" in line: | |
continue | |
stack_str = """File "{}", line {}, in {}\n\t{}""".format(*[stack_data for stack_data in stack]).replace( | |
"%", "%%" | |
) | |
if colorize is True: | |
stack_str = stata_green(stack_str) if str(settings.BASE_DIR) in line else white(stack_str) | |
stacktrace.append(stack_str) | |
stacktrace = "\n".join(stacktrace) | |
return stacktrace | |
def get_queries(self, colorize, use_sql_comments=False): | |
num = 0 | |
rtn_val = "" | |
ending_linebreaks = "\n" if colorize is True else "\n\n" | |
column_width = shutil.get_terminal_size().columns if colorize is True else 80 | |
for q in self.captured_queries: | |
if q["sql"].startswith("EXPLAIN"): | |
continue | |
if self.query_time_threshold is not None and float(q["time"]) < self.query_time_threshold: | |
continue | |
rtn_val += f"{'-- ' if use_sql_comments else ''}{num}. {q['time']} ms\n" | |
rtn_val += f"{'-- ' if use_sql_comments else ''}{'-' * column_width}\n" | |
rtn_val += f"{format_sql(q['sql'], colorize=colorize)};{ending_linebreaks}" | |
if self.include_stacktrace is True: | |
stacktrace = f"\n{self.format_stacktrace(q['stacktrace'], colorize)}\n\n" | |
if use_sql_comments is True: | |
stacktrace = f"\n/* {stacktrace} */\n" | |
rtn_val += stacktrace | |
num += 1 | |
self.final_queries = num | |
return rtn_val | |
def __enter__(self): | |
def make_debug_cursor(db_wrapper_class, cursor): | |
return StacktraceCursorWrapper(cursor, db_wrapper_class) | |
self.orig_force_debug_cursor = self.connection.force_debug_cursor | |
self.orig_make_debug_cursor = self.connection.make_debug_cursor | |
self.connection.force_debug_cursor = True | |
self.connection.make_debug_cursor = MethodType(make_debug_cursor, self.connection) | |
self.initial_queries = len(self.connection.queries_log) | |
self.final_queries = 0 | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.connection.force_debug_cursor = self.orig_force_debug_cursor | |
self.connection.make_debug_cursor = self.orig_make_debug_cursor | |
if self.print_queries is True: | |
print("\nQueries:\n") | |
print(self.get_queries(colorize=self.colorize)) | |
if self.query_time_threshold is not None: | |
print(f"Queries over {self.query_time_threshold} ms: {self.final_queries}") | |
else: | |
self.final_queries = len(self.connection.queries_log) - self.initial_queries | |
print(f"Queries executed: {self.final_queries}\n") | |
if self.file_path is not None: | |
with open(self.file_path, "w+") as f: | |
f.write(self.get_queries(colorize=False, use_sql_comments=True)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Great demo. Looks very useful, and I look forward to trying it this week!