Last active
June 24, 2022 15:38
-
-
Save DomWeldon/adb17318202dd0b673a0753e37c46b4b to your computer and use it in GitHub Desktop.
CRUD Router
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Standard Library | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
FrozenSet, | |
List, | |
Optional, | |
Sequence, | |
Set, | |
Type, | |
cast, | |
) | |
# Third Party Libraries | |
from fastapi import APIRouter, HTTPException, status | |
from fastapi.routing import APIRoute | |
from pydantic import BaseModel | |
from sqlalchemy import ForeignKeyConstraint, Table, inspect | |
from sqlalchemy.ext.declarative import DeclarativeMeta | |
from sqlalchemy.orm import Session | |
from sqlalchemy.sql.elements import UnaryExpression | |
from starlette import routing | |
from starlette.responses import Response | |
from starlette.types import ASGIApp | |
# App and Model Imports | |
from app.utils.oop import all_subclasses | |
class CRUDApiRouter(APIRouter): | |
"""Automatically generate a router with essential methods on.""" | |
db_dep: Session | |
"""Dependency to get a DB session""" | |
model: DeclarativeMeta | |
"""Model this router is for""" | |
_HTTP_2XX_NO_RETURN_CODES: Set[int] = { | |
status.HTTP_204_NO_CONTENT, | |
status.HTTP_205_RESET_CONTENT, | |
} | |
model_base: Optional[DeclarativeMeta] = None | |
"""Base for the model if using create or update""" | |
# for listing rows | |
list_schema: Optional[BaseModel] = None | |
list_deps: Optional[List[Any]] = None | |
list_sort: Optional[Sequence[UnaryExpression]] = None | |
list_default_offset: Optional[int] = 0 | |
list_default_limit: Optional[int] = 100 | |
list_max_limit: int = 1_000 | |
list_view_path: str = "/" | |
# for detail page | |
detail_schema: Optional[BaseModel] = None | |
detail_deps: Optional[List[Any]] = None | |
detail_view_path: str = "/{id}" | |
detail_status_code_not_found: int = status.HTTP_404_NOT_FOUND | |
# for create view | |
create_schema_in: Optional[BaseModel] = None | |
create_schema_out: Optional[BaseModel] = None | |
create_status_code: int = status.HTTP_201_CREATED | |
create_status_code_fk_error: int = status.HTTP_422_UNPROCESSABLE_ENTITY | |
create_status_code_conflict: int = status.HTTP_409_CONFLICT | |
create_view_path: str = "/" | |
create_deps: Optional[List[Any]] = None | |
# for update view | |
update_schema_in: Optional[BaseModel] = None | |
update_schema_out: Optional[BaseModel] = None | |
update_status_code: int = status.HTTP_204_NO_CONTENT | |
update_status_code_fk_error: int = status.HTTP_422_UNPROCESSABLE_ENTITY | |
update_status_code_conflict: int = status.HTTP_409_CONFLICT | |
update_status_code_not_found: int = status.HTTP_404_NOT_FOUND | |
update_view_path: str = "/{id}" | |
update_deps: Optional[List[Any]] = None | |
delete_view: bool = False | |
delete_deps: Optional[List[Any]] = None | |
delete_view_path: str = "/{id}" | |
delete_status_code: int = status.HTTP_204_NO_CONTENT | |
delete_schema_out: Optional[BaseModel] = None | |
delete_status_code_not_found: int = status.HTTP_404_NOT_FOUND | |
def __init__( | |
self, | |
routes: Optional[List[routing.BaseRoute]] = None, | |
redirect_slashes: bool = True, | |
default: Optional[ASGIApp] = None, | |
dependency_overrides_provider: Optional[Any] = None, | |
route_class: Type[APIRoute] = APIRoute, | |
default_response_class: Optional[Type[Response]] = None, | |
on_startup: Optional[Sequence[Callable]] = None, | |
on_shutdown: Optional[Sequence[Callable]] = None, | |
) -> None: | |
"""Instantiate like a normal API view then add CRUD methods.""" | |
assert self.model is not None | |
super().__init__( | |
routes=routes, | |
redirect_slashes=redirect_slashes, | |
default=default, | |
dependency_overrides_provider=dependency_overrides_provider, | |
route_class=route_class, | |
default_response_class=default_response_class, | |
on_startup=on_startup, | |
on_shutdown=on_shutdown, | |
) | |
if self.list_schema is not None: | |
endpoint = self._generate_list_view() | |
self.get( | |
self.list_view_path, | |
response_model=List[self.list_schema], # type: ignore | |
dependencies=(self.list_deps or []), | |
description=endpoint._description, # type: ignore | |
)(endpoint) | |
if self.detail_schema is not None: | |
endpoint = self._generate_detail_view() | |
self.get( | |
self.detail_view_path, | |
response_model=self.detail_schema, # type: ignore | |
dependencies=(self.detail_deps or []), | |
description=endpoint._description, # type: ignore | |
)(endpoint) | |
if self.create_schema_in is not None: | |
assert self.create_schema_out is not None | |
assert self.model_base is not None | |
endpoint = self._generate_create_view() | |
self.post( | |
self.create_view_path, | |
response_model=self.create_schema_out, # type: ignore | |
status_code=self.create_status_code, | |
dependencies=(self.create_deps or []), | |
description=endpoint._description, # type: ignore | |
)(endpoint) | |
if self.update_schema_in is not None: | |
assert self.model_base is not None | |
endpoint = self._generate_update_view() | |
self.put( | |
self.update_view_path, | |
response_model=self.update_schema_out, # type: ignore | |
status_code=self.update_status_code, | |
dependencies=(self.update_deps or []), | |
description=endpoint._description, # type: ignore | |
)(endpoint) | |
if self.delete_view is True: | |
endpoint = self._generate_delete_view() | |
self.delete( | |
self.delete_view_path, | |
response_model=self.delete_schema_out, # type: ignore | |
status_code=self.delete_status_code, | |
dependencies=(self.delete_deps or []), | |
description=endpoint._description, # type: ignore | |
)(endpoint) | |
def _generate_list_view(self) -> Callable: | |
"""Create a generic list view for Model. | |
To sort, set | |
list_sort = [SomeModel.some_property.asc()] | |
""" | |
def list_view( | |
db: Session = self.db_dep, | |
offset: Optional[int] = self.list_default_offset, | |
limit: Optional[int] = self.list_default_limit, | |
) -> Any: | |
query = db.query(self.model) | |
if self.list_sort is not None: | |
query = query.order_by(*self.list_sort) | |
if ( | |
self.list_default_offset is not None | |
and self.list_default_limit is not None | |
): | |
query = query.offset(offset).limit(limit) | |
return query.all() | |
list_view._description = ( # type: ignore | |
f"🗃️ List {self.model.__name__} sorted by " | |
f"{', '.join(str(x) for x in self.list_sort or [])}" | |
) | |
return list_view | |
def _generate_detail_view(self) -> Callable: | |
"""Create a generic detail view for Model. | |
At the moment this supports _only_ a model with a single, non-composite | |
PK which is an integer called id. | |
""" | |
cols = inspect(self.model).primary_key | |
assert len(cols) == 1 | |
[pk_col] = cols | |
assert pk_col.type.python_type == int | |
def detail_view(id: int, db: Session = self.db_dep,) -> Any: | |
obj = ( | |
db.query(self.model) | |
.filter(getattr(self.model, pk_col.key) == id) | |
.scalar() | |
) | |
if obj is None: | |
raise HTTPException( | |
status_code=self.detail_status_code_not_found | |
) | |
return obj | |
detail_view._description = ( # type: ignore | |
f"""📁 Show {self.model.__name__} identified by {pk_col.key}""" | |
) | |
return detail_view | |
@property | |
def _MODEL_MAP(self) -> Dict[Table, DeclarativeMeta]: | |
"""Create a mapping of table names to models""" | |
return { | |
cast(Table, m.__table__): cast(DeclarativeMeta, m) | |
for m in all_subclasses(self.model_base) # type: ignore | |
} | |
def _check_fk_constraints( | |
self, | |
*, | |
db: Session, | |
model_map: Dict[Table, DeclarativeMeta], | |
fk_constraints: Set[ForeignKeyConstraint], | |
obj_in: BaseModel, | |
status_code_fk_error: int, | |
) -> None: | |
"""Raises errors if FK constraints are violated.""" | |
for constraint in fk_constraints: | |
# check that the corresponding rows exist | |
referred_model = model_map[constraint.referred_table] | |
# neeed a mapping of column keys on self.model | |
constraint_columns_map = { | |
# which map to | |
col.key: next( | |
iter( | |
# the key on the foreign referred_model | |
col_fk.target_fullname.split(".")[-1] | |
# being referenced by this foreign key | |
for col_fk in col.foreign_keys | |
if col_fk.constraint.referred_table # type: ignore | |
== constraint.referred_table | |
) | |
) | |
# for every column in this constraint | |
for col in constraint.columns | |
} | |
# now, we check a row with that value exists | |
num_rows = ( | |
db.query(referred_model) | |
.filter( | |
*( | |
getattr(referred_model, target_key) | |
== getattr(obj_in, local_key) | |
for ( | |
local_key, | |
target_key, | |
) in constraint_columns_map.items() | |
) | |
) | |
.count() | |
) | |
if num_rows == 0: | |
error_row = ( | |
f"{target_key}={getattr(obj_in, local_key)}" | |
for local_key, target_key in constraint_columns_map.items() | |
) | |
raise HTTPException( | |
status_code=status_code_fk_error, | |
detail=( | |
"I could not find a value of " | |
f"{referred_model.__name__} with values " | |
f"{' '.join(error_row)}" | |
), | |
) | |
def _check_unique_indexes( | |
self, | |
db: Session, | |
unique_indexes: Set[FrozenSet[str]], | |
obj_in: BaseModel, | |
status_code_conflict: int, | |
) -> None: | |
"""Check unique indexes won't be violated.""" | |
for ix in unique_indexes: | |
num_rows = ( | |
db.query(self.model) | |
.filter( | |
*(getattr(self.model, k) == getattr(obj_in, k) for k in ix) | |
) | |
.count() | |
) | |
if num_rows != 0: | |
error_row = (f"{k}={getattr(obj_in, k)}" for k in ix) | |
raise HTTPException( | |
status_code=self.create_status_code_conflict, | |
detail=( | |
f"A row already exists in {self.model.__name__} " | |
f"with values {' '.join(error_row)}" | |
), | |
) | |
def _generate_create_view(self) -> Callable: | |
"""Generate a generic create view for the Model. | |
Required features: | |
- check for conflicts on unique constraints | |
- check referenced foreign keys exist | |
- create and return resource with 201 by default | |
""" | |
# we know it's valid because it passed the schema | |
# are there any foreign keys? | |
schema_cols = ( | |
cast(BaseModel, self.create_schema_in) | |
.schema()["properties"] | |
.keys() | |
) | |
# filter out the constraints we need to check | |
fkcs = self.model.__table__.foreign_key_constraints # type: ignore | |
fk_constraints = { | |
constraint | |
for constraint in fkcs | |
if {c.key for c in constraint.columns} < schema_cols | |
} | |
# likewise unique indexes | |
unique_indexes = { | |
frozenset(c.key for c in ix.columns) | |
for ix in self.model.__table__.indexes # type: ignore | |
if ix.unique and frozenset(c.key for c in ix.columns) < schema_cols | |
} | |
# we'll need this to lookup models based on tables | |
model_map = self._MODEL_MAP | |
def create_view( | |
*, | |
obj_in: self.create_schema_in, # type: ignore | |
db: Session = self.db_dep, | |
) -> Any: | |
f"""Create new {self.model.__name__}""" | |
# for every constraint on this model | |
self._check_fk_constraints( | |
db=db, | |
model_map=model_map, | |
fk_constraints=fk_constraints, | |
obj_in=obj_in, | |
status_code_fk_error=self.create_status_code_fk_error, | |
) | |
# check for unique indexes | |
self._check_unique_indexes( | |
db=db, | |
unique_indexes=unique_indexes, | |
obj_in=obj_in, | |
status_code_conflict=self.create_status_code_conflict, | |
) | |
# make the insert | |
instance = self.model() | |
for k, v in obj_in.dict().items(): | |
setattr(instance, k, v) | |
db.add(instance) | |
db.commit() | |
db.refresh(instance) | |
return instance | |
create_view._description = ( # type: ignore | |
f"""💾 Create new {self.model.__name__}""" | |
) | |
return create_view | |
def _generate_update_view(self) -> Callable: | |
"""Generic update view creator.""" | |
assert not ( | |
self.update_schema_out is not None | |
and self.update_status_code in self._HTTP_2XX_NO_RETURN_CODES | |
) | |
cols = inspect(self.model).primary_key | |
assert len(cols) == 1 | |
[pk_col] = cols | |
assert pk_col.type.python_type == int | |
# we know it's valid because it passed the schema | |
# are there any foreign keys? | |
schema_cols = ( | |
cast(BaseModel, self.update_schema_in) | |
.schema()["properties"] | |
.keys() | |
) | |
# filter out the constraints we need to check | |
fkcs = self.model.__table__.foreign_key_constraints # type: ignore | |
fk_constraints = { | |
constraint | |
for constraint in fkcs # type: ignore | |
if {c.key for c in constraint.columns} < schema_cols | |
} | |
# likewise unique indexes | |
unique_indexes = { | |
frozenset(c.key for c in ix.columns) | |
for ix in self.model.__table__.indexes # type: ignore | |
if ix.unique and frozenset(c.key for c in ix.columns) < schema_cols | |
} | |
# we'll need this to lookup models based on tables | |
model_map = self._MODEL_MAP | |
def update_view( | |
*, | |
id: int, | |
db: Session = self.db_dep, | |
obj_in: self.update_schema_in, # type: ignore | |
) -> Any: | |
obj = ( | |
db.query(self.model) | |
.filter(getattr(self.model, pk_col.key) == id) | |
.scalar() | |
) | |
if obj is None: | |
raise HTTPException( | |
status_code=self.update_status_code_not_found | |
) | |
# for every constraint on this model | |
self._check_fk_constraints( | |
db=db, | |
model_map=model_map, | |
fk_constraints=fk_constraints, | |
obj_in=obj_in, | |
status_code_fk_error=self.update_status_code_fk_error, | |
) | |
# check for unique indexes | |
self._check_unique_indexes( | |
db=db, | |
unique_indexes=unique_indexes, | |
obj_in=obj_in, | |
status_code_conflict=self.update_status_code_conflict, | |
) | |
for k, v in obj_in.dict().items(): | |
setattr(obj, k, v) | |
db.add(obj) | |
db.commit() | |
db.refresh(obj) | |
return ( | |
obj | |
if self.update_status_code | |
not in self._HTTP_2XX_NO_RETURN_CODES | |
else None | |
) | |
update_view._description = ( # type: ignore | |
f"""📝 Update {self.model.__name__} identified by {pk_col.key}""" | |
) | |
return update_view | |
def _generate_delete_view(self) -> Callable: | |
"""Generic delete view.""" | |
assert not ( | |
self.delete_schema_out is not None | |
and self.delete_status_code in self._HTTP_2XX_NO_RETURN_CODES | |
) | |
assert not ( | |
self.delete_schema_out is None | |
and self.delete_status_code not in self._HTTP_2XX_NO_RETURN_CODES | |
) | |
cols = inspect(self.model).primary_key | |
assert len(cols) == 1 | |
[pk_col] = cols | |
assert pk_col.type.python_type == int | |
def delete_view(*, id: int, db: Session = self.db_dep,) -> Any: | |
obj = ( | |
db.query(self.model) | |
.filter(getattr(self.model, pk_col.key) == id) | |
.scalar() | |
) | |
if obj is None: | |
raise HTTPException( | |
status_code=self.delete_status_code_not_found | |
) | |
db.delete(obj) | |
db.commit() | |
return obj if self.delete_schema_out is not None else None | |
delete_view._description = ( # type: ignore | |
f"""❌ Delete {self.model.__name__} identified by {pk_col.key}""" | |
) | |
return delete_view |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Standard Library | |
from typing import TYPE_CHECKING | |
# Third Party Libraries | |
from sqlalchemy import Column, ForeignKey, Integer, String | |
from sqlalchemy.orm import relationship | |
# App and Model Imports | |
from app.db.base_class import Base | |
if TYPE_CHECKING: | |
from .user import User # noqa: F401 | |
class Item(Base): | |
id = Column(Integer, primary_key=True, index=True) | |
title = Column(String, index=True) | |
description = Column(String, index=True) | |
owner_id = Column(Integer, ForeignKey("user.id")) | |
owner = relationship("User", back_populates="items") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Standard Library | |
from typing import TYPE_CHECKING | |
# Third Party Libraries | |
from sqlalchemy import Column, ForeignKey, Integer, String | |
from sqlalchemy.orm import relationship | |
# App and Model Imports | |
from app.db.base_class import Base | |
if TYPE_CHECKING: | |
from .user import User # noqa: F401 | |
class Item(Base): | |
id = Column(Integer, primary_key=True, index=True) | |
title = Column(String, index=True) | |
description = Column(String, index=True) | |
owner_id = Column(Integer, ForeignKey("user.id")) | |
owner = relationship("User", back_populates="items") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment