Last active
February 6, 2024 12:58
-
-
Save kingbuzzman/c790bced3788d289a4d938d9d8fed596 to your computer and use it in GitHub Desktop.
Test excluding middlewares
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding:utf-8 -*- | |
# Stolen from: https://mlvin.xyz/django-single-file-project.html | |
import datetime | |
import inspect | |
import os | |
import sys | |
from types import ModuleType | |
import django | |
from django.conf import settings | |
from django.urls import include, path, reverse | |
from django.apps import apps, AppConfig | |
from django.http import HttpResponse | |
BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
# The current name of the file, which will be the name of our app | |
APP_LABEL, _ = os.path.splitext(os.path.basename(os.path.abspath(__file__))) | |
# Migrations folder need to be created, and django needs to be told where it is | |
APP_MIGRATION_MODULE = '%s_migrations' % APP_LABEL | |
APP_MIGRATION_PATH = os.path.join(BASE_DIR, APP_MIGRATION_MODULE) | |
# Create the folder and a __init__.py if they don't exist | |
if not os.path.exists(APP_MIGRATION_PATH): | |
os.makedirs(APP_MIGRATION_PATH) | |
open(os.path.join(APP_MIGRATION_PATH, '__init__.py'), 'w').close() | |
# Hack to trick Django into thinking this file is actually a package | |
sys.modules[APP_LABEL] = sys.modules[__name__] | |
sys.modules[APP_LABEL].__path__ = [os.path.abspath(__file__)] | |
settings.configure( | |
DEBUG=True, | |
ROOT_URLCONF='%s.urls' % APP_LABEL, | |
MIDDLEWARE=('%s.views.Middleware1' % APP_LABEL, '%s.views.middleware2' % APP_LABEL), | |
INSTALLED_APPS=[APP_LABEL], | |
MIGRATION_MODULES={APP_LABEL: APP_MIGRATION_MODULE}, | |
SITE_ID=1, | |
DATABASES={ | |
'default': { | |
'ENGINE': 'django.db.backends.sqlite3', | |
'NAME': "db.sqlite3", | |
} | |
}, | |
LOGGING={}, | |
STATIC_URL='/static/' | |
) | |
django.setup() | |
# Setup the AppConfig so we don't have to add the app_label to all our models | |
def get_containing_app_config(module): | |
if module == '__main__': | |
return apps.get_app_config(APP_LABEL) | |
return apps._get_containing_app_config(module) | |
apps._get_containing_app_config = apps.get_containing_app_config | |
apps.get_containing_app_config = get_containing_app_config | |
# Your code below this line | |
# ############################################################################## | |
from django.test import TestCase # noqa: E402 isort:skip | |
from django.http import HttpResponse | |
from django.urls import path | |
from django.db import models | |
from django.utils.decorators import decorator_from_middleware | |
from django.test.utils import override_settings | |
class Middleware1: | |
def __init__(self, get_response): | |
self.get_response = get_response | |
def __call__(self, request): | |
print('before middleware 1') | |
response = self.get_response(request) | |
print('after middleware 1') | |
return response | |
def middleware2(get_response): | |
def middleware(request): | |
print('before middleware 2') | |
response = get_response(request) | |
print('after middleware 2') | |
return response | |
return middleware | |
def view1(request): | |
return HttpResponse('view1') | |
def view2(request): | |
return HttpResponse('view2') | |
urlpatterns = [ | |
path('1', view1), | |
path('2', view2), | |
] | |
urlpatterns_ex = [ | |
path('1', decorator_from_middleware(Middleware1)(view1)), | |
path('2', view2), | |
] | |
import unittest.mock | |
class SimpleTestCase(TestCase): | |
def test_view_all_middlewares(self): | |
expected = [ | |
unittest.mock.call('before middleware 1'), | |
unittest.mock.call('before middleware 2'), | |
unittest.mock.call('after middleware 2'), | |
unittest.mock.call('after middleware 1') | |
] | |
with unittest.mock.patch('builtins.print') as mock_print: | |
self.assertEqual(self.client.get('/1').status_code, 200) | |
self.assertEqual(mock_print.call_args_list, expected) | |
with unittest.mock.patch('builtins.print') as mock_print: | |
self.assertEqual(self.client.get('/2').status_code, 200) | |
self.assertEqual(mock_print.call_args_list, expected) | |
@override_settings(ROOT_URLCONF='%s.urls_test' % APP_LABEL) | |
def test_exclude_middleware(self): | |
self.assertEqual(self.client.get('/1').status_code, 200) | |
# Your code above this line | |
# ############################################################################## | |
# Used so you can do 'from <name of file>.models import *' | |
views_module = ModuleType('%s.views' % (APP_LABEL)) | |
views_module.view1 = view1 | |
views_module.view2 = view2 | |
views_module.Middleware1 = Middleware1 | |
views_module.middleware2 = middleware2 | |
models_module = ModuleType('%s.models' % (APP_LABEL)) | |
tests_module = ModuleType('%s.tests' % (APP_LABEL)) | |
urls_module = ModuleType('%s.urls' % (APP_LABEL)) | |
urls_module.urlpatterns = urlpatterns | |
urls_module_test = ModuleType('%s.urls_test' % (APP_LABEL)) | |
urls_module_test.urlpatterns = urlpatterns_ex | |
for variable_name, value in list(locals().items()): | |
# We are only interested in models | |
if inspect.isclass(value) and issubclass(value, models.Model): | |
setattr(models_module, variable_name, value) | |
# We are only interested in tests | |
if inspect.isclass(value) and issubclass(value, TestCase): | |
setattr(tests_module, variable_name, value) | |
# Setup the fake module s | |
sys.modules[views_module.__name__] = views_module | |
sys.modules[models_module.__name__] = models_module | |
sys.modules[tests_module.__name__] = tests_module | |
sys.modules[urls_module.__name__] = urls_module | |
sys.modules[urls_module_test.__name__] = urls_module_test | |
sys.modules[APP_LABEL].models = models_module | |
sys.modules[APP_LABEL].tests = tests_module | |
sys.modules[APP_LABEL].urls = urls_module | |
sys.modules[APP_LABEL].urls_test = urls_module_test | |
if __name__ == "__main__": | |
# Hack to fix tests | |
argv = [arg for arg in sys.argv if not arg.startswith('-')] | |
if len(argv) == 2 and argv[1] == 'test': | |
sys.argv.append(APP_LABEL) | |
from django.core.management import execute_from_command_line | |
execute_from_command_line(sys.argv) | |
else: | |
from django.core.wsgi import get_wsgi_application | |
get_wsgi_application() |
Author
kingbuzzman
commented
Feb 6, 2024
•
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment