Created
October 29, 2024 11:37
-
-
Save paulwinex/c3c7472b07cac6190107f32f96e0ab4b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Requirements | |
python = "^3.11" | |
uvicorn = "^0.32.0" | |
SQLAlchemy = "^2.0.36" | |
asyncpg = "^0.30.0" | |
fastapi = "^0.115.4" | |
""" | |
from __future__ import annotations | |
from contextlib import asynccontextmanager | |
from typing import Optional, Literal, Union | |
from pydantic import ConfigDict, Field | |
from sqlalchemy import ForeignKey, String, UniqueConstraint | |
from sqlalchemy.orm import ( | |
relationship, | |
DeclarativeBase, | |
Mapped, | |
mapped_column, | |
) | |
# MODELS | |
class BaseDbModel(DeclarativeBase): | |
"""Base database model.""" | |
__abstract__ = True | |
id: Mapped[int] = mapped_column(primary_key=True) | |
__tablename__ = None | |
__table_args__ = {"extend_existing": True} | |
class AssetTypeModel(BaseDbModel): | |
__tablename__ = "asset_types" | |
name: Mapped[str] = mapped_column(index=True, unique=True, nullable=False) | |
label: Mapped[str] = mapped_column(index=True, unique=True, nullable=False) | |
class EntityModel(BaseDbModel): | |
""" | |
Polymorphic base model | |
""" | |
__tablename__ = "entities" | |
entity_type: Mapped[str] = mapped_column(String(50), nullable=False) | |
__mapper_args__ = {"polymorphic_identity": "entity", "polymorphic_on": entity_type} | |
tasks: Mapped[list["TaskModel"]] = relationship( | |
"TaskModel", | |
back_populates="entity", | |
lazy="selectin", | |
) | |
class AssetModel(EntityModel): | |
__tablename__ = "assets" | |
id: Mapped[int] = mapped_column(ForeignKey("entities.id"), primary_key=True) | |
__mapper_args__ = { | |
"polymorphic_identity": "asset", | |
"inherit_condition": id == EntityModel.id, | |
} | |
name: Mapped[str] = mapped_column(unique=True, nullable=False, index=True) | |
label: Mapped[str] = mapped_column(nullable=False) | |
asset_type_id: Mapped[int] = mapped_column( | |
ForeignKey("asset_types.id"), | |
nullable=True, | |
) | |
asset_type: Mapped[AssetTypeModel] = relationship( | |
lazy="selectin", | |
) | |
class AssetGroupModel(EntityModel): | |
__tablename__ = "assetgroup" | |
id: Mapped[int] = mapped_column(ForeignKey("entities.id"), primary_key=True) | |
__mapper_args__ = { | |
"polymorphic_identity": "assetgroup", | |
"inherit_condition": id == EntityModel.id, | |
} | |
label: Mapped[str] = mapped_column(index=True) | |
name: Mapped[str] = mapped_column(index=True) | |
assets: Mapped[list[AssetModel]] = relationship( | |
AssetModel, | |
secondary="assetgroup_asset_link", | |
lazy="selectin", | |
) | |
class AssetGroupAssetLinkModel(BaseDbModel): | |
__tablename__ = "assetgroup_asset_link" | |
__table_args__ = (UniqueConstraint("asset_id", "assetgroup_id", name="assetgroup_asset_link_uc"),) | |
asset_id: Mapped[int] = mapped_column(ForeignKey("assets.id", ondelete="CASCADE")) | |
assetgroup_id: Mapped[int] = mapped_column(ForeignKey("assetgroup.id", ondelete="CASCADE")) | |
class TaskModel(BaseDbModel): | |
__tablename__ = "tasks" | |
__table_args__ = (UniqueConstraint("name", "entity_id", name="uq_task_type_entity"),) # ??? | |
name: Mapped[str] = mapped_column(String(64)) | |
label: Mapped[str] = mapped_column(String(64)) | |
entity_id: Mapped[int] = mapped_column( | |
ForeignKey(f"entities.id", ondelete="SET NULL"), | |
nullable=True, | |
) | |
entity: Mapped[EntityModel] = relationship( | |
lazy="selectin", | |
) | |
class ProjectModel(BaseDbModel): | |
__tablename__ = "project" | |
task_id: Mapped[int] = mapped_column(ForeignKey("tasks.id", ondelete="SET NULL")) | |
task: Mapped[TaskModel] = relationship( | |
lazy="selectin", | |
) | |
# SCHEMAS | |
from pydantic import BaseModel | |
class AssetTypeResponseSchema(BaseModel): | |
model_config = ConfigDict(from_attributes=True) | |
id: int | |
label: Optional[str] = None | |
name: Optional[str] = None | |
class AssetGroupTaskResponseSchema(BaseModel): | |
"""for polymorphic field""" | |
model_config = ConfigDict(from_attributes=True) | |
entity_type: Literal["assetgroup"] = Field(...) | |
id: int | |
label: str | |
name: str | |
class AssetTaskResponseSchema(BaseModel): | |
"""for polymorphic field""" | |
model_config = ConfigDict(from_attributes=True) | |
entity_type: Literal["asset"] = Field(...) | |
id: int | |
name: str | |
label: str | |
asset_type_id: int | |
asset_type: AssetTypeResponseSchema | |
class TaskResponseSchema(BaseModel): | |
id: int | |
name: str | |
label: str | |
# polymorphic field | |
entity: Optional[ | |
Union[ | |
AssetGroupTaskResponseSchema, | |
AssetTaskResponseSchema, | |
] | |
] = Field(..., discriminator="entity_type") | |
meta_data: Optional[dict | None] = None | |
class ProjectResponseSchema(BaseModel): | |
model_config = ConfigDict(from_attributes=True) | |
id: int | |
task: TaskResponseSchema | |
# APP | |
from fastapi import FastAPI, Depends | |
from sqlalchemy.ext.asyncio import ( | |
AsyncSession, | |
async_sessionmaker, | |
create_async_engine, | |
) | |
import uvicorn | |
# db_url = "sqlite+aiosqlite:///:memory:" | |
db_url = "postgresql+asyncpg://test:test@localhost:5432/test" | |
echo_sql = False | |
engine = create_async_engine( | |
url=db_url, | |
echo=echo_sql, | |
pool_size=5, | |
max_overflow=5, | |
) | |
session_factory = async_sessionmaker( | |
bind=engine, | |
autoflush=False, | |
autocommit=False, | |
expire_on_commit=False, | |
) | |
async def init_db(app): | |
async with engine.begin() as conn: | |
await conn.run_sync(BaseDbModel.metadata.create_all) | |
async def get_session(): | |
async with session_factory() as session: | |
yield session | |
async def create_demo_objects(session: AsyncSession): | |
from sqlalchemy import select | |
result = await session.scalars(select(ProjectModel)) | |
if bool(result.first()): | |
print("Skip creation demo records") | |
return | |
asset_type = AssetTypeModel(name="type1", label="Type 1") | |
session.add(asset_type) | |
await session.flush() | |
asset = AssetModel(name="asset1", label="Asset 1", asset_type_id=asset_type.id) | |
session.add(asset) | |
await session.flush() | |
asset_group = AssetGroupModel(label="Group 1", name="group1") | |
session.add(asset_group) | |
await session.flush() | |
task_asset = TaskModel(name="task_asset", label="Task for Asset", entity_id=asset.id) | |
session.add(task_asset) | |
await session.flush() | |
task_asset_group = TaskModel(name="task_asset_group", label="Task for AssetGroup", entity_id=asset_group.id) | |
session.add(task_asset_group) | |
await session.flush() | |
project_asset = ProjectModel(task_id=task_asset.id) | |
session.add(project_asset) | |
await session.flush() | |
project_asset_group = ProjectModel(task_id=task_asset_group.id) | |
session.add(project_asset_group) | |
await session.flush() | |
await session.commit() | |
@asynccontextmanager | |
async def lifespan(app: FastAPI): | |
await init_db(app) | |
async with session_factory() as session: | |
await create_demo_objects(session) | |
yield | |
await engine.dispose() | |
application = FastAPI( | |
lifespan=lifespan, | |
) | |
@application.get("/", response_model=list[ProjectResponseSchema]) | |
async def index( | |
session: AsyncSession = Depends(get_session), | |
): | |
from sqlalchemy import select | |
from sqlalchemy.orm import with_polymorphic, joinedload, selectinload | |
query = select(ProjectModel) | |
all_entities = with_polymorphic(EntityModel, [AssetModel, AssetGroupModel], flat=True) | |
query = query.options( | |
selectinload(ProjectModel.task).selectinload(TaskModel.entity.of_type(all_entities)), | |
) | |
query = query.order_by("id").distinct(ProjectModel.id) | |
result = await session.scalars(query) | |
values = result.unique().all() | |
return values | |
if __name__ == "__main__": | |
uvicorn.run("app:application", host="0.0.0.0", port=8080, reload=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment