Skip to content

Instantly share code, notes, and snippets.

@paulwinex
Last active December 25, 2024 04:30
Show Gist options
  • Save paulwinex/71615de8d6d75dfbda1010d57bb10afb to your computer and use it in GitHub Desktop.
Save paulwinex/71615de8d6d75dfbda1010d57bb10afb to your computer and use it in GitHub Desktop.
from contextlib import asynccontextmanager
from typing import Any, Optional
from fastapi import FastAPI, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, selectinload, Mapped, mapped_column, relationship, declarative_base
from sqlalchemy import ForeignKey, JSON, select
from sqlalchemy.types import Enum
from pydantic import BaseModel
import enum
DATABASE_URL = "postgresql+asyncpg://test:test@localhost/test"
Base = declarative_base()
class DataTypeEnum(enum.Enum):
STRING = "str"
INTEGER = "int"
FLOAT = "float"
BOOLEAN = "bool"
JSON = "dict"
class Product(Base):
__tablename__ = "products"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False)
attribute_values: Mapped[list["AttributeValue"]] = relationship(
"AttributeValue", back_populates="product", cascade="all, delete-orphan"
)
class Attribute(Base):
__tablename__ = "attributes"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(unique=True, nullable=False)
data_type: Mapped[DataTypeEnum] = mapped_column(Enum(DataTypeEnum), nullable=False)
values: Mapped[list["AttributeValue"]] = relationship("AttributeValue", back_populates="attribute")
class AttributeValue(Base):
__tablename__ = "attribute_values"
id: Mapped[int] = mapped_column(primary_key=True)
product_id: Mapped[int] = mapped_column(ForeignKey("products.id"), nullable=False)
attribute_id: Mapped[int] = mapped_column(ForeignKey("attributes.id"), nullable=False)
value: Mapped[JSON] = mapped_column(JSON)
product: Mapped["Product"] = relationship("Product", back_populates="attribute_values")
attribute: Mapped["Attribute"] = relationship("Attribute", back_populates="values")
# SCHEMAS
class ProductResponseSchema(BaseModel):
id: int
name: str
attributes: dict[str, Any]
model_config = ConfigDict(from_attributes=True)
class ProductCreateSchema(BaseModel):
name: str
attributes: Optional[dict[str, Any]] = {}
class AttributeResponseSchema(BaseModel):
id: int
name: str
data_type: DataTypeEnum
model_config = ConfigDict(from_attributes=True)
class AttributeCreateSchema(BaseModel):
name: str
data_type: DataTypeEnum
# DATABASE
engine = create_async_engine(DATABASE_URL, echo=True)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
@asynccontextmanager
async def lifespan(app):
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async def get_session() -> AsyncSession:
async with async_session() as session:
yield session
# APP
app = FastAPI(
lifespan=lifespan,
)
@app.get("/attributes/", response_model=list[AttributeResponseSchema])
async def get_attributes(session: AsyncSession = Depends(get_session)):
result = await session.execute(select(Attribute))
attributes = result.scalars().all()
return attributes
@app.post("/attributes/", response_model=AttributeResponseSchema)
async def create_attribute(
payload: AttributeCreateSchema,
session: AsyncSession = Depends(get_session)
):
new_attribute = Attribute(**payload.dict())
session.add(new_attribute)
await session.commit()
return new_attribute
@app.get("/products/", response_model=list[ProductResponseSchema])
async def get_products(session: AsyncSession = Depends(get_session)):
result = await session.execute(
select(Product).options(selectinload(Product.attribute_values).selectinload(AttributeValue.attribute))
)
products = result.scalars().all()
return [map_product_to_response(product) for product in products]
@app.post("/products/", response_model=ProductResponseSchema)
async def create_product(
payload: ProductCreateSchema,
session: AsyncSession = Depends(get_session)
):
product = Product(name=payload.name)
session.add(product)
await session.flush()
for attr_name, value in payload.attributes.items():
attribute = await session.execute(select(Attribute).filter_by(name=attr_name))
attribute = attribute.scalar_one_or_none()
if not attribute:
raise HTTPException(status_code=400, detail=f"Attribute '{attr_name}' does not exist.")
attribute_value = AttributeValue(
product_id=product.id,
attribute_id=attribute.id,
value=value,
)
session.add(attribute_value)
await session.commit()
return map_product_to_response(product)
@app.get("/products/{product_id}", response_model=ProductResponseSchema)
async def get_product(product_id: int, session: AsyncSession = Depends(get_session)):
result = await session.execute(
select(Product).options(selectinload(Product.attribute_values).selectinload(AttributeValue.attribute))
.filter(Product.id == product_id)
)
product = result.scalar_one_or_none()
if not product:
raise HTTPException(status_code=404, detail="Product not found")
return map_product_to_response(product)
# UTILS
def map_product_to_response(product: Product) -> ProductResponseSchema:
attributes = {
av.attribute.name: av.value
for av in product.attribute_values
}
return ProductResponseSchema(
id=product.id,
name=product.name,
attributes=attributes
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment