Created
June 14, 2022 06:22
-
-
Save Niccolum/245ad37b74c0713c240155d81730bfeb to your computer and use it in GitHub Desktop.
How to prepare conftest for testing fastapi with sqlmodel
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
import typing as t | |
from contextlib import contextmanager | |
import fastapi | |
import pytest | |
from httpx import AsyncClient | |
from mixer.backend.sqlalchemy import Mixer | |
from pytest_mock import MockerFixture | |
from sqlalchemy.engine.base import Engine | |
from sqlalchemy.orm import scoped_session, sessionmaker | |
from sqlalchemy.orm.session import Session | |
from sqlalchemy.pool import NullPool | |
from sqlmodel import SQLModel, create_engine | |
from sqlmodel.ext.asyncio.session import AsyncEngine, AsyncSession | |
from application.models import * # type: ignore | |
from core.config import settings | |
from core.db import get_engine | |
from main import create_app | |
from other.models import * # type: ignore | |
from scoring.models import * # type: ignore | |
from .helpers import create_database, database_exists, drop_database | |
@pytest.fixture(scope="session") | |
def event_loop(): | |
# https://github.com/pytest-dev/pytest-asyncio#async-fixtures | |
return asyncio.get_event_loop() | |
@pytest.fixture(scope="session") | |
def app(): | |
return create_app() | |
@pytest.fixture | |
async def test_client(app: fastapi.FastAPI) -> t.AsyncGenerator[AsyncClient, None]: | |
async with AsyncClient(app=app, base_url="http://test") as c: | |
yield c | |
@pytest.fixture | |
def override_settings(app: fastapi.FastAPI): | |
@contextmanager | |
def wrapper(key: t.Callable[..., t.Any], value: t.Callable[..., t.Any]): | |
try: | |
app.dependency_overrides[key] = value | |
yield | |
finally: | |
app.dependency_overrides = {} | |
return wrapper | |
@pytest.fixture(scope="session") | |
async def async_engine() -> AsyncEngine: | |
engine = AsyncEngine( | |
create_engine( | |
settings.ASYNC_TEST_DATABASE_URL, # type: ignore | |
echo=True, | |
poolclass=NullPool, | |
) | |
) | |
yield engine | |
await engine.dispose() | |
@pytest.fixture(scope="session") | |
def sync_engine() -> Engine: | |
engine = create_engine( | |
settings.SYNC_TEST_DATABASE_URL, # type: ignore | |
echo=True, | |
poolclass=NullPool, | |
) | |
yield engine | |
engine.dispose() | |
@pytest.fixture(scope="session") | |
async def prepare_db() -> None: | |
app_async_engine = get_engine() | |
if await database_exists(app_async_engine, test_db=settings.TEST_POSTGRES_DB): | |
await drop_database(app_async_engine, test_db=settings.TEST_POSTGRES_DB) | |
await create_database(app_async_engine, test_db=settings.TEST_POSTGRES_DB) | |
yield | |
await drop_database(app_async_engine, test_db=settings.TEST_POSTGRES_DB) | |
await app_async_engine.dispose() | |
@pytest.fixture(scope="session") | |
async def setup_database(async_engine: AsyncEngine, prepare_db: None): | |
logging.disable(logging.WARNING) | |
async with async_engine.begin() as conn: | |
await conn.run_sync(SQLModel.metadata.create_all) | |
yield | |
async with async_engine.begin() as conn: | |
await conn.run_sync(SQLModel.metadata.drop_all) | |
@pytest.fixture(scope="session") | |
async def created_async_session(mocker: MockerFixture, async_engine: AsyncEngine, setup_database: None) -> AsyncSession: | |
_ = mocker.patch("core.db.get_engine", return_value=async_engine) | |
session = scoped_session( # pyright: reportUnknownArgumentType=false | |
sessionmaker( # pyright: reportUnknownVariableType=false | |
async_engine, | |
class_=AsyncSession, # pyright: reportGeneralTypeIssues=false | |
expire_on_commit=False, | |
autocommit=False, | |
autoflush=False, | |
) | |
) | |
yield session | |
await session.close() | |
@pytest.fixture | |
async def async_session(created_async_session: AsyncSession): | |
_ = await created_async_session.begin_nested() | |
yield created_async_session | |
_ = await created_async_session.rollback() | |
@pytest.fixture(scope="session") | |
def created_sync_session(sync_engine: Engine, setup_database: None) -> Session: | |
session = scoped_session( # pyright: reportUnknownArgumentType=false | |
sessionmaker( # pyright: reportUnknownVariableType=false | |
sync_engine, | |
expire_on_commit=False, | |
autocommit=False, | |
autoflush=False, | |
) | |
) | |
yield session | |
session.close() | |
@pytest.fixture | |
def sync_session(created_sync_session: Session): | |
_ = created_sync_session.begin_nested() | |
yield created_sync_session | |
created_sync_session.rollback() | |
@pytest.fixture | |
async def mixer_db(sync_session: Session): | |
return Mixer(session=sync_session, commit=True) | |
@pytest.fixture | |
async def generate_model(mixer_db: Mixer) -> t.Callable[..., t.List[t.Type[SQLModel]]]: | |
def wrapper(mymodel: str, **kwargs: t.Any) -> t.List[t.Type[SQLModel]]: | |
models: t.List[t.Type[SQLModel]] = mixer_db.blend(mymodel, **kwargs) | |
return models | |
return wrapper |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing as t | |
from sqlalchemy import text | |
if t.TYPE_CHECKING: | |
from sqlmodel.ext.asyncio.session import AsyncEngine | |
async def create_database(engine: "AsyncEngine", test_db: str) -> None: | |
async with engine.connect() as conn: | |
_ = await conn.execute(text("COMMIT")) # type: ignore | |
create_db_command = text(f"CREATE DATABASE {test_db}") | |
_ = await conn.execute(create_db_command) # type: ignore | |
async def database_exists(engine: "AsyncEngine", test_db: str) -> bool: | |
dialect_name = engine.dialect.name | |
if dialect_name == "postgresql": | |
async with engine.connect() as conn: | |
check_db_command = text(f"SELECT 1 FROM pg_database WHERE datname='{test_db}'") | |
result = await conn.scalar(check_db_command) # type: ignore | |
return bool(result) | |
else: | |
raise | |
async def drop_database(engine: "AsyncEngine", test_db: str) -> None: | |
async with engine.connect() as conn: | |
_ = await conn.execute(text("COMMIT")) # type: ignore | |
drop_db_command = text(f"DROP DATABASE {test_db}") | |
_ = await conn.execute(drop_db_command) # type: ignore |
Привет)
Асинхронный контекст для приложения. Синхронный - для фикстур через Mixer (он пока не поддерживает асинхронный энжин)
Hello)
Async context for app. Sync - for db fixtures with Mixer (sadly, it not support async engine)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Здравствуйте, спасибо, что поделились.
зачем вам тестировать как асинхронный? какой контекст?
Hello, thanks for sharing.
why do you need to test as async ? what context ?