Skip to content

Instantly share code, notes, and snippets.

@Niccolum
Created June 14, 2022 06:22
Show Gist options
  • Save Niccolum/245ad37b74c0713c240155d81730bfeb to your computer and use it in GitHub Desktop.
Save Niccolum/245ad37b74c0713c240155d81730bfeb to your computer and use it in GitHub Desktop.
How to prepare conftest for testing fastapi with sqlmodel
import asyncio
import logging
import typing as t
from contextlib import contextmanager
import fastapi
import pytest
from httpx import AsyncClient
from mixer.backend.sqlalchemy import Mixer
from pytest_mock import MockerFixture
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import scoped_session, sessionmaker
from sqlalchemy.orm.session import Session
from sqlalchemy.pool import NullPool
from sqlmodel import SQLModel, create_engine
from sqlmodel.ext.asyncio.session import AsyncEngine, AsyncSession
from application.models import * # type: ignore
from core.config import settings
from core.db import get_engine
from main import create_app
from other.models import * # type: ignore
from scoring.models import * # type: ignore
from .helpers import create_database, database_exists, drop_database
@pytest.fixture(scope="session")
def event_loop():
# https://github.com/pytest-dev/pytest-asyncio#async-fixtures
return asyncio.get_event_loop()
@pytest.fixture(scope="session")
def app():
return create_app()
@pytest.fixture
async def test_client(app: fastapi.FastAPI) -> t.AsyncGenerator[AsyncClient, None]:
async with AsyncClient(app=app, base_url="http://test") as c:
yield c
@pytest.fixture
def override_settings(app: fastapi.FastAPI):
@contextmanager
def wrapper(key: t.Callable[..., t.Any], value: t.Callable[..., t.Any]):
try:
app.dependency_overrides[key] = value
yield
finally:
app.dependency_overrides = {}
return wrapper
@pytest.fixture(scope="session")
async def async_engine() -> AsyncEngine:
engine = AsyncEngine(
create_engine(
settings.ASYNC_TEST_DATABASE_URL, # type: ignore
echo=True,
poolclass=NullPool,
)
)
yield engine
await engine.dispose()
@pytest.fixture(scope="session")
def sync_engine() -> Engine:
engine = create_engine(
settings.SYNC_TEST_DATABASE_URL, # type: ignore
echo=True,
poolclass=NullPool,
)
yield engine
engine.dispose()
@pytest.fixture(scope="session")
async def prepare_db() -> None:
app_async_engine = get_engine()
if await database_exists(app_async_engine, test_db=settings.TEST_POSTGRES_DB):
await drop_database(app_async_engine, test_db=settings.TEST_POSTGRES_DB)
await create_database(app_async_engine, test_db=settings.TEST_POSTGRES_DB)
yield
await drop_database(app_async_engine, test_db=settings.TEST_POSTGRES_DB)
await app_async_engine.dispose()
@pytest.fixture(scope="session")
async def setup_database(async_engine: AsyncEngine, prepare_db: None):
logging.disable(logging.WARNING)
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
yield
async with async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.drop_all)
@pytest.fixture(scope="session")
async def created_async_session(mocker: MockerFixture, async_engine: AsyncEngine, setup_database: None) -> AsyncSession:
_ = mocker.patch("core.db.get_engine", return_value=async_engine)
session = scoped_session( # pyright: reportUnknownArgumentType=false
sessionmaker( # pyright: reportUnknownVariableType=false
async_engine,
class_=AsyncSession, # pyright: reportGeneralTypeIssues=false
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
)
yield session
await session.close()
@pytest.fixture
async def async_session(created_async_session: AsyncSession):
_ = await created_async_session.begin_nested()
yield created_async_session
_ = await created_async_session.rollback()
@pytest.fixture(scope="session")
def created_sync_session(sync_engine: Engine, setup_database: None) -> Session:
session = scoped_session( # pyright: reportUnknownArgumentType=false
sessionmaker( # pyright: reportUnknownVariableType=false
sync_engine,
expire_on_commit=False,
autocommit=False,
autoflush=False,
)
)
yield session
session.close()
@pytest.fixture
def sync_session(created_sync_session: Session):
_ = created_sync_session.begin_nested()
yield created_sync_session
created_sync_session.rollback()
@pytest.fixture
async def mixer_db(sync_session: Session):
return Mixer(session=sync_session, commit=True)
@pytest.fixture
async def generate_model(mixer_db: Mixer) -> t.Callable[..., t.List[t.Type[SQLModel]]]:
def wrapper(mymodel: str, **kwargs: t.Any) -> t.List[t.Type[SQLModel]]:
models: t.List[t.Type[SQLModel]] = mixer_db.blend(mymodel, **kwargs)
return models
return wrapper
import typing as t
from sqlalchemy import text
if t.TYPE_CHECKING:
from sqlmodel.ext.asyncio.session import AsyncEngine
async def create_database(engine: "AsyncEngine", test_db: str) -> None:
async with engine.connect() as conn:
_ = await conn.execute(text("COMMIT")) # type: ignore
create_db_command = text(f"CREATE DATABASE {test_db}")
_ = await conn.execute(create_db_command) # type: ignore
async def database_exists(engine: "AsyncEngine", test_db: str) -> bool:
dialect_name = engine.dialect.name
if dialect_name == "postgresql":
async with engine.connect() as conn:
check_db_command = text(f"SELECT 1 FROM pg_database WHERE datname='{test_db}'")
result = await conn.scalar(check_db_command) # type: ignore
return bool(result)
else:
raise
async def drop_database(engine: "AsyncEngine", test_db: str) -> None:
async with engine.connect() as conn:
_ = await conn.execute(text("COMMIT")) # type: ignore
drop_db_command = text(f"DROP DATABASE {test_db}")
_ = await conn.execute(drop_db_command) # type: ignore
@Niccolum
Copy link
Author

Привет)

Асинхронный контекст для приложения. Синхронный - для фикстур через Mixer (он пока не поддерживает асинхронный энжин)

Hello)
Async context for app. Sync - for db fixtures with Mixer (sadly, it not support async engine)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment