Created
November 9, 2023 21:22
-
-
Save DurandA/202c1b15b2d3bd2a7f3c9f5e6af8c677 to your computer and use it in GitHub Desktop.
Create Pydantic obj from SQLAlchemy model with AsyncAttrs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Any, List, Type, TypeVar | |
from pydantic import BaseModel | |
from sqlalchemy import Column, ForeignKey, Integer, String, create_engine | |
from sqlalchemy.ext.asyncio import (AsyncAttrs, AsyncEngine, AsyncSession, | |
create_async_engine) | |
from sqlalchemy.ext.declarative import declarative_base | |
from sqlalchemy.future import select | |
from sqlalchemy.inspection import inspect | |
from sqlalchemy.orm import RelationshipProperty, relationship, sessionmaker | |
from sqlalchemy.orm.decl_api import DeclarativeMeta | |
Base: DeclarativeMeta = declarative_base(cls=AsyncAttrs) | |
class Parent(Base): | |
__tablename__ = 'parents' | |
id: Any = Column(Integer, primary_key=True) | |
name: Any = Column(String) | |
children = relationship("Child", back_populates="parent") | |
class Child(Base): | |
__tablename__ = 'children' | |
id: Any = Column(Integer, primary_key=True) | |
name: Any = Column(String) | |
parent_id: Any = Column(Integer, ForeignKey('parents.id')) | |
parent = relationship("Parent", back_populates="children") | |
class ChildPydantic(BaseModel): | |
id: int | |
name: str | |
class ParentPydantic(BaseModel): | |
id: int | |
name: str | |
children: List[ChildPydantic] | |
T = TypeVar('T', bound=BaseModel) | |
async def load_relationships_and_create_pydantic( | |
db_obj, | |
pydantic_model: Type[T], | |
session: AsyncSession | |
) -> T: | |
# create a dictionary to hold the attributes including relationships | |
loaded_attrs = {} | |
for field_name, field_info in pydantic_model.__annotations__.items(): | |
mapper_attrs = inspect(db_obj.__class__).mapper.attrs | |
# check if the field is a relationship that should be loaded | |
if field_name in mapper_attrs and isinstance(mapper_attrs[field_name], RelationshipProperty): | |
relationship_data = await getattr(db_obj.awaitable_attrs, field_name) | |
if isinstance(relationship_data, list): # relationship is a list of items, use type of the list | |
loaded_attrs[field_name] = [ | |
await load_relationships_and_create_pydantic(item, field_info.__args__[0], session) | |
for item in relationship_data | |
] | |
else: # relationship is a single item | |
loaded_attrs[field_name] = await load_relationships_and_create_pydantic(relationship_data, field_info, session) | |
else: | |
# regular attribute, get its value | |
attr_value = getattr(db_obj, field_name, None) | |
if attr_value is not None: | |
loaded_attrs[field_name] = attr_value | |
# create the Pydantic instance | |
return pydantic_model(**loaded_attrs) | |
async def main(): | |
DATABASE_URL = "sqlite+aiosqlite:///./test.db" | |
engine: AsyncEngine = create_async_engine(DATABASE_URL, echo=True) | |
async with engine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) | |
async with async_session() as session: | |
parent = Parent(name='Parent1') | |
session.add(parent) | |
await session.commit() | |
child = Child(name='Child1', parent_id=parent.id) | |
session.add(child) | |
await session.commit() | |
stmt = select(Parent).where(Parent.id == parent.id) | |
result = await session.execute(stmt) | |
parent_obj = result.scalars().first() | |
parent_pydantic = await load_relationships_and_create_pydantic(parent_obj, ParentPydantic, session) | |
print(parent_pydantic.model_dump()) | |
import asyncio | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment