Skip to content

Instantly share code, notes, and snippets.

@bomzheg
Last active June 10, 2021 05:55
Show Gist options
  • Save bomzheg/3732fd178b9334db149ec388221fae63 to your computer and use it in GitHub Desktop.
Save bomzheg/3732fd178b9334db149ec388221fae63 to your computer and use it in GitHub Desktop.
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, TypeVar, Type, Generic
from app.models.db.base import Base
Model = TypeVar('Model', Base, Base)
class BaseDAO(Generic[Model]):
def __init__(self, model: Type[Model], session: AsyncSession):
self.model = model
self.session = session
async def get_all(self) -> List[Model]:
result = await self.session.execute(select(self.model))
return result.all()
async def get_by_id(self, id_: int) -> Model:
result = await self.session.execute(
select(self.model).where(self.model.id == id_)
)
return result.scalar_one()
from sqlalchemy import Column, Integer, String
from sqlalchemy.orm import relationship
from .base import Base
class Sport(Base):
__tablename__ = "sports"
__mapper_args__ = {"eager_defaults": True}
id = Column(Integer, primary_key=True)
name = Column(String, index=True, unique=True)
matchs = relationship("Match", back_populates="sport")
from contextlib import suppress
from sqlalchemy.exc import IntegrityError
from sqlalchemy.future import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm.exc import NoResultFound
from app.dao.common import BaseDAO
from app.models.db import Sport
class SportDAO(BaseDAO[Sport]):
def __init__(self, session: AsyncSession):
super().__init__(Sport, session)
async def insert_or_select(self, sport_name: str) -> Sport:
with suppress(NoResultFound):
return await self.get_by_name(sport_name)
sport = Sport(name=sport_name)
self.session.add(sport)
try:
await self.session.commit()
except IntegrityError:
await self.session.rollback()
sport = await self.get_by_name(sport_name)
return sport
async def get_by_name(self, sport_name: str) -> Sport:
statement = select(Sport).where(Sport.name == sport_name)
sport_res = await self.session.execute(statement)
sport = sport_res.scalar_one()
return sport
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment