Skip to content

Instantly share code, notes, and snippets.

@chbndrhnns
Created May 27, 2021 11:35
Show Gist options
  • Save chbndrhnns/47f7734ea4ee65839b43834327ddebd3 to your computer and use it in GitHub Desktop.
Save chbndrhnns/47f7734ea4ee65839b43834327ddebd3 to your computer and use it in GitHub Desktop.
Testing subprocesses
from pathlib import Path
class GunicornConfigBuilder:
def __init__(self):
self._cmd = "gunicorn"
self._log_level = None
self._app_module = "tests.integration.webserver.app:app"
self._check_only = False
self._print_only = False
self._config_file = Path(__file__).parent / "config" / "gunicorn_conf.py"
self._worker = "uvicorn.workers.UvicornWorker"
self._access_log = None
self._error_log = None
self._is_dirty = False
def build(self):
cmd = [self._cmd]
if self._config_file:
cmd.extend(["-c", str(self._config_file)])
if self._print_only:
cmd.append("--print-config")
if self._check_only:
cmd.append("--check-config")
if self._log_level:
cmd.append(f"--log-level={self._log_level}")
if self._access_log:
cmd.append(f"--access-logfile={self._access_log}")
if self._error_log:
cmd.append(f"--error-logfile={self._error_log}")
cmd.extend(["-k", self._worker])
cmd.append(self._app_module)
print(f'\nCommand: {" ".join(cmd)}')
return cmd
def check_only(self):
if self._print_only:
raise ValueError("'print_only' and 'check_only' are mutually exclusive")
self._check_only = True
self._is_dirty = True
return self
def print_only(self):
if self._check_only:
raise ValueError("'print_only' and 'check_only' are mutually exclusive")
self._print_only = True
self._is_dirty = True
return self
def without_config(self):
self._config_file = None
self._is_dirty = True
return self
def with_loglevel(self, val):
if val not in ("debug", "info"):
raise ValueError(f"Invalid log level. Got: {val}")
self._log_level = val
return self
def production(self):
if self._is_dirty:
raise ValueError("Cannot build with changed defaults")
return self
def with_access_log(self):
self._access_log = "-"
return self
def with_error_log(self):
self._error_log = "-"
return self
import ast
import subprocess
from typing import MutableMapping
import pytest
from pydantic import BaseModel, validator
from .config_builder import GunicornConfigBuilder
def create_config_model(captured):
print(captured)
a = dict(item.split("=") for item in captured.replace(" ", "").splitlines())
return GunicornConfig.parse_obj(a)
class GunicornConfig(BaseModel):
access_log_format: str
logger_class: str
default_proc_name: str
logconfig_dict: MutableMapping = {}
loglevel: str
worker_class: str
config: str
@validator("logconfig_dict", pre=True)
def _parse_logconfig(cls, val):
if isinstance(val, str):
return ast.literal_eval(val)
return val
@pytest.fixture(scope="session")
def gunicorn_options():
config = GunicornConfigBuilder().print_only().build()
return config
@pytest.fixture(scope="class")
def gunicorn__class(gunicorn_options):
captured = _run_process(gunicorn_options)
return create_config_model(captured.stdout)
@pytest.fixture
def gunicorn(gunicorn_options):
captured = _run_process(gunicorn_options)
return create_config_model(captured.stdout)
def _run_process(args):
return subprocess.run(args, stdout=subprocess.PIPE, text=True)
import asyncio
import os
import signal
import subprocess
from abc import ABCMeta, abstractmethod
from typing import Set
from .config_builder import GunicornConfigBuilder
class Events:
request_processed = asyncio.Future()
startup_complete = asyncio.Future()
should_stop = asyncio.Future()
class ManageGunicorn:
def __init__(self, *, gunicorn_config=None, line_handlers: Set = None):
self.events = Events()
self._gunicorn_process = None
self._gunicorn_config = gunicorn_config or GunicornConfigBuilder().build()
self._stdout_task = None
line_handlers = (line_handlers or set()).union(self.default_line_handlers)
self._line_handlers = {h: h() for h in line_handlers}
@property
def default_line_handlers(self) -> Set:
return {IsReadyLine, NotFoundLine, ServerRunningLine}
async def start_gunicorn(self):
self._gunicorn_process = await asyncio.create_subprocess_exec(
*self._gunicorn_config, stderr=subprocess.STDOUT, stdout=subprocess.PIPE
)
self._stdout_task = asyncio.create_task(self._process_stdout())
try:
await self.events.startup_complete
except Exception:
await self.stop_gunicorn()
raise
# await self._timeout_after(5.0)
async def stop_gunicorn(self):
print("stopping gunicorn")
os.kill(self._gunicorn_process.pid, signal.SIGTERM)
async def _process_line(self, line):
print(f"processing '{line}'")
for handler in self._line_handlers.values():
handler(line, self.events)
async def _process_stdout(self):
while True:
line = (
(await self._gunicorn_process.stdout.readline()).decode("utf-8").strip()
)
if line:
await self._process_line(line)
# async def _timeout_after(self, seconds: float):
# await asyncio.wait_for(self._events.request_processed, seconds)
# self._events.request_processed.set_exception(asyncio.TimeoutError(
# f"No request received in {seconds} seconds"))
class BaseLineHandler(metaclass=ABCMeta):
def __init__(self):
self._processed_lines = []
def __call__(self, line, events):
self._line = line
self._events: Events = events
if self._impl():
self._processed_lines.append(line)
@abstractmethod
def _impl(self):
raise NotImplementedError()
class IsReadyLine(BaseLineHandler):
def _impl(self):
if self._line.endswith("Application startup complete."):
print(f"{self.__class__.__name__}: Handling '{self._line}'")
self._events.startup_complete.set_result(None)
return True
return False
class NotFoundLine(BaseLineHandler):
def _impl(self):
if "GET / " in self._line:
print(f"{self.__class__.__name__}: Handling '{self._line}'")
self._events.request_processed.set_result(self._line)
return True
return False
class ServerRunningLine(BaseLineHandler):
def _impl(self):
if "Connection in use" in self._line:
print(f"{self.__class__.__name__}: Handling '{self._line}'")
self._events.should_stop.set_result(None)
return True
return False
import pytest
class TestGunicornDefaults:
def test_access_log_format(self, gunicorn__class):
assert (
gunicorn__class.access_log_format
== '%(h)s%(l)s%(u)s%(t)s"%(r)s"%(s)s%(b)s"%(f)s""%(a)s"'
)
def test_config_file(self, gunicorn__class):
assert gunicorn__class.config.endswith("webserver/config/gunicorn_conf.py")
def test_logger_class(self, gunicorn__class):
assert gunicorn__class.logger_class == "gunicorn.glogging.Logger"
def test_app_name(self, gunicorn__class):
assert (
gunicorn__class.default_proc_name == "tests.integration.webserver.app:app"
)
def test_loglevel(self, gunicorn__class):
assert gunicorn__class.loglevel == "info"
def test_worker(self, gunicorn__class):
assert gunicorn__class.worker_class == "uvicorn.workers.UvicornWorker"
class TestModifyEnv:
@pytest.fixture
def set_loglevel(self, monkeypatch):
monkeypatch.setenv("JUICE_LOG_LEVEL", "debug")
def test_can_modify_loglevel(self, set_loglevel, gunicorn):
assert gunicorn.loglevel == "debug"
from datetime import datetime
import httpx
import pytest
from .config_builder import GunicornConfigBuilder
from .manage_gunicorn import ManageGunicorn
timestamp_format = "%Y-%m-%d %H:%M:%S,%f"
@pytest.fixture
def gunicorn_config():
return GunicornConfigBuilder().with_access_log().build()
class TestParsing:
def test_can_parse_timestamp(self):
date_time_str = "2018-06-29 08:15:27,330"
actual = datetime.strptime(date_time_str, timestamp_format)
assert actual.day == 29
assert actual.month == 6
assert actual.year == 2018
assert actual.hour == 8
assert actual.minute == 15
assert actual.second == 27
assert actual.microsecond == 330000
class TestRequest:
@pytest.fixture(autouse=True)
async def server(self, gunicorn_config):
g = ManageGunicorn(gunicorn_config=gunicorn_config)
await g.start_gunicorn()
yield g
await g.stop_gunicorn()
async def test_request_is_logged(self, server):
send_request()
await server.events.request_processed
@pytest.mark.skip()
async def test_has_correct_log_format(self, server):
send_request()
res = await server.events.request_processed
assert res == ""
assert parse_timestamp(res[:22])
def send_request():
print("sending request")
try:
httpx.get("http://localhost:8000/")
except httpx.HTTPError:
raise
def parse_timestamp(val):
return datetime.strptime(val, timestamp_format)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment