Created
November 11, 2022 01:29
-
-
Save charettes/7b5718eb80c51fa23ec7b8db9a916609 to your computer and use it in GitHub Desktop.
Django test suite SQL output compare
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import difflib | |
import os | |
import sys | |
from itertools import chain | |
import yaml | |
file_name_format = "{sha}:{vendor}.yml" | |
def compare(control_sha: str, feature_sha: str, vendor: str): | |
dir = os.path.dirname(__file__) | |
control_file_path = os.path.join( | |
dir, file_name_format.format(sha=control_sha, vendor=vendor) | |
) | |
with open(control_file_path) as control_file: | |
control = { | |
entry["test"]: entry["queries"][0] | |
for entry in yaml.load_all(control_file, Loader=yaml.SafeLoader) | |
if entry | |
} | |
feature_file_path = os.path.join( | |
dir, file_name_format.format(sha=feature_sha, vendor=vendor) | |
) | |
with open(feature_file_path) as feature_file: | |
feature = { | |
entry["test"]: entry["queries"][0] | |
for entry in yaml.load_all(feature_file, Loader=yaml.SafeLoader) | |
if entry | |
} | |
deltas = [] | |
for test, control_queries in control.items(): | |
if (feature_queries := feature.get(test)) is None: | |
continue | |
if control_queries != feature_queries: | |
deltas.append((test, control_queries, feature_queries)) | |
for test, control_queries, feature_queries in deltas: | |
sys.stdout.writelines( | |
difflib.unified_diff( | |
list( | |
chain.from_iterable( | |
f"{query}\n".splitlines(True) for query in control_queries | |
) | |
), | |
list( | |
chain.from_iterable( | |
f"{query}\n".splitlines(True) for query in feature_queries | |
) | |
), | |
fromfile=f"{test}:{vendor}:{control_sha}", | |
tofile=f"{test}:{vendor}:{feature_sha}", | |
) | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("control_sha") | |
parser.add_argument("feature_sha") | |
parser.add_argument("-v", "--vendor", default="postgresql") | |
compare(**vars(parser.parse_args())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/django/db/backends/utils.py b/django/db/backends/utils.py | |
index d505cd7904..d6f8c58acf 100644 | |
--- a/django/db/backends/utils.py | |
+++ b/django/db/backends/utils.py | |
@@ -117,7 +117,9 @@ def debug_sql( | |
stop = time.monotonic() | |
duration = stop - start | |
if use_last_executed_query: | |
- sql = self.db.ops.last_executed_query(self.cursor, sql, params) | |
+ executed_sql = self.db.ops.last_executed_query(self.cursor, sql, params) | |
+ else: | |
+ executed_sql = sql | |
try: | |
times = len(params) if many else "" | |
except TypeError: | |
@@ -125,7 +127,7 @@ def debug_sql( | |
times = "?" | |
self.db.queries_log.append( | |
{ | |
- "sql": "%s times: %s" % (times, sql) if many else sql, | |
+ "sql": "%s times: %s" % (times, executed_sql) if many else executed_sql, | |
"time": "%.3f" % duration, | |
} | |
) | |
@@ -137,7 +139,8 @@ def debug_sql( | |
self.db.alias, | |
extra={ | |
"duration": duration, | |
- "sql": sql, | |
+ "sql": executed_sql, | |
+ "raw_sql": sql, | |
"params": params, | |
"alias": self.db.alias, | |
}, | |
diff --git a/django/test/runner.py b/django/test/runner.py | |
index fb4d77ed60..46f841a5e5 100644 | |
--- a/django/test/runner.py | |
+++ b/django/test/runner.py | |
@@ -43,6 +43,49 @@ | |
tblib = None | |
+import subprocess | |
+ | |
+sha = ( | |
+ subprocess.check_output(["git", "rev-parse", "--short", "HEAD^"]) | |
+ .decode("ascii") | |
+ .strip() | |
+) | |
+ | |
+ | |
+class RecorderHandler(logging.Handler): | |
+ def __init__(self, test): | |
+ self.test = test | |
+ self.queries = [] | |
+ super().__init__(logging.DEBUG) | |
+ | |
+ def handle(self, record): | |
+ query = record.raw_sql | |
+ if ( | |
+ query.startswith("SAVEPOINT") | |
+ or query.startswith("RELEASE SAVEPOINT") | |
+ or query.startswith("EXPLAIN") | |
+ ): | |
+ return | |
+ self.queries.append(sqlparse.format(query, reindent=True, keyword_case="upper")) | |
+ | |
+ def flush(self): | |
+ import os | |
+ import yaml | |
+ | |
+ try: | |
+ os.mkdir("tests/.rsql") | |
+ except FileExistsError: | |
+ pass | |
+ vendor = connections["default"].vendor | |
+ with open(f"tests/.rsql/{sha}:{vendor}.yml", "a") as file: | |
+ yaml.dump( | |
+ {"test": str(self.test), "queries": [self.queries]}, | |
+ file, | |
+ sort_keys=False, | |
+ ) | |
+ file.write(f"---\n") | |
+ | |
+ | |
class DebugSQLTextTestResult(unittest.TextTestResult): | |
def __init__(self, stream, descriptions, verbosity): | |
self.logger = logging.getLogger("django.db.backends") | |
@@ -54,15 +97,19 @@ def startTest(self, test): | |
self.debug_sql_stream = StringIO() | |
self.handler = logging.StreamHandler(self.debug_sql_stream) | |
self.logger.addHandler(self.handler) | |
+ self.record_handler = RecorderHandler(test) | |
+ self.logger.addHandler(self.record_handler) | |
super().startTest(test) | |
def stopTest(self, test): | |
super().stopTest(test) | |
self.logger.removeHandler(self.handler) | |
+ self.logger.removeHandler(self.record_handler) | |
if self.showAll: | |
self.debug_sql_stream.seek(0) | |
self.stream.write(self.debug_sql_stream.read()) | |
self.stream.writeln(self.separator2) | |
+ self.record_handler.flush() | |
def addError(self, test, err): | |
super().addError(test, err) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment