Last active
December 25, 2024 04:30
-
-
Save paulwinex/71615de8d6d75dfbda1010d57bb10afb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from contextlib import asynccontextmanager | |
from typing import Any, Optional | |
from fastapi import FastAPI, HTTPException, Depends | |
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine | |
from sqlalchemy.orm import sessionmaker, selectinload, Mapped, mapped_column, relationship, declarative_base | |
from sqlalchemy import ForeignKey, JSON, select | |
from sqlalchemy.types import Enum | |
from pydantic import BaseModel | |
import enum | |
DATABASE_URL = "postgresql+asyncpg://test:test@localhost/test" | |
Base = declarative_base() | |
class DataTypeEnum(enum.Enum): | |
STRING = "str" | |
INTEGER = "int" | |
FLOAT = "float" | |
BOOLEAN = "bool" | |
JSON = "dict" | |
class Product(Base): | |
__tablename__ = "products" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
name: Mapped[str] = mapped_column(nullable=False) | |
attribute_values: Mapped[list["AttributeValue"]] = relationship( | |
"AttributeValue", back_populates="product", cascade="all, delete-orphan" | |
) | |
class Attribute(Base): | |
__tablename__ = "attributes" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
name: Mapped[str] = mapped_column(unique=True, nullable=False) | |
data_type: Mapped[DataTypeEnum] = mapped_column(Enum(DataTypeEnum), nullable=False) | |
values: Mapped[list["AttributeValue"]] = relationship("AttributeValue", back_populates="attribute") | |
class AttributeValue(Base): | |
__tablename__ = "attribute_values" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
product_id: Mapped[int] = mapped_column(ForeignKey("products.id"), nullable=False) | |
attribute_id: Mapped[int] = mapped_column(ForeignKey("attributes.id"), nullable=False) | |
value: Mapped[JSON] = mapped_column(JSON) | |
product: Mapped["Product"] = relationship("Product", back_populates="attribute_values") | |
attribute: Mapped["Attribute"] = relationship("Attribute", back_populates="values") | |
# SCHEMAS | |
class ProductResponseSchema(BaseModel): | |
id: int | |
name: str | |
attributes: dict[str, Any] | |
model_config = ConfigDict(from_attributes=True) | |
class ProductCreateSchema(BaseModel): | |
name: str | |
attributes: Optional[dict[str, Any]] = {} | |
class AttributeResponseSchema(BaseModel): | |
id: int | |
name: str | |
data_type: DataTypeEnum | |
model_config = ConfigDict(from_attributes=True) | |
class AttributeCreateSchema(BaseModel): | |
name: str | |
data_type: DataTypeEnum | |
# DATABASE | |
engine = create_async_engine(DATABASE_URL, echo=True) | |
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) | |
@asynccontextmanager | |
async def lifespan(app): | |
async with engine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
yield | |
async def get_session() -> AsyncSession: | |
async with async_session() as session: | |
yield session | |
# APP | |
app = FastAPI( | |
lifespan=lifespan, | |
) | |
@app.get("/attributes/", response_model=list[AttributeResponseSchema]) | |
async def get_attributes(session: AsyncSession = Depends(get_session)): | |
result = await session.execute(select(Attribute)) | |
attributes = result.scalars().all() | |
return attributes | |
@app.post("/attributes/", response_model=AttributeResponseSchema) | |
async def create_attribute( | |
payload: AttributeCreateSchema, | |
session: AsyncSession = Depends(get_session) | |
): | |
new_attribute = Attribute(**payload.dict()) | |
session.add(new_attribute) | |
await session.commit() | |
return new_attribute | |
@app.get("/products/", response_model=list[ProductResponseSchema]) | |
async def get_products(session: AsyncSession = Depends(get_session)): | |
result = await session.execute( | |
select(Product).options(selectinload(Product.attribute_values).selectinload(AttributeValue.attribute)) | |
) | |
products = result.scalars().all() | |
return [map_product_to_response(product) for product in products] | |
@app.post("/products/", response_model=ProductResponseSchema) | |
async def create_product( | |
payload: ProductCreateSchema, | |
session: AsyncSession = Depends(get_session) | |
): | |
product = Product(name=payload.name) | |
session.add(product) | |
await session.flush() | |
for attr_name, value in payload.attributes.items(): | |
attribute = await session.execute(select(Attribute).filter_by(name=attr_name)) | |
attribute = attribute.scalar_one_or_none() | |
if not attribute: | |
raise HTTPException(status_code=400, detail=f"Attribute '{attr_name}' does not exist.") | |
attribute_value = AttributeValue( | |
product_id=product.id, | |
attribute_id=attribute.id, | |
value=value, | |
) | |
session.add(attribute_value) | |
await session.commit() | |
return map_product_to_response(product) | |
@app.get("/products/{product_id}", response_model=ProductResponseSchema) | |
async def get_product(product_id: int, session: AsyncSession = Depends(get_session)): | |
result = await session.execute( | |
select(Product).options(selectinload(Product.attribute_values).selectinload(AttributeValue.attribute)) | |
.filter(Product.id == product_id) | |
) | |
product = result.scalar_one_or_none() | |
if not product: | |
raise HTTPException(status_code=404, detail="Product not found") | |
return map_product_to_response(product) | |
# UTILS | |
def map_product_to_response(product: Product) -> ProductResponseSchema: | |
attributes = { | |
av.attribute.name: av.value | |
for av in product.attribute_values | |
} | |
return ProductResponseSchema( | |
id=product.id, | |
name=product.name, | |
attributes=attributes | |
) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment