Last active
March 13, 2025 02:16
-
-
Save DiTo97/013edc70ffbacd3a9e63ff62f2f430f8 to your computer and use it in GitHub Desktop.
A feature-rich database model for use with SQLAlchemy's ORM abstraction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import typing | |
from uuid import UUID, uuid4 | |
from datetime import datetime, timezone | |
from functools import partial | |
from typing import Any | |
from pydantic import BaseModel | |
from sqlalchemy import select, func, sql | |
from sqlalchemy.dialects.mysql import insert as mysql_insert | |
from sqlalchemy.dialects.postgresql import insert as postgres_insert | |
from sqlalchemy.dialects.sqlite import insert as sqlite_insert | |
from sqlalchemy.engine.default import DefaultExecutionContext | |
from sqlalchemy.ext.asyncio import AsyncAttrs, async_object_session, AsyncSession | |
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | |
from sqlalchemy.sql.base import ExecutableOption | |
from typing_extensions import Self | |
if typing.TYPE_CHECKING: | |
from sqlalchemy.orm import InstrumentedAttribute | |
from sqlalchemy.sql.elements import BinaryExpression | |
Schema = typing.TypeVar("Schema", bound=BaseModel) | |
def get_field_value(context: DefaultExecutionContext, field: str) -> Any: | |
return context.get_current_parameters(isolate_multiinsert_groups=False)[field] | |
get_field_value_created_at = partial(get_field_value, field="created_at") | |
# TODO: should remove once SQLAlchemy has an agnostic insert | |
_insert_func = { | |
"sqlite": sqlite_insert, | |
"postgresql": postgres_insert, | |
"mysql": mysql_insert, | |
} | |
def _schema_or_values(schema: Schema | None, values: dict[str, Any]) -> dict[str, Any]: | |
return schema.model_dump() if schema else values | |
def _utcnow() -> datetime: | |
"""constructs a UTC datetime from time""" | |
return datetime.now(timezone.utc) | |
class NotFoundException(Exception): | |
"""raises when a domain entity has not be found on the system""" | |
def __init__(self, message: str, code: int = 404): | |
self.message = message | |
self.code = code | |
class CRUDMixin: | |
__upsertable_columns__: set[str] | None = None | |
def fill(self, replace_dict: bool = False, **kwargs: Any) -> Self: | |
for key, value in kwargs.items(): | |
if not hasattr(self, key): | |
message = f"`{self.__class__.__name__}` has no attribute `{key}`" | |
raise AttributeError(message) | |
# If the value is a dict, set value for each key one by one, as we have to update | |
# only the keys that are in `value` and not override the whole dict. | |
if isinstance(value, dict) and not replace_dict: | |
dict_value = getattr(self, key) or {} | |
dict_value.update(value) | |
value = dict_value | |
setattr(self, key, value) | |
return self | |
@classmethod | |
async def create( | |
cls, session: AsyncSession, schema: Schema | None = None, autocommit: bool = True, **kwargs: Any | |
) -> Self: | |
values = _schema_or_values(schema, kwargs) | |
instance = cls() | |
instance.fill(**values) | |
return await instance.save(session, autocommit) | |
@classmethod | |
async def get(cls, session: AsyncSession, key: UUID, options: list[ExecutableOption] = []) -> Self | None: | |
return (await session.execute(select(cls).filter_by(key=key).options(*options))).scalar_one_or_none() | |
@classmethod | |
async def get_or_raise(cls, session: AsyncSession, key: UUID, options: list[ExecutableOption] = []) -> Self: | |
instance = await cls.get(session, key, options) | |
if not instance: | |
message = f"{cls.__name__} with key `{key}` not found" | |
raise NotFoundException(message) | |
return instance | |
@classmethod | |
async def exists(cls, session: AsyncSession, **conditions) -> bool: | |
return (await session.execute(select(cls).filter_by(**conditions).exists())).scalar() | |
@classmethod | |
async def get_by(cls, session: AsyncSession, **conditions) -> Self | None: | |
return (await session.execute(select(cls).filter_by(**conditions))).scalar_one_or_none() | |
@classmethod | |
async def get_by_or_raise(cls, session: AsyncSession, **conditions) -> Self: | |
instance = await cls.get_by(session, **conditions) | |
if not instance: | |
message = ", ".join([f"{key}={value}" for key, value in conditions.items()]) | |
message = f"{cls.__name__} not found filtering by {message}" | |
raise NotFoundException(message) | |
return instance | |
@classmethod | |
async def count_by(cls, session: AsyncSession, **conditions) -> int: | |
return (await session.execute(select(func.count(cls.id)).filter_by(**conditions))).scalar_one() | |
async def update( | |
self, | |
session: AsyncSession, | |
schema: Schema | None = None, | |
replace_dict: bool = False, | |
autocommit: bool = True, | |
**kwargs: Any, | |
) -> Self: | |
values = _schema_or_values(schema, kwargs) | |
result = self.fill(replace_dict=replace_dict, **values) | |
return await result.save(session, autocommit) | |
@classmethod | |
async def update_many( | |
cls, | |
session: AsyncSession, | |
objects: list[dict[str, Any]], | |
autocommit: bool = True, | |
) -> None: | |
if not objects: | |
message = "cannot update empty list of objects" | |
raise ValueError(message) | |
await session.execute(sql.update(cls), objects) | |
if autocommit: | |
await session.commit() | |
@classmethod | |
async def upsert_many( | |
cls, | |
session: AsyncSession, | |
objects: list[Schema | dict], | |
constraints: list["InstrumentedAttribute[Any]"], | |
autocommit: bool = True, | |
) -> list[Self]: | |
if not objects: | |
message = "cannot upsert empty list of objects" | |
raise ValueError(message) | |
values = [obj if isinstance(obj, dict) else obj.model_dump() for obj in objects] | |
statement = _insert_func[session.bind.dialect.name](cls).values(values) | |
# On conflict, should update only the columns that are upsertable | |
columns_to_update = {column: statement.excluded[column] for column in cls.__upsertable_columns__} | |
# onupdate for `updated_at` is not working; should force a new value on update | |
if hasattr(cls, "updated_at"): | |
columns_to_update["updated_at"] = _utcnow() | |
statement = ( | |
statement.on_conflict_do_update(index_elements=constraints, set_=columns_to_update) | |
.returning(cls) | |
.execution_options(populate_existing=True) | |
) | |
result = await session.execute(statement) | |
if autocommit: | |
await session.commit() | |
return result.scalars().all() | |
@classmethod | |
async def upsert( | |
cls, | |
session: AsyncSession, | |
schema: Schema | dict, | |
constraints: list["InstrumentedAttribute[Any]"], | |
autocommit: bool = True, | |
) -> Self: | |
result = await cls.upsert_many(session, [schema], constraints, autocommit) | |
return result[0] | |
async def delete(self, session: AsyncSession, autocommit: bool = True) -> Self: | |
await session.delete(self) | |
if autocommit: | |
await session.commit() | |
return self | |
@classmethod | |
async def delete_many( | |
cls, session: AsyncSession, conditions: list["BinaryExpression"], autocommit: bool = True | |
) -> list[Self]: | |
statement = sql.delete(cls).where(*conditions).returning(cls) | |
result = await session.execute(statement) | |
if autocommit: | |
await session.commit() | |
return result.scalars().all() | |
async def save(self, session: AsyncSession, autocommit: bool = True) -> Self: | |
session.add(self) | |
if autocommit: | |
await session.commit() | |
return self | |
class TimestampMixin: | |
created_at: Mapped[datetime] = mapped_column(default=_utcnow) | |
updated_at: Mapped[datetime] = mapped_column(default=get_field_value_created_at, onupdate=_utcnow) | |
class DatabaseModel(DeclarativeBase, AsyncAttrs, CRUDMixin, TimestampMixin): | |
__abstract__ = True | |
# https://docs.sqlalchemy.org/en/14/orm/extensions/asyncio.html#preventing-implicit-io-when-using-asyncsession | |
__mapper_args__ = {"eager_defaults": True} | |
key: Mapped[UUID] = mapped_column(primary_key=True, default=uuid4) | |
def has_relationship(self, relationship: str) -> bool: | |
return relationship in self.__dict__ | |
@property | |
def async_session(self) -> AsyncSession | None: | |
return async_object_session(self) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
from uuid import UUID, uuid4 | |
from datetime import datetime | |
from sqlalchemy import String, ForeignKey | |
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker | |
from sqlalchemy.orm import Mapped, mapped_column, relationship | |
from database import DatabaseModel | |
class Dataset(DatabaseModel): | |
__tablename__ = "datasets" | |
name: Mapped[str] = mapped_column(String, index=True, nullable=False) | |
description: Mapped[str] = mapped_column(String, nullable=True) | |
users: Mapped[list["User"]] = relationship(secondary="datasets_users", back_populates="datasets") | |
class User(DatabaseModel): | |
__tablename__ = "users" | |
first_name: Mapped[str] = mapped_column(String, nullable=True) | |
last_name: Mapped[str] = mapped_column(String, nullable=True) | |
username: Mapped[str] = mapped_column(String, unique=True, nullable=False) | |
email: Mapped[str] = mapped_column(String, unique=True, nullable=False) | |
datasets: Mapped[list[Dataset]] = relationship(secondary="datasets_users", back_populates="users") | |
def __repr__(self): | |
return ( | |
f"User(key={str(self.key)}, first_name={self.first_name}, last_name={self.last_name}, " | |
f"username={self.username}, email={self.email}, created_at={str(self.created_at)}, updated_at={str(self.updated_at)})" | |
) | |
_engine = create_async_engine("sqlite+aiosqlite:///example.db", echo=True) | |
Session = async_sessionmaker(bind=_engine, autocommit=False, expire_on_commit=False) | |
async def main(): | |
username = "DiTo97" | |
async with Session() as session: | |
user = await User.create(session, username=username, email=f"{username}@example.com") | |
dataset = await Dataset.create(session, name="example") | |
user.datasets.append(dataset) | |
await user.save(session) | |
user = await User.get_by_or_raise(session, username=username) | |
print(user) | |
if __name__ == "__main__": | |
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pydantic>2,<3 | |
SQLAlchemy>2,<3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The module is inspired by Argilla's server implementation. 1 2 3
Footnotes
https://github.com/argilla-io/argilla/blob/72fef19968d0e8eb818d7ccac13839ea51df4a4a/argilla-server/src/argilla_server/models/base.py ↩
https://github.com/argilla-io/argilla/blob/72fef19968d0e8eb818d7ccac13839ea51df4a4a/argilla-server/src/argilla_server/models/mixins.py ↩
https://github.com/argilla-io/argilla/blob/72fef19968d0e8eb818d7ccac13839ea51df4a4a/argilla-server/src/argilla_server/errors/future/base_errors.py#L31 ↩