Skip to content

Instantly share code, notes, and snippets.

@phillipuniverse
Created December 1, 2024 19:17
Show Gist options
  • Save phillipuniverse/5a0a38557d864a87e7b5c79e9c5d2fdf to your computer and use it in GitHub Desktop.
Save phillipuniverse/5a0a38557d864a87e7b5c79e9c5d2fdf to your computer and use it in GitHub Desktop.
Attempt at monkeypatching FastAPI to work with pydantic.v1 package when Pydantic v2 is installed

An incomplete attempt at trying to allow pydantic.v1 models in FastAPI. My overall conclusion is this is probably a waste of time and too risky, likely adds more risk and takes more time than modifying all models to work with Pydantic v2. There are too many edge cases with FastAPI usage of the global fastapi._compat.PYDANTIC_V2 variable which gets resolved at import time.

Versions

FastAPI 0.115.5 and Pydantic 2.10.2

Usage

Invoke patch_v2_compat() before you do anything with the FastAPI application (like adding routes):

from _pydantic_v1_compat import patch_v2_compat()

patch_v2_compat()

from fastapi import FastAPI

app = FastAPI()

@app.get("/exists")
async def exists() -> PlainTextResponse:
    return PlainTextResponse("exists")

The implementation doesn't work but might be a good enough start for somebody else to take on. Other notes:

  • _v1_compat.py and _v2_compat.py is a copy/paste from fastapi._compat.py, with v1 being if PYDANTIC_V2 was false, and v1 being if PYDANTIC_V2 was true. I didn't copy in the other functions
  • The monkeypatch works against both fastapi._compat and all the places that imports functions from fastapi._compat. I did this because the patching always happens too late, and after fastapi._compat is imported in other modules, that's why there's additional patching at fastapi.routing, fastapi.utils, etc. This isn't the best solution and probably it would be better to use a MetaPathFinder or some other import loader mechanism to change what gets imported from fastapi._compat. Still need to make sure that gets initialized before any fastapi import happens
  • There are other places in FastAPI inside of methods and class definition that rely on fastapi._compat.PYDANTIC_V2. That's why I ended up monkeypatching fastapi.utils.create_model_field, the method implementation changes kwargs being passed basd on fastapi._compat.PYDANTIC_V2
  • The USING_PYDANTIC_V2 global in _pydantic_v1_compat.py isn't very useful, I was trying to go down a path where I could make this be more dynamic but probably it should go away and instead provide real method signatures and make the implementations more dynamic
    • For instance, see get_missing_field_error which just has *args and **kwargs, vs is_scalar_field which does something different based on the actual type of ModelField being passed to the method
from typing import (
Any, Literal,
)
from venv import create
import fastapi
import pydantic.v1.typing
from fastapi._compat import ModelField
from fastapi.types import ModelNameMap
from pydantic import BaseConfig, PydanticSchemaGenerationError
from pydantic._internal import _typing_extra
from pydantic.fields import FieldInfo as V2FieldInfo
from pydantic.v1 import BaseConfig as V1BaseConfig, Extra as V1Extra
from pydantic.v1.fields import FieldInfo as V1FieldInfo
from pydantic.v1.fields import Undefined as V1Undefined
from pydantic.v1 import BaseModel as V1BaseModel
from pydantic.v1.fields import UndefinedType as V1UndefinedType
from pydantic_core import PydanticUndefined
from pydantic import BaseModel as V2BaseModel
from ._v1_compat import lenient_issubclass as v1_lenient_issubclass
from ._v1_compat import ModelField as FastApiPV1ModelField
from ._v1_compat import with_info_plain_validator_function as v1_with_info_plain_validator_function
from ._v1_compat import _model_rebuild as v1__model_rebuild
from ._v1_compat import get_annotation_from_field_info as v1_get_annotation_from_field_info
from ._v1_compat import _normalize_errors as v1__normalize_errors
from ._v1_compat import get_model_definitions as v1_get_model_definitions
from ._v1_compat import _model_dump as v1__model_dump
from ._v1_compat import _get_model_config as v1__get_model_config
from ._v1_compat import get_schema_from_model_field as v1_get_schema_from_model_field
from ._v1_compat import get_compat_model_name_map as v1_get_compat_model_name_map
from ._v1_compat import get_definitions as v1_get_definitions
from ._v1_compat import is_scalar_field as v1_is_scalar_field
from ._v1_compat import is_sequence_field as v1_is_sequence_field
from ._v1_compat import is_scalar_sequence_field as v1_is_scalar_sequence_field
from ._v1_compat import is_bytes_field as v1_is_bytes_field
from ._v1_compat import is_bytes_sequence_field as v1_is_bytes_sequence_field
from ._v1_compat import copy_field_info as v1_copy_field_info
from ._v1_compat import serialize_sequence_value as v1_serialize_sequence_value
from ._v1_compat import get_missing_field_error as v1_get_missing_field_error
from ._v1_compat import create_body_model as v1_create_body_model
from ._v1_compat import get_model_fields as v1_get_model_fields
from ._v2_compat import lenient_issubclass as v2_lenient_issubclass
from ._v2_compat import ModelField as FastApiPV2ModelField
from ._v2_compat import with_info_plain_validator_function as v2_with_info_plain_validator_function
from ._v2_compat import get_annotation_from_field_info as v2_get_annotation_from_field_info
from ._v2_compat import _normalize_errors as v2__normalize_errors
from ._v2_compat import _model_rebuild as v2__model_rebuild
from ._v2_compat import _model_dump as v2__model_dump
from ._v2_compat import _get_model_config as v2__get_model_config
from ._v2_compat import get_schema_from_model_field as v2_get_schema_from_model_field
from ._v2_compat import get_compat_model_name_map as v2_get_compat_model_name_map
from ._v2_compat import get_definitions as v2_get_definitions
from ._v2_compat import is_scalar_field as v2_is_scalar_field
from ._v2_compat import is_sequence_field as v2_is_sequence_field
from ._v2_compat import is_scalar_sequence_field as v2_is_scalar_sequence_field
from ._v2_compat import is_bytes_field as v2_is_bytes_field
from ._v2_compat import is_bytes_sequence_field as v2_is_bytes_sequence_field
from ._v2_compat import copy_field_info as v2_copy_field_info
from ._v2_compat import serialize_sequence_value as v2_serialize_sequence_value
from ._v2_compat import get_missing_field_error as v2_get_missing_field_error
from ._v2_compat import create_body_model as v2_create_body_model
from ._v2_compat import get_model_fields as v2_get_model_fields
# ORIGINAL_PYDANTIC_V2 = fastapi._compat.PYDANTIC_V2
class PydanticV2DynamicCheck:
def __bool__(self):
return False
USING_PYDANTIC_V2 = False
def lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool:
# Modified to make both work at the same time
# in fastapi.dependencies.utils line 508 there's an assertion that incorrectly assumes
# stuff about Pydantic v1. It makes a comparison to a field.type_ and a BaseModel, which
# won't work when field.type_ is a V1 model, as pydantic.BaseModel will now be v2.
#
# attempt to catch that case and swap it around with the v1 basemodel comparison
type_is_v1_basemodel = issubclass(cls, V1BaseModel) and class_or_tuple is V2BaseModel
try:
return isinstance(cls, type) and (issubclass(cls, class_or_tuple) or type_is_v1_basemodel)
except TypeError:
if isinstance(cls, _typing_extra.WithArgsTypes):
return False
if isinstance(cls, pydantic.v1.typing.WithArgsTypes):
return False
raise # pragma: no cover
def with_info_plain_validator_function(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_with_info_plain_validator_function(*args, **kwargs)
else:
return v1_with_info_plain_validator_function(*args, **kwargs)
def get_model_definitions(*args, **kwargs) -> Any:
# Function does not exist in fastapi._compat for v2"
return v1_get_model_definitions(*args, **kwargs)
def get_annotation_from_field_info(
annotation: Any, field_info: Any, field_name: str
) -> Any:
if isinstance(field_info, V2FieldInfo):
return v2_get_annotation_from_field_info(annotation, field_info, field_name)
else:
return v1_get_annotation_from_field_info(annotation, field_info, field_name)
def _normalize_errors(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2__normalize_errors(*args, **kwargs)
else:
return v1__normalize_errors(*args, **kwargs)
def _model_rebuild(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2__model_rebuild(*args, **kwargs)
else:
return v1__model_rebuild(*args, **kwargs)
def _model_dump(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2__model_dump(*args, **kwargs)
else:
return v1__model_dump(*args, **kwargs)
def _get_model_config(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2__get_model_config(*args, **kwargs)
else:
return v1__get_model_config(*args, **kwargs)
def get_schema_from_model_field(*, field: ModelField, **kwargs) -> Any:
if isinstance(field, FastApiPV2ModelField):
return v2_get_schema_from_model_field(field=field, **kwargs)
else:
return v1_get_schema_from_model_field(field=field, **kwargs)
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
# v2 version doesn't return any values, so only need to care about v1 fields
v1_fields = [f for f in fields if isinstance(f, FastApiPV1ModelField)]
if not v1_fields:
return {}
else:
return v1_get_compat_model_name_map(v1_fields)
def get_definitions(*, fields: list[ModelField], **kwargs) -> Any:
v1_fields = [f for f in fields if isinstance(f, FastApiPV1ModelField)]
v2_fields = [f for f in fields if isinstance(f, FastApiPV2ModelField)]
v1_result = v1_get_definitions(fields=v1_fields, **kwargs)
v2_result = v2_get_definitions(fields=v2_fields, **kwargs)
final = v1_result
final[0].update(v2_result[0])
final[1].update(v2_result[1])
return final
def is_scalar_field(field: ModelField) -> bool:
if isinstance(field, FastApiPV2ModelField):
return v2_is_scalar_field(field=field)
else:
return v1_is_scalar_field(field=field)
def is_sequence_field(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_is_sequence_field(*args, **kwargs)
else:
return v1_is_sequence_field(*args, **kwargs)
def is_scalar_sequence_field(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_is_scalar_sequence_field(*args, **kwargs)
else:
return v1_is_scalar_sequence_field(*args, **kwargs)
def is_bytes_field(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_is_bytes_field(*args, **kwargs)
else:
return v1_is_bytes_field(*args, **kwargs)
def is_bytes_sequence_field(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_is_bytes_sequence_field(*args, **kwargs)
else:
return v1_is_bytes_sequence_field(*args, **kwargs)
def copy_field_info(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_copy_field_info(*args, **kwargs)
else:
return v1_copy_field_info(*args, **kwargs)
def serialize_sequence_value(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_serialize_sequence_value(*args, **kwargs)
else:
return v1_serialize_sequence_value(*args, **kwargs)
def get_missing_field_error(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_get_missing_field_error(*args, **kwargs)
else:
return v1_get_missing_field_error(*args, **kwargs)
def create_body_model(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_create_body_model(*args, **kwargs)
else:
return v1_create_body_model(*args, **kwargs)
def get_model_fields(*args, **kwargs) -> Any:
if USING_PYDANTIC_V2:
return v2_get_model_fields(*args, **kwargs)
else:
return v1_get_model_fields(*args, **kwargs)
_UNSET = object()
class ConfigWithExtra(V1BaseConfig):
extra = V1Extra.allow
class PydanticV1CompatModelField:
def __new__(cls, *args, **kwargs):
pydantic_v1: bool = kwargs.pop("pydantic_v1", False)
if pydantic_v1:
if "model_config" in kwargs and not issubclass(kwargs["model_config"], V1BaseConfig):
kwargs["model_config"] = ConfigWithExtra
if "field_info" in kwargs and not isinstance(kwargs["field_info"], V1FieldInfo):
kwargs["field_info"] = V1FieldInfo()
if "mode" in kwargs:
kwargs.pop("mode")
res = FastApiPV1ModelField(*args, **kwargs)
else:
res = FastApiPV2ModelField(*args, **kwargs)
return res
def create_model_field(
name: str,
type_: Any,
class_validators: dict[str, Any] = None,
default: Any | None = _UNSET,
required: bool | Any = _UNSET,
model_config: Any = BaseConfig,
field_info: Any | None = None,
alias: str | None = None,
mode: Literal["validation", "serialization"] = "validation",
) -> Any:
class_validators = class_validators or {}
pydantic_v1 = issubclass(type_, pydantic.v1.BaseModel)
if pydantic_v1:
if default is _UNSET:
default = V1Undefined
if required is _UNSET:
required = V1Undefined
if model_config is _UNSET:
model_config = V1BaseConfig
field_info = field_info or V1FieldInfo()
else:
if default is _UNSET:
default = PydanticUndefined
if required is _UNSET:
required = PydanticUndefined
if model_config is _UNSET:
model_config = BaseConfig
field_info = field_info or FieldInfo(
annotation=type_, default=default, alias=alias
)
kwargs = {"name": name, "field_info": field_info}
if pydantic_v1:
kwargs.update(
{
"type_": type_,
"class_validators": class_validators,
"default": default,
"required": required,
"model_config": model_config,
"alias": alias,
}
)
else:
kwargs.update({"mode": mode})
try:
return PydanticV1CompatModelField(**kwargs, pydantic_v1=pydantic_v1) # type: ignore[arg-type]
except (RuntimeError, PydanticSchemaGenerationError):
raise fastapi.exceptions.FastAPIError(
"Invalid args for response field! Hint: "
f"check that {type_} is a valid Pydantic field type. "
"If you are using a return type annotation that is not a valid Pydantic "
"field (e.g. Union[Response, dict, None]) you can disable generating the "
"response model from the type annotation with the path operation decorator "
"parameter response_model=None. Read more: "
"https://fastapi.tiangolo.com/tutorial/response-model/"
) from None
def patch_v2_compat():
# fastapi._compat.PYDANTIC_V2 = PydanticV2DynamicCheck()
import fastapi._compat
fastapi._compat.ModelField = PydanticV1CompatModelField
fastapi._compat.with_info_plain_validator_function = with_info_plain_validator_function
fastapi._compat._get_model_config = _get_model_config
fastapi._compat.get_annotation_from_field_info = get_annotation_from_field_info
fastapi._compat._normalize_errors = _normalize_errors
fastapi._compat.get_model_definitions = get_model_definitions
fastapi._compat.get_schema_from_model_field = get_schema_from_model_field
fastapi._compat.get_compat_model_name_map = get_compat_model_name_map
fastapi._compat.get_definitions = get_definitions
fastapi._compat.is_scalar_field = is_scalar_field
fastapi._compat.is_sequence_field = is_sequence_field
fastapi._compat.is_scalar_sequence_field = is_scalar_sequence_field
fastapi._compat.is_bytes_field = is_bytes_field
fastapi._compat.is_bytes_sequence_field = is_bytes_sequence_field
fastapi._compat.copy_field_info = copy_field_info
fastapi._compat.serialize_sequence_value = serialize_sequence_value
fastapi._compat.get_missing_field_error = get_missing_field_error
fastapi._compat.create_body_model = create_body_model
fastapi._compat.get_model_fields = get_model_fields
# Now patch all the places we import it since we're actually patching fastapi._compat too late
import fastapi.routing
fastapi.routing.ModelField = PydanticV1CompatModelField
# fastapi.routing.Undefined
fastapi.routing._get_model_config = _get_model_config
fastapi.routing._model_dump = _model_dump
fastapi.routing._normalize_errors = _normalize_errors
fastapi.routing.lenient_issubclass = lenient_issubclass
fastapi.routing.create_model_field = create_model_field
import fastapi.utils
# fastapi.utils.PYDANTIC_V2 = PydanticV2DynamicCheck() # needs the dynamic check else we miss kwargs in initializations
# fastapi.utils.BaseConfig = # TODO
fastapi.utils.ModelField = PydanticV1CompatModelField
# fastapi.utils.PydanticSchemaGenerationError,
# fastapi.utils.Undefined, # TODO: I think this is ok?
# fastapi.utils.UndefinedType, # TODO: I think this is ok?
# fastapi.utils.Validator,
fastapi.utils.lenient_issubclass = lenient_issubclass
fastapi.utils.create_model_field = create_model_field
# import fastapi.dependencies.models
# fastapi.dependencies.models.ModelField = PydanticV1CompatModelField
import fastapi.dependencies.utils
# fastapi.dependencies.utils.PYDANTIC_V2 = PydanticV2DynamicCheck() # needs the dynamic check else we call an invalid method
# fastapi.dependencies.utils.ErrorWrapper, # TODO: maybe ok?
# fastapi.dependencies.utils.ModelField, # only used in type hints
# fastapi.dependencies.utils.RequiredParam, # TODO: maybe ok?
# fastapi.dependencies.utils.Undefined, # TODO: maybe ok?
# fastapi.dependencies.utils._regenerate_error_with_loc # same in both, no need to override
fastapi.dependencies.utils.copy_field_info = copy_field_info
fastapi.dependencies.utils.create_body_model = create_body_model
fastapi.dependencies.utils.evaluate_forwardref
# fastapi.dependencies.utils.field_annotation_is_scalar # same in both, no need to override
fastapi.dependencies.utils.get_annotation_from_field_info = get_annotation_from_field_info
# fastapi.dependencies.utils.get_cached_model_fields # same in both, no need to override
fastapi.dependencies.utils.get_missing_field_error = get_missing_field_error
fastapi.dependencies.utils.is_bytes_field = is_bytes_field
fastapi.dependencies.utils.is_bytes_sequence_field = is_bytes_sequence_field
fastapi.dependencies.utils.is_scalar_field = is_scalar_field
fastapi.dependencies.utils.is_scalar_sequence_field = is_scalar_sequence_field
fastapi.dependencies.utils.is_sequence_field = is_sequence_field
# fastapi.dependencies.utils.is_uploadfile_or_nonable_uploadfile_annotation # same in both, no need to override
# fastapi.dependencies.utils.is_uploadfile_sequence_annotation # same in both, no need to override
fastapi.dependencies.utils.lenient_issubclass = lenient_issubclass
# fastapi.dependencies.utils.sequence_types # same in both, no need to override
fastapi.dependencies.utils.serialize_sequence_value = serialize_sequence_value
# fastapi.dependencies.utils.value_is_sequence # same in both, no need to override
fastapi.dependencies.utils.create_model_field = create_model_field
import fastapi.openapi.models
# fastapi.openapi.models.PYDANTIC_V2, # TODO
# fastapi.openapi.models.CoreSchema, # only used in type hints
# fastapi.openapi.models.GetJsonSchemaHandler, # only used in type hints
# fastapi.openapi.models.JsonSchemaValue, # only used in type hints
fastapi.openapi.models._model_rebuild = _model_rebuild
fastapi.openapi.models.with_info_plain_validator_function = with_info_plain_validator_function
import fastapi.openapi.utils
# fastapi.openapi.utils.GenerateJsonSchema, # TODO: this might be a problem in how it's used from get_openapi as a default value
# the V1 DataClass, but maybe not
# fastapi.openapi.utils.JsonSchemaValue, # only used in type hints
# fastapi.openapi.utils.ModelField, # only used in type hints
# fastapi.openapi.utils.Undefined, # TODO: I think this is ok?
fastapi.openapi.utils.get_compat_model_name_map = get_compat_model_name_map
fastapi.openapi.utils.get_definitions = get_definitions
fastapi.openapi.utils.get_schema_from_model_field = get_schema_from_model_field
fastapi.openapi.utils.lenient_issubclass = lenient_issubclass
from collections import deque
from collections.abc import Callable, Sequence
from copy import copy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from typing import (
Any,
Deque,
FrozenSet,
List,
Literal,
Set,
Tuple,
)
from fastapi._compat import _annotation_is_sequence, field_annotation_is_sequence
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from pydantic.v1 import AnyUrl as Url # noqa: F401
from pydantic.v1 import ( # type: ignore[assignment]
BaseConfig as BaseConfig,
)
from pydantic.v1 import ValidationError as ValidationError
from pydantic.v1.class_validators import ( # type: ignore[no-redef]
Validator as Validator,
)
from pydantic.v1.error_wrappers import ( # type: ignore[no-redef]
ErrorWrapper as ErrorWrapper,
)
from pydantic.v1.errors import MissingError
from pydantic.v1.fields import ( # type: ignore[attr-defined]
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
)
from pydantic.v1.fields import FieldInfo as FieldInfo
from pydantic.v1.fields import ( # type: ignore[no-redef,attr-defined]
ModelField as ModelField,
)
# Keeping old "Required" functionality from Pydantic V1, without
# shadowing typing.Required.
RequiredParam: Any = Ellipsis # type: ignore[no-redef]
from fastapi.exceptions import RequestErrorModel
from fastapi.types import ModelNameMap
from pydantic.v1 import BaseModel, create_model
from pydantic.v1.fields import ( # type: ignore[no-redef,attr-defined]
Undefined as Undefined,
)
from pydantic.v1.fields import ( # type: ignore[no-redef, attr-defined]
UndefinedType as UndefinedType,
)
from pydantic.v1.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
model_process_schema,
)
from pydantic.v1.schema import ( # type: ignore[no-redef]
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.v1.typing import ( # type: ignore[no-redef]
evaluate_forwardref as evaluate_forwardref,
)
from pydantic.v1.utils import ( # type: ignore[no-redef]
lenient_issubclass as lenient_issubclass,
)
sequence_annotation_to_type = {
Sequence: list,
List: list, # noqa: UP006
list: list,
Tuple: tuple, # noqa: UP006
tuple: tuple,
Set: set, # noqa: UP006
set: set,
FrozenSet: frozenset, # noqa: UP006
frozenset: frozenset,
Deque: deque, # noqa: UP006
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())
GetJsonSchemaHandler = Any # type: ignore[assignment,misc]
JsonSchemaValue = dict[str, Any] # type: ignore[misc]
CoreSchema = Any # type: ignore[assignment,misc]
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
SHAPE_FROZENSET,
SHAPE_TUPLE,
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}
sequence_shape_to_type = {
SHAPE_LIST: list,
SHAPE_SET: set,
SHAPE_TUPLE: tuple,
SHAPE_SEQUENCE: list,
SHAPE_TUPLE_ELLIPSIS: list,
}
@dataclass
class GenerateJsonSchema: # type: ignore[no-redef]
ref_template: str
class PydanticSchemaGenerationError(Exception): # type: ignore[no-redef]
pass
def with_info_plain_validator_function( # type: ignore[misc]
function: Callable[..., Any],
*,
ref: str | None = None,
metadata: Any = None,
serialization: Any = None,
) -> Any:
return {}
def get_model_definitions(
*,
flat_models: set[type[BaseModel] | type[Enum]],
model_name_map: dict[type[BaseModel] | type[Enum], str],
) -> dict[str, Any]:
definitions: dict[str, dict[str, Any]] = {}
for model in flat_models:
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
if "description" in m_schema:
m_schema["description"] = m_schema["description"].split("\f")[0]
definitions[model_name] = m_schema
return definitions
def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, dict)
and not field_annotation_is_sequence(field.type_)
and not is_dataclass(field.type_)
and not isinstance(field_info, params.Body)
):
return False
if field.sub_fields: # type: ignore[attr-defined]
if not all(
is_pv1_scalar_field(f)
for f in field.sub_fields # type: ignore[attr-defined]
):
return False
return True
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
if (field.shape in sequence_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
field.type_, BaseModel
):
if field.sub_fields is not None: # type: ignore[attr-defined]
for sub_field in field.sub_fields: # type: ignore[attr-defined]
if not is_pv1_scalar_field(sub_field):
return False
return True
if _annotation_is_sequence(field.type_):
return True
return False
def _normalize_errors(errors: Sequence[Any]) -> list[dict[str, Any]]:
use_errors: list[Any] = []
for error in errors:
if isinstance(error, ErrorWrapper):
new_errors = ValidationError( # type: ignore[call-arg]
errors=[error], model=RequestErrorModel
).errors()
use_errors.extend(new_errors)
elif isinstance(error, list):
use_errors.extend(_normalize_errors(error))
else:
use_errors.append(error)
return use_errors
def _model_rebuild(model: type[BaseModel]) -> None:
model.update_forward_refs()
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
return model.dict(**kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.__config__ # type: ignore[attr-defined]
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
separate_input_output_schemas: bool = True,
) -> dict[str, Any]:
# This expects that GenerateJsonSchema was already used to generate the definitions
return field_schema( # type: ignore[no-any-return]
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0]
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
models = get_flat_models_from_fields(fields, known_models=set())
return get_model_name_map(models) # type: ignore[no-any-return]
def get_definitions(
*,
fields: list[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, dict[str, Any]],
]:
models = get_flat_models_from_fields(fields, known_models=set())
return {}, get_model_definitions(flat_models=models, model_name_map=model_name_map)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes)
def is_bytes_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> type[BaseModel]:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel
def get_model_fields(model: type[BaseModel]) -> list[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]
from collections import deque
from collections.abc import Sequence
from copy import copy
from typing import (
Any,
Deque,
FrozenSet,
List,
Literal,
Set,
Tuple,
get_origin,
)
from fastapi._compat import (
ModelField,
field_annotation_is_scalar_sequence,
field_annotation_is_sequence,
is_bytes_or_nonable_bytes_annotation,
is_bytes_sequence_annotation, field_annotation_is_scalar,
)
from fastapi.types import ModelNameMap
from pydantic import BaseModel, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import ValidationError as ValidationError
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
GetJsonSchemaHandler as GetJsonSchemaHandler,
)
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
sequence_annotation_to_type = {
Sequence: list,
List: list, # noqa: UP006
list: list,
Tuple: tuple, # noqa: UP006
tuple: tuple,
Set: set, # noqa: UP006
set: set,
FrozenSet: frozenset, # noqa: UP006
frozenset: frozenset,
Deque: deque, # noqa: UP006
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())
try:
from pydantic_core.core_schema import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
except ImportError: # pragma: no cover
from pydantic_core.core_schema import (
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
)
RequiredParam = PydanticUndefined
Undefined = PydanticUndefined
UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
Validator = Any
class BaseConfig:
pass
class ErrorWrapper(Exception):
pass
def get_annotation_from_field_info(annotation: Any, field_info: FieldInfo, field_name: str) -> Any:
return annotation
def _normalize_errors(errors: Sequence[Any]) -> list[dict[str, Any]]:
return errors # type: ignore[return-value]
def _model_rebuild(model: type[BaseModel]) -> None:
model.model_rebuild()
def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any:
return model.model_dump(mode=mode, **kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.model_config
def get_schema_from_model_field(
*,
field: ModelField,
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
field_mapping: dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
separate_input_output_schemas: bool = True,
) -> dict[str, Any]:
override_mode: Literal["validation"] | None = (
None if separate_input_output_schemas else "validation"
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema["title"] = field.field_info.title or field.alias.title().replace("_", " ")
return json_schema
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
return {}
def get_definitions(
*,
fields: list[ModelField],
schema_generator: GenerateJsonSchema,
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, dict[str, Any]],
]:
override_mode: Literal["validation"] | None = (
None if separate_input_output_schemas else "validation"
)
inputs = [
(field, override_mode or field.mode, field._type_adapter.core_schema) for field in fields
]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
return field_mapping, definitions # type: ignore[return-value]
def is_scalar_field(field: ModelField) -> bool:
from fastapi import params
return field_annotation_is_scalar(field.field_info.annotation) and not isinstance(
field.field_info, params.Body
)
def is_sequence_field(field: ModelField) -> bool:
return field_annotation_is_sequence(field.field_info.annotation)
def is_scalar_sequence_field(field: ModelField) -> bool:
return field_annotation_is_scalar_sequence(field.field_info.annotation)
def is_bytes_field(field: ModelField) -> bool:
return is_bytes_or_nonable_bytes_annotation(field.type_)
def is_bytes_sequence_field(field: ModelField) -> bool:
return is_bytes_sequence_annotation(field.type_)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
cls = type(field_info)
merged_field_info = cls.from_annotation(annotation)
new_field_info = copy(field_info)
new_field_info.metadata = merged_field_info.metadata
new_field_info.annotation = merged_field_info.annotation
return new_field_info
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
error["input"] = None
return error # type: ignore[return-value]
def create_body_model(*, fields: Sequence[ModelField], model_name: str) -> type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
def get_model_fields(model: type[BaseModel]) -> list[ModelField]:
return [
ModelField(field_info=field_info, name=name)
for name, field_info in model.model_fields.items()
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment