Last active
August 24, 2021 14:35
-
-
Save btimby/5811298 to your computer and use it in GitHub Desktop.
Use a Django database router, a TestCase mixin and thread local storage to allow unit tests to switch databases.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Detect if executed under test | |
TESTING = any(test in sys.argv for test in ( | |
'test', 'csslint', 'jenkins', 'jslint', | |
'jtest', 'lettuce', 'pep8', 'pyflakes', | |
'pylint', 'sloccount', | |
)) | |
if TESTING: | |
# If testing, move the default DB to 'mysql' and replace it | |
# with a SQLite DB. | |
DATABASES['mysql'] = DATABASES['default'] | |
DATABASES['default'] = { | |
'ENGINE': 'django.db.backends.sqlite3', | |
'NAME': ':memory:', | |
} | |
# Install our router so that unit tests can choose a DB. | |
DATABASE_ROUTERS = ('myapp.tests.TestUsingDbRouter', ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
from django.db import DEFAULT_DB_ALIAS | |
from django.test.testcases import TestCase | |
_LOCALS = threading.local() | |
def set_test_db(db_name): | |
"Sets the database name to route to." | |
setattr(_LOCALS, 'test_db_name', db_name) | |
def get_test_db(): | |
"Get the current database name or the default." | |
return getattr(_LOCALS, 'test_db_name', DEFAULT_DB_ALIAS) | |
def del_test_db(): | |
"Clear the database name (restore default)" | |
try: | |
delattr(_LOCALS, 'test_db_name') | |
except AttributeError: | |
pass | |
class TestUsingDbRouter(object): | |
"Simple router to allow DB selection by name." | |
def db_for_read(self, model, **kwargs): | |
return get_test_db() | |
def db_for_write(self, model, **kwargs): | |
return get_test_db() | |
class UsingDbMixin(object): | |
"A mixin to allow a TestCase to select the DB to use." | |
multi_db = True | |
using_db = None | |
def setUp(self, *args, **kwargs): | |
super(UsingDbMixin, self).setUp(*args, **kwargs) | |
set_test_db(self.using_db) | |
def tearDown(self, *args, **kwargs): | |
del_test_db() | |
super(UsingDbMixin, self).tearDown(*args, **kwargs) | |
class MySQLTestCase(UsingDbMixin, TestCase): | |
"A unit test to run against the 'mysql' database." | |
using_db = 'mysql' | |
def test_mysql_something(self): | |
pass # TODO: test something specific to MySQL |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment