Created
July 19, 2021 22:29
-
-
Save blakev/d0813090077f7b29b6d472d2a597aa9d to your computer and use it in GitHub Desktop.
FastAPI Query parameter fixer/validators
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from inspect import signature | |
from http import HTTPStatus | |
from logging import getLogger | |
from functools import wraps, partial | |
from typing import Dict, Callable, TypeVar, Optional | |
from fastapi import HTTPException | |
from fastapi.params import Query | |
T = TypeVar('T') | |
class QueryValidatorType(type): | |
def __new__(cls, name, bases, ns): | |
f_mapping = {} | |
for key, val in ns.items(): | |
if key.startswith('__'): | |
continue | |
if key.startswith('fix_'): | |
new_key = key[4:] | |
if new_key and callable(val): | |
f_mapping[new_key] = val | |
ns['_fixed_mapping'] = f_mapping | |
obj = super().__new__(cls, name, bases, ns) | |
return obj | |
class QueryFixer(metaclass=QueryValidatorType): | |
"""Enforces a Query field's value to be "fixed" before sending to the | |
route function. By the time ``QueryFixer`` has seen the value it will have | |
already passed FastAPI and Pydantic validation. | |
If the mapped parameter names don't match a Query type in the function signature | |
then the validation is removed at runtime as if the decorator was never used. | |
Example:: | |
@router.get('/') | |
@query_fixer(limit=lambda o: o + 1) | |
async def return_the_last_tickets( | |
limit: Optional[int] = QLimit, | |
after: Optional[DateTime] = Query(None), | |
): | |
# limit must be between 0 and 100 (QLimit), but by this point | |
# it will have gone through the query_fixer instance and can | |
# have one added. | |
# This is useful for enforcing date ranges and other non-basic | |
# types that can be represented, pass validation, but still | |
# fail in the body of the route. | |
return {'limit': limit == (QLimit.le + 1)} | |
""" | |
__slots__ = ( | |
'_name', | |
'_mapping', | |
'logger', | |
) | |
def __init__( | |
self, | |
mapping: Optional[Dict[str, Callable[[T], T]]] = None, | |
*, | |
name: Optional[str] = None, | |
**fn_map, | |
): | |
if not mapping: | |
mapping = dict() | |
mapping.update(fn_map) | |
if fixed := getattr(self, '_fixed_mapping', {}): | |
fixed = {k: partial(fn, self) for k, fn in fixed.items()} | |
qv_name = self.__class__.__name__ + '.{name}' if name else '' | |
mapping.update(fixed) | |
for k, v in mapping.items(): | |
if not isinstance(k, str): | |
raise ValueError(f'cannot use mapping with key, {k}') | |
if not callable(v): | |
raise ValueError(f'value at mapping `{k}` needs to be a callable') | |
self._name = name | |
self._mapping = mapping | |
self.logger = getLogger(f'{__name__}.{qv_name}') | |
def __call__(self, fn): | |
# validate the route signature | |
found = False | |
cls = self.__class__.__name__ | |
sig = signature(fn) | |
params = dict(sig.parameters.items()) | |
for k, v in self._mapping.items(): | |
if k in params: | |
found = True | |
param = params[k] | |
if not issubclass(type(param.default), Query): | |
raise ValueError(f'cannot use {cls} on non-query parameter, {k}') | |
if not found: | |
self.logger.debug(f'no Query parameters match, skipping {cls}') | |
return fn | |
@wraps(fn) | |
async def wrapped(*args, **kwargs): | |
for attr, value in kwargs.items(): | |
orig_value = value | |
if attr in self._mapping: | |
try: | |
value = self._mapping[attr](value) | |
except Exception as e: | |
self.logger.error(e, extra=dict(field=attr, value=orig_value)) | |
raise HTTPException( | |
status_code=HTTPStatus.BAD_REQUEST, | |
detail=f'Cannot accept value for field {attr}', | |
) | |
else: | |
kwargs[attr] = value | |
return await fn(*args, **kwargs) | |
# return the wrapped function with new capabilities | |
return wrapped | |
query_fixer = QueryFixer | |
# alias |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Example
app/validators/dates.py
:routes.py
: