Skip to content

Instantly share code, notes, and snippets.

@DiTo97
Last active March 13, 2025 02:16
Show Gist options
  • Save DiTo97/013edc70ffbacd3a9e63ff62f2f430f8 to your computer and use it in GitHub Desktop.
Save DiTo97/013edc70ffbacd3a9e63ff62f2f430f8 to your computer and use it in GitHub Desktop.
A feature-rich database model for use with SQLAlchemy's ORM abstraction
import typing
from uuid import UUID, uuid4
from datetime import datetime, timezone
from functools import partial
from typing import Any
from pydantic import BaseModel
from sqlalchemy import select, func, sql
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as postgres_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.engine.default import DefaultExecutionContext
from sqlalchemy.ext.asyncio import AsyncAttrs, async_object_session, AsyncSession
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from sqlalchemy.sql.base import ExecutableOption
from typing_extensions import Self
if typing.TYPE_CHECKING:
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.sql.elements import BinaryExpression
Schema = typing.TypeVar("Schema", bound=BaseModel)
def get_field_value(context: DefaultExecutionContext, field: str) -> Any:
return context.get_current_parameters(isolate_multiinsert_groups=False)[field]
get_field_value_created_at = partial(get_field_value, field="created_at")
# TODO: should remove once SQLAlchemy has an agnostic insert
_insert_func = {
"sqlite": sqlite_insert,
"postgresql": postgres_insert,
"mysql": mysql_insert,
}
def _schema_or_values(schema: Schema | None, values: dict[str, Any]) -> dict[str, Any]:
return schema.model_dump() if schema else values
def _utcnow() -> datetime:
"""constructs a UTC datetime from time"""
return datetime.now(timezone.utc)
class NotFoundException(Exception):
"""raises when a domain entity has not be found on the system"""
def __init__(self, message: str, code: int = 404):
self.message = message
self.code = code
class CRUDMixin:
__upsertable_columns__: set[str] | None = None
def fill(self, replace_dict: bool = False, **kwargs: Any) -> Self:
for key, value in kwargs.items():
if not hasattr(self, key):
message = f"`{self.__class__.__name__}` has no attribute `{key}`"
raise AttributeError(message)
# If the value is a dict, set value for each key one by one, as we have to update
# only the keys that are in `value` and not override the whole dict.
if isinstance(value, dict) and not replace_dict:
dict_value = getattr(self, key) or {}
dict_value.update(value)
value = dict_value
setattr(self, key, value)
return self
@classmethod
async def create(
cls, session: AsyncSession, schema: Schema | None = None, autocommit: bool = True, **kwargs: Any
) -> Self:
values = _schema_or_values(schema, kwargs)
instance = cls()
instance.fill(**values)
return await instance.save(session, autocommit)
@classmethod
async def get(cls, session: AsyncSession, key: UUID, options: list[ExecutableOption] = []) -> Self | None:
return (await session.execute(select(cls).filter_by(key=key).options(*options))).scalar_one_or_none()
@classmethod
async def get_or_raise(cls, session: AsyncSession, key: UUID, options: list[ExecutableOption] = []) -> Self:
instance = await cls.get(session, key, options)
if not instance:
message = f"{cls.__name__} with key `{key}` not found"
raise NotFoundException(message)
return instance
@classmethod
async def exists(cls, session: AsyncSession, **conditions) -> bool:
return (await session.execute(select(cls).filter_by(**conditions).exists())).scalar()
@classmethod
async def get_by(cls, session: AsyncSession, **conditions) -> Self | None:
return (await session.execute(select(cls).filter_by(**conditions))).scalar_one_or_none()
@classmethod
async def get_by_or_raise(cls, session: AsyncSession, **conditions) -> Self:
instance = await cls.get_by(session, **conditions)
if not instance:
message = ", ".join([f"{key}={value}" for key, value in conditions.items()])
message = f"{cls.__name__} not found filtering by {message}"
raise NotFoundException(message)
return instance
@classmethod
async def count_by(cls, session: AsyncSession, **conditions) -> int:
return (await session.execute(select(func.count(cls.id)).filter_by(**conditions))).scalar_one()
async def update(
self,
session: AsyncSession,
schema: Schema | None = None,
replace_dict: bool = False,
autocommit: bool = True,
**kwargs: Any,
) -> Self:
values = _schema_or_values(schema, kwargs)
result = self.fill(replace_dict=replace_dict, **values)
return await result.save(session, autocommit)
@classmethod
async def update_many(
cls,
session: AsyncSession,
objects: list[dict[str, Any]],
autocommit: bool = True,
) -> None:
if not objects:
message = "cannot update empty list of objects"
raise ValueError(message)
await session.execute(sql.update(cls), objects)
if autocommit:
await session.commit()
@classmethod
async def upsert_many(
cls,
session: AsyncSession,
objects: list[Schema | dict],
constraints: list["InstrumentedAttribute[Any]"],
autocommit: bool = True,
) -> list[Self]:
if not objects:
message = "cannot upsert empty list of objects"
raise ValueError(message)
values = [obj if isinstance(obj, dict) else obj.model_dump() for obj in objects]
statement = _insert_func[session.bind.dialect.name](cls).values(values)
# On conflict, should update only the columns that are upsertable
columns_to_update = {column: statement.excluded[column] for column in cls.__upsertable_columns__}
# onupdate for `updated_at` is not working; should force a new value on update
if hasattr(cls, "updated_at"):
columns_to_update["updated_at"] = _utcnow()
statement = (
statement.on_conflict_do_update(index_elements=constraints, set_=columns_to_update)
.returning(cls)
.execution_options(populate_existing=True)
)
result = await session.execute(statement)
if autocommit:
await session.commit()
return result.scalars().all()
@classmethod
async def upsert(
cls,
session: AsyncSession,
schema: Schema | dict,
constraints: list["InstrumentedAttribute[Any]"],
autocommit: bool = True,
) -> Self:
result = await cls.upsert_many(session, [schema], constraints, autocommit)
return result[0]
async def delete(self, session: AsyncSession, autocommit: bool = True) -> Self:
await session.delete(self)
if autocommit:
await session.commit()
return self
@classmethod
async def delete_many(
cls, session: AsyncSession, conditions: list["BinaryExpression"], autocommit: bool = True
) -> list[Self]:
statement = sql.delete(cls).where(*conditions).returning(cls)
result = await session.execute(statement)
if autocommit:
await session.commit()
return result.scalars().all()
async def save(self, session: AsyncSession, autocommit: bool = True) -> Self:
session.add(self)
if autocommit:
await session.commit()
return self
class TimestampMixin:
created_at: Mapped[datetime] = mapped_column(default=_utcnow)
updated_at: Mapped[datetime] = mapped_column(default=get_field_value_created_at, onupdate=_utcnow)
class DatabaseModel(DeclarativeBase, AsyncAttrs, CRUDMixin, TimestampMixin):
__abstract__ = True
# https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html#preventing-implicit-io-when-using-asyncsession
__mapper_args__ = {"eager_defaults": True}
key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4)
def has_relationship(self, relationship: str) -> bool:
return relationship in self.__dict__
@property
def async_session(self) -> AsyncSession | None:
return async_object_session(self)
import asyncio
from uuid import UUID, uuid4
from datetime import datetime
from sqlalchemy import String, ForeignKey
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import Mapped, mapped_column, relationship
from database import DatabaseModel
class Dataset(DatabaseModel):
__tablename__ = "datasets"
name: Mapped[str] = mapped_column(String, index=True, nullable=False)
description: Mapped[str] = mapped_column(String, nullable=True)
users: Mapped[list["User"]] = relationship(secondary="datasets_users", back_populates="datasets")
class User(DatabaseModel):
__tablename__ = "users"
first_name: Mapped[str] = mapped_column(String, nullable=True)
last_name: Mapped[str] = mapped_column(String, nullable=True)
username: Mapped[str] = mapped_column(String, unique=True, nullable=False)
email: Mapped[str] = mapped_column(String, unique=True, nullable=False)
datasets: Mapped[list[Dataset]] = relationship(secondary="datasets_users", back_populates="users")
def __repr__(self):
return (
f"User(key={str(self.key)}, first_name={self.first_name}, last_name={self.last_name}, "
f"username={self.username}, email={self.email}, created_at={str(self.created_at)}, updated_at={str(self.updated_at)})"
)
_engine = create_async_engine("sqlite+aiosqlite:///example.db", echo=True)
Session = async_sessionmaker(bind=_engine, autocommit=False, expire_on_commit=False)
async def main():
username = "DiTo97"
async with Session() as session:
user = await User.create(session, username=username, email=f"{username}@example.com")
dataset = await Dataset.create(session, name="example")
user.datasets.append(dataset)
await user.save(session)
user = await User.get_by_or_raise(session, username=username)
print(user)
if __name__ == "__main__":
asyncio.run(main())
pydantic>2,<3
SQLAlchemy>2,<3