Skip to content

Instantly share code, notes, and snippets.

@charettes
Created November 11, 2022 01:29
Show Gist options
  • Save charettes/7b5718eb80c51fa23ec7b8db9a916609 to your computer and use it in GitHub Desktop.
Save charettes/7b5718eb80c51fa23ec7b8db9a916609 to your computer and use it in GitHub Desktop.
Django test suite SQL output compare
import argparse
import difflib
import os
import sys
from itertools import chain
import yaml
file_name_format = "{sha}:{vendor}.yml"
def compare(control_sha: str, feature_sha: str, vendor: str):
dir = os.path.dirname(__file__)
control_file_path = os.path.join(
dir, file_name_format.format(sha=control_sha, vendor=vendor)
)
with open(control_file_path) as control_file:
control = {
entry["test"]: entry["queries"][0]
for entry in yaml.load_all(control_file, Loader=yaml.SafeLoader)
if entry
}
feature_file_path = os.path.join(
dir, file_name_format.format(sha=feature_sha, vendor=vendor)
)
with open(feature_file_path) as feature_file:
feature = {
entry["test"]: entry["queries"][0]
for entry in yaml.load_all(feature_file, Loader=yaml.SafeLoader)
if entry
}
deltas = []
for test, control_queries in control.items():
if (feature_queries := feature.get(test)) is None:
continue
if control_queries != feature_queries:
deltas.append((test, control_queries, feature_queries))
for test, control_queries, feature_queries in deltas:
sys.stdout.writelines(
difflib.unified_diff(
list(
chain.from_iterable(
f"{query}\n".splitlines(True) for query in control_queries
)
),
list(
chain.from_iterable(
f"{query}\n".splitlines(True) for query in feature_queries
)
),
fromfile=f"{test}:{vendor}:{control_sha}",
tofile=f"{test}:{vendor}:{feature_sha}",
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("control_sha")
parser.add_argument("feature_sha")
parser.add_argument("-v", "--vendor", default="postgresql")
compare(**vars(parser.parse_args()))
diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py
index d505cd7904..d6f8c58acf 100644
--- a/django/db/backends/utils.py
+++ b/django/db/backends/utils.py
@@ -117,7 +117,9 @@ def debug_sql(
stop = time.monotonic()
duration = stop - start
if use_last_executed_query:
- sql = self.db.ops.last_executed_query(self.cursor, sql, params)
+ executed_sql = self.db.ops.last_executed_query(self.cursor, sql, params)
+ else:
+ executed_sql = sql
try:
times = len(params) if many else ""
except TypeError:
@@ -125,7 +127,7 @@ def debug_sql(
times = "?"
self.db.queries_log.append(
{
- "sql": "%s times: %s" % (times, sql) if many else sql,
+ "sql": "%s times: %s" % (times, executed_sql) if many else executed_sql,
"time": "%.3f" % duration,
}
)
@@ -137,7 +139,8 @@ def debug_sql(
self.db.alias,
extra={
"duration": duration,
- "sql": sql,
+ "sql": executed_sql,
+ "raw_sql": sql,
"params": params,
"alias": self.db.alias,
},
diff --git a/django/test/runner.py b/django/test/runner.py
index fb4d77ed60..46f841a5e5 100644
--- a/django/test/runner.py
+++ b/django/test/runner.py
@@ -43,6 +43,49 @@
tblib = None
+import subprocess
+
+sha = (
+ subprocess.check_output(["git", "rev-parse", "--short", "HEAD^"])
+ .decode("ascii")
+ .strip()
+)
+
+
+class RecorderHandler(logging.Handler):
+ def __init__(self, test):
+ self.test = test
+ self.queries = []
+ super().__init__(logging.DEBUG)
+
+ def handle(self, record):
+ query = record.raw_sql
+ if (
+ query.startswith("SAVEPOINT")
+ or query.startswith("RELEASE SAVEPOINT")
+ or query.startswith("EXPLAIN")
+ ):
+ return
+ self.queries.append(sqlparse.format(query, reindent=True, keyword_case="upper"))
+
+ def flush(self):
+ import os
+ import yaml
+
+ try:
+ os.mkdir("tests/.rsql")
+ except FileExistsError:
+ pass
+ vendor = connections["default"].vendor
+ with open(f"tests/.rsql/{sha}:{vendor}.yml", "a") as file:
+ yaml.dump(
+ {"test": str(self.test), "queries": [self.queries]},
+ file,
+ sort_keys=False,
+ )
+ file.write(f"---\n")
+
+
class DebugSQLTextTestResult(unittest.TextTestResult):
def __init__(self, stream, descriptions, verbosity):
self.logger = logging.getLogger("django.db.backends")
@@ -54,15 +97,19 @@ def startTest(self, test):
self.debug_sql_stream = StringIO()
self.handler = logging.StreamHandler(self.debug_sql_stream)
self.logger.addHandler(self.handler)
+ self.record_handler = RecorderHandler(test)
+ self.logger.addHandler(self.record_handler)
super().startTest(test)
def stopTest(self, test):
super().stopTest(test)
self.logger.removeHandler(self.handler)
+ self.logger.removeHandler(self.record_handler)
if self.showAll:
self.debug_sql_stream.seek(0)
self.stream.write(self.debug_sql_stream.read())
self.stream.writeln(self.separator2)
+ self.record_handler.flush()
def addError(self, test, err):
super().addError(test, err)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment