Created
June 4, 2010 01:32
-
-
Save rmax/424784 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# HG changeset patch | |
# Parent b83d4a597f5fefe33a4eb0a7a4d4027402140870 | |
# User Rolando Espinoza La fuente <[email protected]> | |
diff --git a/django/test/testcases.py b/django/test/testcases.py | |
--- a/django/test/testcases.py | |
+++ b/django/test/testcases.py | |
@@ -7,7 +7,7 @@ from django.conf import settings | |
from django.core import mail | |
from django.core.management import call_command | |
from django.core.urlresolvers import clear_url_caches | |
-from django.db import transaction, connections, DEFAULT_DB_ALIAS | |
+from django.db import transaction, connections, DEFAULT_DB_ALIAS, signals | |
from django.http import QueryDict | |
from django.test import _doctest as doctest | |
from django.test.client import Client | |
@@ -209,6 +209,40 @@ class DocTestRunner(doctest.DocTestRunne | |
for conn in connections: | |
transaction.rollback_unless_managed(using=conn) | |
+ | |
+class _AssertNumQueriesContext(object): | |
+ def __init__(self, expected, testCase): | |
+ self.counter = 0 | |
+ self.expected = expected | |
+ self.failureException = testCase.failureException | |
+ # enable signal cursor on all databases regardless | |
+ # debug mode backing up current value | |
+ self.signal_cursor_flag = {} | |
+ for conn in connections.all(): | |
+ self.signal_cursor_flag[conn] = conn.use_signal_cursor | |
+ conn.use_signal_cursor = True | |
+ signals.query_execute.connect(self.track_query, weak=False) | |
+ | |
+ def __enter__(self): | |
+ pass | |
+ | |
+ def __exit__(self, exc_type, exc_value, tb): | |
+ signals.query_execute.disconnect(self.track_query, weak=False) | |
+ # restore signal cursor flag for each connection | |
+ for conn in connections.all(): | |
+ conn.use_signal_cursor = self.signal_cursor_flag[conn] | |
+ if exc_type is None: | |
+ if not self.counter == self.expected: | |
+ raise self.failureException("Executed %s queries. Expected: %s."\ | |
+ % (self.counter, self.expected)) | |
+ return True | |
+ # let exceptions pass through | |
+ return False | |
+ | |
+ def track_query(self, sender, *args, **kwargs): | |
+ self.counter += 1 | |
+ | |
+ | |
class TransactionTestCase(unittest.TestCase): | |
def _pre_setup(self): | |
"""Performs any pre-test setup. This includes: | |
@@ -469,6 +503,28 @@ class TransactionTestCase(unittest.TestC | |
def assertQuerysetEqual(self, qs, values, transform=repr): | |
return self.assertEqual(map(transform, qs), values) | |
+ def assertNumQueries(self, num, callableObj=None, *args, **kwargs): | |
+ """ | |
+ Fail unless the number of db queries performed by callableObj when | |
+ invoked with arguments args and keyword arguments is equal to expected | |
+ number. | |
+ | |
+ If called with callableObj omitted or None, will return a context | |
+ object used like this: | |
+ | |
+ with self.assertNumQueries(4): | |
+ do_something() | |
+ """ | |
+ context = _AssertNumQueriesContext(num, self) | |
+ if callableObj is None: | |
+ return context | |
+ else: | |
+ # XXX: use `with` if supported | |
+ context.__enter__() | |
+ callableObj(*args, **kwargs) | |
+ context.__exit__(None, None, None) | |
+ | |
+ | |
def connections_support_transactions(): | |
""" | |
Returns True if all connections support transactions. This is messy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# HG changeset patch | |
# Parent 0a88925843522d427d84e6e68e3427c4cb416be1 | |
# User Rolando Espinoza La fuente <[email protected]> | |
diff --git a/django/db/__init__.py b/django/db/__init__.py | |
--- a/django/db/__init__.py | |
+++ b/django/db/__init__.py | |
@@ -1,5 +1,5 @@ | |
from django.conf import settings | |
-from django.core import signals | |
+from django.core import signals as core_signals | |
from django.core.exceptions import ImproperlyConfigured | |
from django.db.utils import ConnectionHandler, ConnectionRouter, load_backend, DEFAULT_DB_ALIAS, \ | |
DatabaseError, IntegrityError | |
@@ -80,14 +80,14 @@ backend = load_backend(connection.settin | |
def close_connection(**kwargs): | |
for conn in connections.all(): | |
conn.close() | |
-signals.request_finished.connect(close_connection) | |
+core_signals.request_finished.connect(close_connection) | |
# Register an event that resets connection.queries | |
# when a Django request is started. | |
def reset_queries(**kwargs): | |
for conn in connections.all(): | |
conn.queries = [] | |
-signals.request_started.connect(reset_queries) | |
+core_signals.request_started.connect(reset_queries) | |
# Register an event that rolls back the connections | |
# when a Django request has an exception. | |
@@ -98,4 +98,4 @@ def _rollback_on_exception(**kwargs): | |
transaction.rollback_unless_managed(using=conn) | |
except DatabaseError: | |
pass | |
-signals.got_request_exception.connect(_rollback_on_exception) | |
+core_signals.got_request_exception.connect(_rollback_on_exception) | |
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py | |
--- a/django/db/backends/__init__.py | |
+++ b/django/db/backends/__init__.py | |
@@ -20,6 +20,7 @@ class BaseDatabaseWrapper(local): | |
self.queries = [] | |
self.settings_dict = settings_dict | |
self.alias = alias | |
+ self.use_signal_cursor = False | |
def __eq__(self, other): | |
return self.settings_dict == other.settings_dict | |
@@ -74,12 +75,18 @@ class BaseDatabaseWrapper(local): | |
from django.conf import settings | |
cursor = self._cursor() | |
if settings.DEBUG: | |
- return self.make_debug_cursor(cursor) | |
+ return self.make_signal_cursor(self.make_debug_cursor(cursor)) | |
+ elif self.use_signal_cursor: | |
+ return self.make_signal_cursor(cursor) | |
return cursor | |
def make_debug_cursor(self, cursor): | |
return util.CursorDebugWrapper(cursor, self) | |
+ def make_signal_cursor(self, cursor): | |
+ return util.CursorSignalWrapper(cursor, self) | |
+ | |
+ | |
class BaseDatabaseFeatures(object): | |
allows_group_by_pk = False | |
# True if django.db.backend.utils.typecast_timestamp is used on values | |
diff --git a/django/db/backends/util.py b/django/db/backends/util.py | |
--- a/django/db/backends/util.py | |
+++ b/django/db/backends/util.py | |
@@ -2,6 +2,7 @@ import datetime | |
import decimal | |
from time import time | |
+from django.db import signals | |
from django.utils.hashcompat import md5_constructor | |
class CursorDebugWrapper(object): | |
@@ -41,6 +42,39 @@ class CursorDebugWrapper(object): | |
def __iter__(self): | |
return iter(self.cursor) | |
+class CursorSignalWrapper(object): | |
+ def __init__(self, cursor, db): | |
+ self.cursor = cursor | |
+ self.db = db # Instance of a BaseDatabaseWrapper subclass | |
+ | |
+ def execute(self, sql, params=()): | |
+ start = time() | |
+ try: | |
+ return self.cursor.execute(sql, params) | |
+ finally: | |
+ stop = time() | |
+ signals.query_execute.send(self, sql=sql, params=params, time=stop-start) | |
+ | |
+ def executemany(self, sql, param_list): | |
+ start = time() | |
+ try: | |
+ return self.cursor.executemany(sql, param_list) | |
+ finally: | |
+ stop = time() | |
+ # XXX: Should this be a special signal that only gets called once for efficiency? | |
+ for params in param_list: | |
+ signals.query_execute.send(self, sql=sql, params=params, time=stop-start) | |
+ | |
+ def __getattr__(self, attr): | |
+ if attr in self.__dict__: | |
+ return self.__dict__[attr] | |
+ else: | |
+ return getattr(self.cursor, attr) | |
+ | |
+ def __iter__(self): | |
+ return iter(self.cursor) | |
+ | |
+ | |
############################################### | |
# Converters from database (string) to Python # | |
############################################### | |
diff --git a/django/db/signals.py b/django/db/signals.py | |
new file mode 100644 | |
--- /dev/null | |
+++ b/django/db/signals.py | |
@@ -0,0 +1,3 @@ | |
+from django.dispatch import Signal | |
+ | |
+query_execute = Signal() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment