Skip to content

Instantly share code, notes, and snippets.

@rmax
Created June 4, 2010 01:32
Show Gist options
  • Save rmax/424784 to your computer and use it in GitHub Desktop.
Save rmax/424784 to your computer and use it in GitHub Desktop.
# HG changeset patch
# Parent b83d4a597f5fefe33a4eb0a7a4d4027402140870
# User Rolando Espinoza La fuente <[email protected]>
diff --git a/django/test/testcases.py b/django/test/testcases.py
--- a/django/test/testcases.py
+++ b/django/test/testcases.py
@@ -7,7 +7,7 @@ from django.conf import settings
from django.core import mail
from django.core.management import call_command
from django.core.urlresolvers import clear_url_caches
-from django.db import transaction, connections, DEFAULT_DB_ALIAS
+from django.db import transaction, connections, DEFAULT_DB_ALIAS, signals
from django.http import QueryDict
from django.test import _doctest as doctest
from django.test.client import Client
@@ -209,6 +209,40 @@ class DocTestRunner(doctest.DocTestRunne
for conn in connections:
transaction.rollback_unless_managed(using=conn)
+
+class _AssertNumQueriesContext(object):
+ def __init__(self, expected, testCase):
+ self.counter = 0
+ self.expected = expected
+ self.failureException = testCase.failureException
+ # enable signal cursor on all databases regardless
+ # debug mode backing up current value
+ self.signal_cursor_flag = {}
+ for conn in connections.all():
+ self.signal_cursor_flag[conn] = conn.use_signal_cursor
+ conn.use_signal_cursor = True
+ signals.query_execute.connect(self.track_query, weak=False)
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_value, tb):
+ signals.query_execute.disconnect(self.track_query, weak=False)
+ # restore signal cursor flag for each connection
+ for conn in connections.all():
+ conn.use_signal_cursor = self.signal_cursor_flag[conn]
+ if exc_type is None:
+ if not self.counter == self.expected:
+ raise self.failureException("Executed %s queries. Expected: %s."\
+ % (self.counter, self.expected))
+ return True
+ # let exceptions pass through
+ return False
+
+ def track_query(self, sender, *args, **kwargs):
+ self.counter += 1
+
+
class TransactionTestCase(unittest.TestCase):
def _pre_setup(self):
"""Performs any pre-test setup. This includes:
@@ -469,6 +503,28 @@ class TransactionTestCase(unittest.TestC
def assertQuerysetEqual(self, qs, values, transform=repr):
return self.assertEqual(map(transform, qs), values)
+ def assertNumQueries(self, num, callableObj=None, *args, **kwargs):
+ """
+ Fail unless the number of db queries performed by callableObj when
+ invoked with arguments args and keyword arguments is equal to expected
+ number.
+
+ If called with callableObj omitted or None, will return a context
+ object used like this:
+
+ with self.assertNumQueries(4):
+ do_something()
+ """
+ context = _AssertNumQueriesContext(num, self)
+ if callableObj is None:
+ return context
+ else:
+ # XXX: use `with` if supported
+ context.__enter__()
+ callableObj(*args, **kwargs)
+ context.__exit__(None, None, None)
+
+
def connections_support_transactions():
"""
Returns True if all connections support transactions. This is messy
# HG changeset patch
# Parent 0a88925843522d427d84e6e68e3427c4cb416be1
# User Rolando Espinoza La fuente <[email protected]>
diff --git a/django/db/__init__.py b/django/db/__init__.py
--- a/django/db/__init__.py
+++ b/django/db/__init__.py
@@ -1,5 +1,5 @@
from django.conf import settings
-from django.core import signals
+from django.core import signals as core_signals
from django.core.exceptions import ImproperlyConfigured
from django.db.utils import ConnectionHandler, ConnectionRouter, load_backend, DEFAULT_DB_ALIAS, \
DatabaseError, IntegrityError
@@ -80,14 +80,14 @@ backend = load_backend(connection.settin
def close_connection(**kwargs):
for conn in connections.all():
conn.close()
-signals.request_finished.connect(close_connection)
+core_signals.request_finished.connect(close_connection)
# Register an event that resets connection.queries
# when a Django request is started.
def reset_queries(**kwargs):
for conn in connections.all():
conn.queries = []
-signals.request_started.connect(reset_queries)
+core_signals.request_started.connect(reset_queries)
# Register an event that rolls back the connections
# when a Django request has an exception.
@@ -98,4 +98,4 @@ def _rollback_on_exception(**kwargs):
transaction.rollback_unless_managed(using=conn)
except DatabaseError:
pass
-signals.got_request_exception.connect(_rollback_on_exception)
+core_signals.got_request_exception.connect(_rollback_on_exception)
diff --git a/django/db/backends/__init__.py b/django/db/backends/__init__.py
--- a/django/db/backends/__init__.py
+++ b/django/db/backends/__init__.py
@@ -20,6 +20,7 @@ class BaseDatabaseWrapper(local):
self.queries = []
self.settings_dict = settings_dict
self.alias = alias
+ self.use_signal_cursor = False
def __eq__(self, other):
return self.settings_dict == other.settings_dict
@@ -74,12 +75,18 @@ class BaseDatabaseWrapper(local):
from django.conf import settings
cursor = self._cursor()
if settings.DEBUG:
- return self.make_debug_cursor(cursor)
+ return self.make_signal_cursor(self.make_debug_cursor(cursor))
+ elif self.use_signal_cursor:
+ return self.make_signal_cursor(cursor)
return cursor
def make_debug_cursor(self, cursor):
return util.CursorDebugWrapper(cursor, self)
+ def make_signal_cursor(self, cursor):
+ return util.CursorSignalWrapper(cursor, self)
+
+
class BaseDatabaseFeatures(object):
allows_group_by_pk = False
# True if django.db.backend.utils.typecast_timestamp is used on values
diff --git a/django/db/backends/util.py b/django/db/backends/util.py
--- a/django/db/backends/util.py
+++ b/django/db/backends/util.py
@@ -2,6 +2,7 @@ import datetime
import decimal
from time import time
+from django.db import signals
from django.utils.hashcompat import md5_constructor
class CursorDebugWrapper(object):
@@ -41,6 +42,39 @@ class CursorDebugWrapper(object):
def __iter__(self):
return iter(self.cursor)
+class CursorSignalWrapper(object):
+ def __init__(self, cursor, db):
+ self.cursor = cursor
+ self.db = db # Instance of a BaseDatabaseWrapper subclass
+
+ def execute(self, sql, params=()):
+ start = time()
+ try:
+ return self.cursor.execute(sql, params)
+ finally:
+ stop = time()
+ signals.query_execute.send(self, sql=sql, params=params, time=stop-start)
+
+ def executemany(self, sql, param_list):
+ start = time()
+ try:
+ return self.cursor.executemany(sql, param_list)
+ finally:
+ stop = time()
+ # XXX: Should this be a special signal that only gets called once for efficiency?
+ for params in param_list:
+ signals.query_execute.send(self, sql=sql, params=params, time=stop-start)
+
+ def __getattr__(self, attr):
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ else:
+ return getattr(self.cursor, attr)
+
+ def __iter__(self):
+ return iter(self.cursor)
+
+
###############################################
# Converters from database (string) to Python #
###############################################
diff --git a/django/db/signals.py b/django/db/signals.py
new file mode 100644
--- /dev/null
+++ b/django/db/signals.py
@@ -0,0 +1,3 @@
+from django.dispatch import Signal
+
+query_execute = Signal()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment