Created
March 8, 2024 17:25
-
-
Save paulwinex/63efeddf0b085ed9d432976484773741 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List, Optional | |
import aiosqlite | |
from fastapi import FastAPI, HTTPException, Depends | |
from sqlalchemy import Column, Integer, String, Table, ForeignKey, select, UniqueConstraint | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_scoped_session, async_sessionmaker | |
from sqlalchemy.orm import relationship, mapped_column, Mapped, backref, DeclarativeBase, selectinload | |
from pydantic import BaseModel as BaseSchema, ConfigDict | |
from contextlib import asynccontextmanager | |
class Base(DeclarativeBase): | |
id = Column(Integer, primary_key=True) | |
class TagsLinkModel(Base): | |
__tablename__ = "tag_links" | |
__table_args__ = (UniqueConstraint("tag_left_id", "tag_right_id", name="tag_links_uc"),) | |
tag_left_id: Mapped[int] = mapped_column(ForeignKey("tags.id", ondelete="CASCADE"), nullable=False) | |
tag_right_id: Mapped[int] = mapped_column(ForeignKey("tags.id", ondelete="CASCADE"), nullable=False) | |
class TagModel(Base): | |
__tablename__ = "tags" | |
id: Mapped[int] = mapped_column(primary_key=True) | |
name: Mapped[str] = mapped_column(nullable=True) | |
links: Mapped[list["TagModel"]] = relationship( | |
"TagModel", | |
secondary=TagsLinkModel.__table__, | |
primaryjoin=(TagsLinkModel.tag_left_id == id), | |
secondaryjoin=(TagsLinkModel.tag_right_id == id), | |
backref=backref("parent_links", lazy="selectin"), | |
lazy="selectin", | |
uselist=True, | |
) | |
class TagCreateSchema(BaseSchema): | |
name: str | |
links: Optional[List[int]] = None | |
class TagUpdateSchema(BaseSchema): | |
name: Optional[str] = None | |
links: Optional[List[int]] = None | |
class TagShortResponseSchema(BaseSchema): | |
model_config = ConfigDict(from_attributes=True) | |
id: int | |
name: str | |
class TagResponseSchema(TagShortResponseSchema): | |
links: List[TagShortResponseSchema] | |
DATABASE_URL = "sqlite+aiosqlite:///:memory:" | |
engine = create_async_engine(DATABASE_URL, echo=False, future=True) | |
async_session = async_sessionmaker( | |
bind=engine, | |
autoflush=False, | |
autocommit=False, | |
expire_on_commit=False, | |
) | |
@asynccontextmanager | |
async def lifespan(app: FastAPI): | |
async with engine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
yield | |
app = FastAPI(lifespan=lifespan) | |
async def get_tags_by_id(tag_ids: list[int], session: AsyncSession): | |
query = ( | |
select(TagModel) | |
.where(TagModel.id.in_(tag_ids)) | |
.options( | |
selectinload( | |
TagModel.links, | |
) | |
) | |
) | |
result = await session.execute(query) | |
tags = result.scalars().all() | |
if not tags: | |
raise HTTPException(status_code=404, detail="Tag not found") | |
return tags | |
async def get_session(): | |
async with async_session() as session: | |
yield session | |
@app.get("/tags", response_model=List[TagResponseSchema]) | |
async def get_tags(session: AsyncSession = Depends(get_session)): | |
query = select(TagModel).options( | |
selectinload( | |
TagModel.links, | |
) | |
) | |
result = await session.execute(query) | |
tag = result.scalars().all() | |
return tag | |
@app.get("/tags/{tag_id}", response_model=TagResponseSchema) | |
async def get_tag(tag_id: int, session: AsyncSession = Depends(get_session)): | |
return (await get_tags_by_id([tag_id], session))[0] | |
@app.post("/tags", response_model=TagResponseSchema) | |
async def create_tag(tag_form: TagCreateSchema, session: AsyncSession = Depends(get_session)): | |
new_tag = TagModel(**tag_form.model_dump(exclude_unset=True)) | |
session.add(new_tag) | |
await session.commit() | |
await session.refresh(new_tag) | |
return new_tag | |
@app.patch("/tags/{tag_id}", response_model=TagResponseSchema) | |
async def update_tag(tag_id: int, tag_data: TagUpdateSchema, session: AsyncSession = Depends(get_session)): | |
existing_tag = (await get_tags_by_id([tag_id], session))[0] | |
if not existing_tag: | |
raise HTTPException(status_code=404, detail="Tag not found") | |
data = tag_data.model_dump(exclude_unset=True) | |
if data: | |
for field, value in data.items(): | |
if field == "links": | |
if value: | |
other_tags = await get_tags_by_id(value, session) | |
existing_tag.links = other_tags | |
else: | |
existing_tag.links = [] | |
else: | |
setattr(existing_tag, field, value) | |
await session.commit() | |
await session.refresh(existing_tag) | |
return existing_tag | |
else: | |
raise HTTPException(status_code=400, detail="No Data to update") | |
@app.delete("/tags/{tag_id}") | |
async def delete_tag(tag_id: int, session: AsyncSession = Depends(get_session)): | |
tag = await session.get(TagModel, tag_id) | |
if not tag: | |
raise HTTPException(status_code=404, detail="Tag not found") | |
await session.delete(tag) | |
await session.commit() | |
return {"message": "Tag deleted"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("self-m2m-relation:app", host="0.0.0.0", port=8001) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
from pprint import pprint | |
url = "http://0.0.0.0:8001" | |
# create tags | |
for i in range(3): | |
requests.post(url + "/tags", json={"name": f"tag{i}"}) | |
# make links | |
requests.patch(url + "/tags/1", json={"links": [2, 3]}).json() | |
# {'id': 1, 'name': 'tag0', 'links': [{'id': 2, 'name': 'tag1'}, {'id': 3, 'name': 'tag2'}]} | |
requests.patch(url + "/tags/2", json={"links": [3]}).json() | |
# {'id': 2, 'name': 'tag1', 'links': [{'id': 3, 'name': 'tag2'}]} | |
requests.patch(url + "/tags/3", json={"links": [1]}).json() | |
# {'id': 3, 'name': 'tag2', 'links': [{'id': 1, 'name': 'tag0'}]} | |
# request tag | |
requests.get(url + "/tags/1").json() | |
# {'id': 1, 'name': 'tag0', 'links': [{'id': 2, 'name': 'tag1'}, {'id': 3, 'name': 'tag2'}]} | |
# request tags | |
pprint(requests.get(url + "/tags").json()) | |
# [ | |
# { | |
# "id": 1, | |
# "name": "tag0", | |
# "links": [{"id": 2, "name": "tag1"}, {"id": 3, "name": "tag2"}], | |
# }, | |
# { | |
# "id": 2, | |
# "name": "tag1", | |
# "links": [{"id": 3, "name": "tag2"}], | |
# }, | |
# { | |
# "id": 3, | |
# "name": "tag2", | |
# "links": [{"id": 1, "name": "tag0"}], | |
# }, | |
# ] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment