Last active
August 1, 2021 07:11
-
-
Save picaso/817399e1206177ef4d80c53a6bf001b2 to your computer and use it in GitHub Desktop.
Repository pattern
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Generic, TypeVar, Optional, List | |
from sqlalchemy.exc import SQLAlchemyError | |
from db.postgres_db.db_manager import DBManager | |
from db.postgres_db.models import db | |
T = TypeVar('T', bound=db.Model) | |
class BaseRepository(Generic[T]): | |
# noinspection PyTypeChecker | |
def __init__(self, clazz: db.Model, db_manager: DBManager) -> None: | |
self.clazz = clazz | |
self.session = db_manager.session | |
def save(self, model: T) -> T: | |
self.session.add(model) | |
self.commit() | |
return model | |
def save_all(self, models: List[T]) -> List[T]: | |
self.session.bulk_save_objects(models) | |
self.commit() | |
return models | |
def delete(self, model: T) -> T: | |
self.session.delete(model) | |
self.commit() | |
return model | |
def find(self, pk) -> Optional[T]: | |
return self.clazz.query.get(pk) | |
def find_by(self, **keys) -> List[T]: | |
return self.clazz.query.filter_by(**keys).all() | |
def find_first(self, **keys) -> Optional[T]: | |
return self.clazz.query.filter_by(**keys).first() | |
def filter(self, *criterion) -> List[T]: | |
return self.clazz.query.filter(*criterion).all() | |
def truncate(self) -> List[T]: | |
try: | |
deleted_rows = self.clazz.query.delete() | |
self.commit() | |
return deleted_rows | |
except SQLAlchemyError: | |
self.session.rollback() | |
return | |
def commit(self) -> None: | |
self.session.commit() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class UserRepository(BaseRepository[User]): | |
def __init__(self, db_manager: DBManager) -> None: | |
super().__init__(clazz=User, db_manager=db_manager) | |
def update_user_status(self, *, user_id: str, status: Status) -> None: | |
user = self.find_first(user_id=user_id) | |
if user: | |
user.order_status = status.name | |
self.commit() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment