Skip to content

Instantly share code, notes, and snippets.

@blakev
Created July 19, 2021 22:29
Show Gist options
  • Save blakev/d0813090077f7b29b6d472d2a597aa9d to your computer and use it in GitHub Desktop.
Save blakev/d0813090077f7b29b6d472d2a597aa9d to your computer and use it in GitHub Desktop.
FastAPI Query parameter fixer/validators
from inspect import signature
from http import HTTPStatus
from logging import getLogger
from functools import wraps, partial
from typing import Dict, Callable, TypeVar, Optional
from fastapi import HTTPException
from fastapi.params import Query
T = TypeVar('T')
class QueryValidatorType(type):
def __new__(cls, name, bases, ns):
f_mapping = {}
for key, val in ns.items():
if key.startswith('__'):
continue
if key.startswith('fix_'):
new_key = key[4:]
if new_key and callable(val):
f_mapping[new_key] = val
ns['_fixed_mapping'] = f_mapping
obj = super().__new__(cls, name, bases, ns)
return obj
class QueryFixer(metaclass=QueryValidatorType):
"""Enforces a Query field's value to be "fixed" before sending to the
route function. By the time ``QueryFixer`` has seen the value it will have
already passed FastAPI and Pydantic validation.
If the mapped parameter names don't match a Query type in the function signature
then the validation is removed at runtime as if the decorator was never used.
Example::
@router.get('/')
@query_fixer(limit=lambda o: o + 1)
async def return_the_last_tickets(
limit: Optional[int] = QLimit,
after: Optional[DateTime] = Query(None),
):
# limit must be between 0 and 100 (QLimit), but by this point
# it will have gone through the query_fixer instance and can
# have one added.
# This is useful for enforcing date ranges and other non-basic
# types that can be represented, pass validation, but still
# fail in the body of the route.
return {'limit': limit == (QLimit.le + 1)}
"""
__slots__ = (
'_name',
'_mapping',
'logger',
)
def __init__(
self,
mapping: Optional[Dict[str, Callable[[T], T]]] = None,
*,
name: Optional[str] = None,
**fn_map,
):
if not mapping:
mapping = dict()
mapping.update(fn_map)
if fixed := getattr(self, '_fixed_mapping', {}):
fixed = {k: partial(fn, self) for k, fn in fixed.items()}
qv_name = self.__class__.__name__ + '.{name}' if name else ''
mapping.update(fixed)
for k, v in mapping.items():
if not isinstance(k, str):
raise ValueError(f'cannot use mapping with key, {k}')
if not callable(v):
raise ValueError(f'value at mapping `{k}` needs to be a callable')
self._name = name
self._mapping = mapping
self.logger = getLogger(f'{__name__}.{qv_name}')
def __call__(self, fn):
# validate the route signature
found = False
cls = self.__class__.__name__
sig = signature(fn)
params = dict(sig.parameters.items())
for k, v in self._mapping.items():
if k in params:
found = True
param = params[k]
if not issubclass(type(param.default), Query):
raise ValueError(f'cannot use {cls} on non-query parameter, {k}')
if not found:
self.logger.debug(f'no Query parameters match, skipping {cls}')
return fn
@wraps(fn)
async def wrapped(*args, **kwargs):
for attr, value in kwargs.items():
orig_value = value
if attr in self._mapping:
try:
value = self._mapping[attr](value)
except Exception as e:
self.logger.error(e, extra=dict(field=attr, value=orig_value))
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f'Cannot accept value for field {attr}',
)
else:
kwargs[attr] = value
return await fn(*args, **kwargs)
# return the wrapped function with new capabilities
return wrapped
query_fixer = QueryFixer
# alias
@blakev
Copy link
Author

blakev commented Jul 19, 2021

Example

app/validators/dates.py:

from typing import Optional

import pendulum
from pendulum import DateTime

from app.config import CONFIG
from app.validators.base import QueryFixer
from app.utils import now

__all__ = [
    'dates_fixer',
]


class DatesFixer(QueryFixer):

    FIXED_AFTER_DATE = pendulum.parse(CONFIG.api.params.earliest_date)
    # str: '2021-04-01T12:00:00.0000-00:00'

    def fix_before_date(self, value: Optional[DateTime]) -> DateTime:
        if value is None:
            return now()
        return min(now(), value)

    def fix_after_date(self, value: Optional[DateTime]) -> DateTime:
        if value is None:
            return self.FIXED_AFTER_DATE
        return max(now(), self.FIXED_AFTER_DATE)


dates_fixer = DatesFixer()
# singleton

routes.py:

@router.get('/')
@dates_fixer
async def return_the_last_tickets(
    limit: Optional[int] = QLimit,
    after_date: Optional[DateTime] = Query(None),
):
    return {}  # <-- if `after_date` is None, it will become FIXED_AFTER_DATE

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment