Skip to content

Instantly share code, notes, and snippets.

Last active March 17, 2024 07:04
Show Gist options
  • Save HacKanCuBa/bfee44fb8f3e81289c36c7bf5a579dfa to your computer and use it in GitHub Desktop.
Save HacKanCuBa/bfee44fb8f3e81289c36c7bf5a579dfa to your computer and use it in GitHub Desktop.
SQLAlchemy handy helper functions
import functools
from contextlib import asynccontextmanager, contextmanager
from time import monotonic
from typing import Annotated, Any, AsyncGenerator, Generator, Hashable, Iterable, Literal, Optional, Sized, Union, overload
from sqlalchemy import event
from sqlalchemy.dialects.mysql.asyncmy import AsyncAdapt_asyncmy_cursor
from sqlalchemy.engine import URL, Connection, Engine, Row, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.pool import AsyncAdaptedQueuePool, QueuePool
AnyCacheable = Annotated[Hashable, "Any type that works well with functools.cache, meaning hashable (i.e., not dicts!)"]
@event.listens_for(Engine, "before_cursor_execute")
def _before_cursor_execute(conn: Connection, *_: Any) -> None:"query_start_time", []).append(monotonic())
# noinspection PyUnusedLocal
@event.listens_for(Engine, "after_cursor_execute")
def _after_cursor_execute(
conn: Connection,
cursor: AsyncAdapt_asyncmy_cursor,
statement: str,
parameters: tuple[dict[str, Any], ...] | dict[str, Any] | None,
*_: Any,
) -> None:
total = monotonic() -["query_start_time"].pop(-1)
# logger.debug('DB query\n\t%s\n\tparams: %s\n\tfinished in %f seconds', statement.replace("\n", ""), parameters, total)
def _get_db_engine(db_url: Union[str, URL], *, sync: Literal[True], **kwargs: AnyCacheable) -> Engine:
def _get_db_engine(db_url: Union[str, URL], *, sync: Literal[False], **kwargs: AnyCacheable) -> AsyncEngine:
def _get_db_engine(db_url: Union[str, URL], *, sync: bool, **kwargs: AnyCacheable) -> Union[Engine, AsyncEngine]:
if "connect_args" in kwargs:
connect_args_raw = kwargs.pop("connect_args")
assert isinstance(connect_args_raw, Iterable) and all(
isinstance(arg, Sized) and len(arg) == 2 for arg in connect_args_raw
connect_args = dict(connect_args_raw)
connect_args = {"connect_timeout": 5} # Some dialects use "timeout"
poolclass = QueuePool if sync else AsyncAdaptedQueuePool
# You may want to move some of this to some sort of global constant
params = {
"isolation_level": "READ COMMITTED", # See
"echo": kwargs.pop("echo", False), # Don't be so verbose unless this is true
"future": True,
"connect_args": connect_args,
"poolclass": poolclass,
if sync:
return create_engine(db_url, **params)
return create_async_engine(db_url, **params)
async def async_db_engine(db_url: Union[str, URL], **kwargs: AnyCacheable) -> AsyncGenerator[AsyncEngine, None]:
"""Get a new async pooled engine ready to be used, as a context manager."""
engine = _get_db_engine(db_url, sync=False, **kwargs)
yield engine
await engine.dispose()
async def async_db_session(engine: AsyncEngine, **kwargs: Any) -> AsyncGenerator[AsyncSession, None]:
"""Get a new async ORM session ready to be used, as a context manager."""
# You may want to move some of this to some sort of global constant
params = {
"expire_on_commit": False,
async_session = async_sessionmaker(engine, **params) # type: ignore[call-overload]
async with async_session() as session:
yield session
def db_engine(db_url: Union[str, URL], **kwargs: Any) -> Generator[Engine, None, None]:
"""Get a new pooled engine ready to be used, as a context manager."""
engine = _get_db_engine(db_url, sync=True, **kwargs)
yield engine
def db_session(engine: Engine, **kwargs: Any) -> Generator[Session, None, None]:
"""Get a new ORM session ready to be used, as a context manager."""
params = {
"expire_on_commit": False,
session = sessionmaker(engine, **params) # type: ignore[call-overload]
with session() as session:
yield session
def asdict(row: Row) -> dict[str, Any]:
"""Convert a row to a dict."""
# Yeah, I have no idea why it's a protected method, but it is properly documented, and we are supposed to use this.
# See:
# noinspection PyProtectedMember
dct = row._asdict() # this may have keys as `sqlalchemy.sql.elements.quoted_name` instead of str
return {f"{key}": value for key, value in dct.items()}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment