Last active
January 25, 2025 15:10
-
-
Save edpyt/e670e0ad6b268b97cbd18e9b644417fb to your computer and use it in GitHub Desktop.
SQLAlchemy, PyTest nested transaction mode
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Generator | |
import pytest | |
from sqlalchemy import Engine, create_engine | |
from sqlalchemy.orm import Session, sessionmaker | |
from .models import Base | |
@pytest.fixture(name="engine", scope="session") | |
def sa_engine() -> Generator[Engine, None, None]: | |
engine = create_engine("sqlite+pysqlite:///:memory:", echo=True) | |
Base.metadata.create_all(engine) | |
yield engine | |
Base.metadata.drop_all(engine) | |
@pytest.fixture(name="session_factory", scope="session") | |
def sa_session_factory() -> sessionmaker: | |
return sessionmaker() | |
@pytest.fixture(name="session") | |
def sa_session(engine: Engine, session_factory: sessionmaker) -> Generator[Session, None, None]: | |
conn = engine.connect() | |
trans = conn.begin() | |
session = session_factory(bind=conn) | |
yield session | |
session.close() | |
trans.rollback() | |
conn.close() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | |
class Base(DeclarativeBase): | |
... | |
class User(Base): | |
__tablename__ = "user" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
name: Mapped[str] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sqlalchemy import select | |
from sqlalchemy.orm import Session | |
from .models import User | |
def test_create_user(session: Session) -> None: | |
user = User(id=1, name="bob") | |
session.add(user) | |
session.commit() | |
assert user in session | |
assert user.id == 1 | |
assert user.name == "bob" | |
def test_create_same_user(session: Session) -> None: | |
user = User(id=1, name="bob") | |
session.add(user) | |
session.commit() | |
assert user.id == 1 | |
users = session.execute(select(User)).scalars().all() | |
assert len(users) == 1 | |
def test_users_empty_in_database(session: Session) -> None: | |
users = session.execute(select(User)).scalars().all() | |
assert len(users) == 0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment