Created
March 2, 2023 22:48
-
-
Save TobeTek/e6214cebcf138f1127a1a64a4d1fa494 to your computer and use it in GitHub Desktop.
A cleaner approach to mocking unmanaged models in Django tests
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A cleaner approach to temporarily creating unmanaged model db tables for tests | |
""" | |
from unittest import TestCase | |
from django.db import connections, models | |
class create_unmanaged_model_tables: | |
""" | |
Create db tables for unmanaged models for tests | |
Adapted from: https://stackoverflow.com/a/49800437 | |
Examples: | |
with create_unmanaged_model_tables(UnmanagedModel): | |
... | |
@create_unmanaged_model_tables(UnmanagedModel, FooModel) | |
def test_generate_data(): | |
... | |
@create_unmanaged_model_tables(UnmanagedModel, FooModel) | |
def MyTestCase(unittest.TestCase): | |
... | |
""" | |
def __init__(self, unmanaged_models: list[ModelBase], db_alias: str = "default"): | |
""" | |
:param str db_alias: Name of the database to connect to, defaults to "default" | |
""" | |
self.unmanaged_models = unmanaged_models | |
self.db_alias = db_alias | |
self.connection = connections[db_alias] | |
def __call__(self, obj): | |
if issubclass(obj, TestCase): | |
return self.decorate_class(obj) | |
return self.decorate_callable(obj) | |
def __enter__(self): | |
self.start() | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.stop() | |
def start(self): | |
with self.connection.schema_editor() as schema_editor: | |
for model in self.unmanaged_models: | |
schema_editor.create_model(model) | |
if ( | |
model._meta.db_table | |
not in self.connection.introspection.table_names() | |
): | |
raise ValueError( | |
"Table `{table_name}` is missing in test database.".format( | |
table_name=model._meta.db_table | |
) | |
) | |
def stop(self): | |
with self.connection.schema_editor() as schema_editor: | |
for model in self.unmanaged_models: | |
schema_editor.delete_model(model) | |
def copy(self): | |
return self.__class__( | |
unmanaged_models=self.unmanaged_models, db_alias=self.db_alias | |
) | |
def decorate_class(self, klass): | |
# Modify setUpClass and tearDownClass | |
orig_setUpClass = klass.setUpClass | |
orig_tearDownClass = klass.tearDownClass | |
# noinspection PyDecorator | |
@classmethod | |
def setUpClass(cls): | |
self.start() | |
if orig_setUpClass is not None: | |
orig_setUpClass() | |
self.stop() | |
# noinspection PyDecorator | |
@classmethod | |
def tearDownClass(cls): | |
self.start() | |
if orig_tearDownClass is not None: | |
orig_tearDownClass() | |
self.stop() | |
klass.setUpClass = setUpClass | |
klass.tearDownClass = tearDownClass | |
orig_setUp = klass.setUp | |
orig_tearDown = klass.tearDown | |
def setUp(*args, **kwargs): | |
self.start() | |
if orig_setUp is not None: | |
orig_setUp(*args, **kwargs) | |
def tearDown(*args, **kwargs): | |
if orig_tearDown is not None: | |
orig_tearDown(*args, **kwargs) | |
self.stop() | |
klass.setUp = setUp | |
klass.tearDown = tearDown | |
return klass | |
def decorate_callable(self, callable_obj): | |
@functools.wraps(callable_obj) | |
def wrapper(*args, **kwargs): | |
with self.copy(): | |
return callable_obj(*args, **kwargs) | |
return wrapper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment