Created
June 20, 2023 09:01
-
-
Save TheBigRoomXXL/b6fe902cd877f09f35058786e8324ddc to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Extend SQLAlchemy | |
This is an experiment to put in common some logique beetween sqla model class so that | |
most queries are done through a shared interface. | |
Some calls are specific to flask_sqlavhemy or smorest but can be easily adapted. | |
""" | |
from typing import TYPE_CHECKING | |
if TYPE_CHECKING: | |
from typing import Self, Tuple | |
from sqlalchemy import Select, Insert, Update, Delete, ColumnCollection | |
from sqlalchemy.orm import ORMExecuteState | |
from sqlalchemy import select, insert, update, delete | |
from sqlalchemy.inspection import inspect | |
from flask_sqlalchemy import SQLAlchemy | |
from flask_sqlalchemy.model import Model | |
from flask_jwt_extended import current_user | |
from flask_smorest import abort | |
class MyDbModel(Model): | |
"""MyDbModel is used as the "model_class" of db (aka db.Model). | |
Extend the SQLAlchemy functionalities throught this class.""" | |
@classmethod | |
def select(klass) -> "Select[Tuple[Self]]": | |
"""Return the select statement of the class""" | |
return select(klass) | |
@classmethod | |
def insert(klass, values: list[dict]) -> "Insert[Tuple[Self]]": | |
return insert(klass).values(*values) | |
@classmethod | |
def update(klass, filters: dict) -> "Update[Tuple[Self]]": | |
return update(klass).where(**filters) | |
@classmethod | |
def insert_one(klass, value: dict) -> "Insert[Self]": | |
return insert(klass).values(**value) | |
@classmethod | |
def delete(klass, filters: dict) -> "Delete[Tuple[Self]]": | |
return delete(klass).where(**filters) | |
@classmethod | |
def select_one(klass, obj_pk: int | str) -> "Select[Self]": | |
pk = inspect(klass).primary_key[0].name | |
return select(klass).where(**{pk: obj_pk}) | |
@classmethod | |
def select_one_or_404(klass, obj_pk: int | str) -> "Select[Self]": | |
pk = inspect(klass).primary_key[0].name | |
obj = db.session.scalars(select(klass).where(**{pk: obj_pk})).one_or_none() | |
if obj is None: | |
abort(404, message="The resource you requested does not exist") | |
return obj | |
@classmethod | |
def update_one(klass, obj_pk: int | str, value: dict) -> "Update[Self]": | |
pk = inspect(klass).primary_key[0].name | |
return update(klass).where(**{pk: obj_pk}).values(**value) | |
@classmethod | |
def delete_one(klass, obj_pk: int | str) -> "Delete[Self]": | |
pk = inspect(klass).primary_key[0].name | |
return delete(klass).where(**{pk: obj_pk}) | |
db = SQLAlchemy(model_class=MyDbModel, engine_options={"future": True}) | |
ADMIN_ROLES = ["admin", "technician"] | |
@db.event.listens_for(db.session, "do_orm_execute") | |
def shared_authorization_filter(orm_execute_state: ORMExecuteState) -> None: | |
"""This function intercept all call to the database and add a filter | |
to ensure that only authorized data is accessed | |
TODO: add filter to what a user can isert""" | |
if orm_execute_state.is_insert: | |
return None | |
statement = orm_execute_state.statement | |
mappers = orm_execute_state.all_mappers | |
for mapper in mappers: | |
columns: "ColumnCollection" = mapper.columns | |
if columns.get("deleted_at") is not None: | |
statement = statement.where(columns.deleted_at == None) # NOQA | |
# admin is not filtered | |
if current_user.role in ADMIN_ROLES: | |
return None | |
if columns.get(FILTER_KEY) is not None: | |
statement = statement.where(columns.FILTER_KEY == current_user.FILTER_KEY) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment