增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -1,44 +1,9 @@
from .main import BaseConfig as BaseConfig
from .main import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .main import RequiredParam as RequiredParam
from .main import Undefined as Undefined
from .main import UndefinedType as UndefinedType
from .main import Url as Url
from .main import Validator as Validator
from .main import _get_model_config as _get_model_config
from .main import _is_error_wrapper as _is_error_wrapper
from .main import _is_model_class as _is_model_class
from .main import _is_model_field as _is_model_field
from .main import _is_undefined as _is_undefined
from .main import _model_dump as _model_dump
from .main import _model_rebuild as _model_rebuild
from .main import copy_field_info as copy_field_info
from .main import create_body_model as create_body_model
from .main import evaluate_forwardref as evaluate_forwardref
from .main import get_annotation_from_field_info as get_annotation_from_field_info
from .main import get_cached_model_fields as get_cached_model_fields
from .main import get_compat_model_name_map as get_compat_model_name_map
from .main import get_definitions as get_definitions
from .main import get_missing_field_error as get_missing_field_error
from .main import get_schema_from_model_field as get_schema_from_model_field
from .main import is_bytes_field as is_bytes_field
from .main import is_bytes_sequence_field as is_bytes_sequence_field
from .main import is_scalar_field as is_scalar_field
from .main import is_scalar_sequence_field as is_scalar_sequence_field
from .main import is_sequence_field as is_sequence_field
from .main import serialize_sequence_value as serialize_sequence_value
from .main import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
from .may_v1 import CoreSchema as CoreSchema
from .may_v1 import GetJsonSchemaHandler as GetJsonSchemaHandler
from .may_v1 import JsonSchemaValue as JsonSchemaValue
from .may_v1 import _normalize_errors as _normalize_errors
from .model_field import ModelField as ModelField
from .shared import PYDANTIC_V2 as PYDANTIC_V2
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
from .shared import field_annotation_is_scalar as field_annotation_is_scalar
from .shared import is_pydantic_v1_model_class as is_pydantic_v1_model_class
from .shared import is_pydantic_v1_model_instance as is_pydantic_v1_model_instance
from .shared import (
is_uploadfile_or_nonable_uploadfile_annotation as is_uploadfile_or_nonable_uploadfile_annotation,
)
@@ -48,3 +13,29 @@ from .shared import (
from .shared import lenient_issubclass as lenient_issubclass
from .shared import sequence_types as sequence_types
from .shared import value_is_sequence as value_is_sequence
from .v2 import BaseConfig as BaseConfig
from .v2 import ModelField as ModelField
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .v2 import RequiredParam as RequiredParam
from .v2 import Undefined as Undefined
from .v2 import UndefinedType as UndefinedType
from .v2 import Url as Url
from .v2 import Validator as Validator
from .v2 import _regenerate_error_with_loc as _regenerate_error_with_loc
from .v2 import copy_field_info as copy_field_info
from .v2 import create_body_model as create_body_model
from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_cached_model_fields as get_cached_model_fields
from .v2 import get_compat_model_name_map as get_compat_model_name_map
from .v2 import get_definitions as get_definitions
from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import get_schema_from_model_field as get_schema_from_model_field
from .v2 import is_bytes_field as is_bytes_field
from .v2 import is_bytes_sequence_field as is_bytes_sequence_field
from .v2 import is_scalar_field as is_scalar_field
from .v2 import is_scalar_sequence_field as is_scalar_sequence_field
from .v2 import is_sequence_field as is_sequence_field
from .v2 import serialize_sequence_value as serialize_sequence_value
from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function,
)

View File

@@ -1,362 +0,0 @@
import sys
from functools import lru_cache
from typing import (
Any,
Dict,
List,
Sequence,
Tuple,
Type,
)
from fastapi._compat import may_v1
from fastapi._compat.shared import PYDANTIC_V2, lenient_issubclass
from fastapi.types import ModelNameMap
from pydantic import BaseModel
from typing_extensions import Literal
from .model_field import ModelField
if PYDANTIC_V2:
from .v2 import BaseConfig as BaseConfig
from .v2 import FieldInfo as FieldInfo
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .v2 import RequiredParam as RequiredParam
from .v2 import Undefined as Undefined
from .v2 import UndefinedType as UndefinedType
from .v2 import Url as Url
from .v2 import Validator as Validator
from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
else:
from .v1 import BaseConfig as BaseConfig # type: ignore[assignment]
from .v1 import FieldInfo as FieldInfo
from .v1 import ( # type: ignore[assignment]
PydanticSchemaGenerationError as PydanticSchemaGenerationError,
)
from .v1 import RequiredParam as RequiredParam
from .v1 import Undefined as Undefined
from .v1 import UndefinedType as UndefinedType
from .v1 import Url as Url # type: ignore[assignment]
from .v1 import Validator as Validator
from .v1 import evaluate_forwardref as evaluate_forwardref
from .v1 import get_missing_field_error as get_missing_field_error
from .v1 import ( # type: ignore[assignment]
with_info_plain_validator_function as with_info_plain_validator_function,
)
@lru_cache
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
if lenient_issubclass(model, may_v1.BaseModel):
from fastapi._compat import v1
return v1.get_model_fields(model)
else:
from . import v2
return v2.get_model_fields(model) # type: ignore[return-value]
def _is_undefined(value: object) -> bool:
if isinstance(value, may_v1.UndefinedType):
return True
elif PYDANTIC_V2:
from . import v2
return isinstance(value, v2.UndefinedType)
return False
def _get_model_config(model: BaseModel) -> Any:
if isinstance(model, may_v1.BaseModel):
from fastapi._compat import v1
return v1._get_model_config(model)
elif PYDANTIC_V2:
from . import v2
return v2._get_model_config(model)
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
if isinstance(model, may_v1.BaseModel):
from fastapi._compat import v1
return v1._model_dump(model, mode=mode, **kwargs)
elif PYDANTIC_V2:
from . import v2
return v2._model_dump(model, mode=mode, **kwargs)
def _is_error_wrapper(exc: Exception) -> bool:
if isinstance(exc, may_v1.ErrorWrapper):
return True
elif PYDANTIC_V2:
from . import v2
return isinstance(exc, v2.ErrorWrapper)
return False
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
if isinstance(field_info, may_v1.FieldInfo):
from fastapi._compat import v1
return v1.copy_field_info(field_info=field_info, annotation=annotation)
else:
assert PYDANTIC_V2
from . import v2
return v2.copy_field_info(field_info=field_info, annotation=annotation)
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
if fields and isinstance(fields[0], may_v1.ModelField):
from fastapi._compat import v1
return v1.create_body_model(fields=fields, model_name=model_name)
else:
assert PYDANTIC_V2
from . import v2
return v2.create_body_model(fields=fields, model_name=model_name) # type: ignore[arg-type]
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Any:
if isinstance(field_info, may_v1.FieldInfo):
from fastapi._compat import v1
return v1.get_annotation_from_field_info(
annotation=annotation, field_info=field_info, field_name=field_name
)
else:
assert PYDANTIC_V2
from . import v2
return v2.get_annotation_from_field_info(
annotation=annotation, field_info=field_info, field_name=field_name
)
def is_bytes_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_bytes_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_bytes_field(field) # type: ignore[arg-type]
def is_bytes_sequence_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_bytes_sequence_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_bytes_sequence_field(field) # type: ignore[arg-type]
def is_scalar_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_scalar_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_scalar_field(field) # type: ignore[arg-type]
def is_scalar_sequence_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_scalar_sequence_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_scalar_sequence_field(field) # type: ignore[arg-type]
def is_sequence_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_sequence_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_sequence_field(field) # type: ignore[arg-type]
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.serialize_sequence_value(field=field, value=value)
else:
assert PYDANTIC_V2
from . import v2
return v2.serialize_sequence_value(field=field, value=value) # type: ignore[arg-type]
def _model_rebuild(model: Type[BaseModel]) -> None:
if lenient_issubclass(model, may_v1.BaseModel):
from fastapi._compat import v1
v1._model_rebuild(model)
elif PYDANTIC_V2:
from . import v2
v2._model_rebuild(model)
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
v1_model_fields = [
field for field in fields if isinstance(field, may_v1.ModelField)
]
if v1_model_fields:
from fastapi._compat import v1
v1_flat_models = v1.get_flat_models_from_fields(
v1_model_fields, known_models=set()
)
all_flat_models = v1_flat_models
else:
all_flat_models = set()
if PYDANTIC_V2:
from . import v2
v2_model_fields = [
field for field in fields if isinstance(field, v2.ModelField)
]
v2_flat_models = v2.get_flat_models_from_fields(
v2_model_fields, known_models=set()
)
all_flat_models = all_flat_models.union(v2_flat_models)
model_name_map = v2.get_model_name_map(all_flat_models)
return model_name_map
from fastapi._compat import v1
model_name_map = v1.get_model_name_map(all_flat_models)
return model_name_map
def get_definitions(
*,
fields: List[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]],
may_v1.JsonSchemaValue,
],
Dict[str, Dict[str, Any]],
]:
if sys.version_info < (3, 14):
v1_fields = [field for field in fields if isinstance(field, may_v1.ModelField)]
v1_field_maps, v1_definitions = may_v1.get_definitions(
fields=v1_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
if not PYDANTIC_V2:
return v1_field_maps, v1_definitions
else:
from . import v2
v2_fields = [field for field in fields if isinstance(field, v2.ModelField)]
v2_field_maps, v2_definitions = v2.get_definitions(
fields=v2_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
all_definitions = {**v1_definitions, **v2_definitions}
all_field_maps = {**v1_field_maps, **v2_field_maps}
return all_field_maps, all_definitions
# Pydantic v1 is not supported since Python 3.14
else:
from . import v2
v2_fields = [field for field in fields if isinstance(field, v2.ModelField)]
v2_field_maps, v2_definitions = v2.get_definitions(
fields=v2_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
return v2_field_maps, v2_definitions
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]],
may_v1.JsonSchemaValue,
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.get_schema_from_model_field(
field=field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
else:
assert PYDANTIC_V2
from . import v2
return v2.get_schema_from_model_field(
field=field, # type: ignore[arg-type]
model_name_map=model_name_map,
field_mapping=field_mapping, # type: ignore[arg-type]
separate_input_output_schemas=separate_input_output_schemas,
)
def _is_model_field(value: Any) -> bool:
if isinstance(value, may_v1.ModelField):
return True
elif PYDANTIC_V2:
from . import v2
return isinstance(value, v2.ModelField)
return False
def _is_model_class(value: Any) -> bool:
if lenient_issubclass(value, may_v1.BaseModel):
return True
elif PYDANTIC_V2:
from . import v2
return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined]
return False

View File

@@ -1,123 +0,0 @@
import sys
from typing import Any, Dict, List, Literal, Sequence, Tuple, Type, Union
from fastapi.types import ModelNameMap
if sys.version_info >= (3, 14):
class AnyUrl:
pass
class BaseConfig:
pass
class BaseModel:
pass
class Color:
pass
class CoreSchema:
pass
class ErrorWrapper:
pass
class FieldInfo:
pass
class GetJsonSchemaHandler:
pass
class JsonSchemaValue:
pass
class ModelField:
pass
class NameEmail:
pass
class RequiredParam:
pass
class SecretBytes:
pass
class SecretStr:
pass
class Undefined:
pass
class UndefinedType:
pass
class Url:
pass
from .v2 import ValidationError, create_model
def get_definitions(
*,
fields: List[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
Dict[str, Dict[str, Any]],
]:
return {}, {} # pragma: no cover
else:
from .v1 import AnyUrl as AnyUrl
from .v1 import BaseConfig as BaseConfig
from .v1 import BaseModel as BaseModel
from .v1 import Color as Color
from .v1 import CoreSchema as CoreSchema
from .v1 import ErrorWrapper as ErrorWrapper
from .v1 import FieldInfo as FieldInfo
from .v1 import GetJsonSchemaHandler as GetJsonSchemaHandler
from .v1 import JsonSchemaValue as JsonSchemaValue
from .v1 import ModelField as ModelField
from .v1 import NameEmail as NameEmail
from .v1 import RequiredParam as RequiredParam
from .v1 import SecretBytes as SecretBytes
from .v1 import SecretStr as SecretStr
from .v1 import Undefined as Undefined
from .v1 import UndefinedType as UndefinedType
from .v1 import Url as Url
from .v1 import ValidationError, create_model
from .v1 import get_definitions as get_definitions
RequestErrorModel: Type[BaseModel] = create_model("Request")
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
use_errors: List[Any] = []
for error in errors:
if isinstance(error, ErrorWrapper):
new_errors = ValidationError( # type: ignore[call-arg]
errors=[error], model=RequestErrorModel
).errors()
use_errors.extend(new_errors)
elif isinstance(error, list):
use_errors.extend(_normalize_errors(error))
else:
use_errors.append(error)
return use_errors
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors

View File

@@ -1,53 +0,0 @@
from typing import (
Any,
Dict,
List,
Tuple,
Union,
)
from fastapi.types import IncEx
from pydantic.fields import FieldInfo
from typing_extensions import Literal, Protocol
class ModelField(Protocol):
field_info: "FieldInfo"
name: str
mode: Literal["validation", "serialization"] = "validation"
_version: Literal["v1", "v2"] = "v1"
@property
def alias(self) -> str: ...
@property
def required(self) -> bool: ...
@property
def default(self) -> Any: ...
@property
def type_(self) -> Any: ...
def get_default(self) -> Any: ...
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: ...
def serialize(
self,
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any: ...

View File

@@ -1,36 +1,24 @@
import sys
import types
import typing
import warnings
from collections import deque
from collections.abc import Mapping, Sequence
from dataclasses import is_dataclass
from typing import (
Annotated,
Any,
Deque,
FrozenSet,
List,
Mapping,
Sequence,
Set,
Tuple,
Type,
Union,
)
from fastapi._compat import may_v1
from fastapi.types import UnionType
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import get_args, get_origin
# Copy from Pydantic v2, compatible with v1
if sys.version_info < (3, 9):
# Pydantic no longer supports Python 3.8, this might be incorrect, but the code
# this is used for is also never reached in this codebase, as it's a copy of
# Pydantic's lenient_issubclass, just for compatibility with v1
# TODO: remove when dropping support for Python 3.8
WithArgsTypes: Tuple[Any, ...] = ()
elif sys.version_info < (3, 10):
if sys.version_info < (3, 10):
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # type: ignore[attr-defined]
else:
WithArgsTypes: tuple[Any, ...] = (
@@ -45,26 +33,21 @@ PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
sequence_annotation_to_type = {
Sequence: list,
List: list,
list: list,
Tuple: tuple,
tuple: tuple,
Set: set,
set: set,
FrozenSet: frozenset,
frozenset: frozenset,
Deque: deque,
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())
Url: Type[Any]
Url: type[Any]
# Copy of Pydantic v2, compatible with v1
def lenient_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]
cls: Any, class_or_tuple: Union[type[Any], tuple[type[Any], ...], None]
) -> bool:
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
@@ -74,13 +57,13 @@ def lenient_issubclass(
raise # pragma: no cover
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def _annotation_is_sequence(annotation: Union[type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types) # type: ignore[arg-type]
return lenient_issubclass(annotation, sequence_types)
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_sequence(annotation: Union[type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
@@ -93,20 +76,18 @@ def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes))
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def _annotation_is_complex(annotation: Union[type[Any], None]) -> bool:
return (
lenient_issubclass(
annotation, (BaseModel, may_v1.BaseModel, Mapping, UploadFile)
)
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_complex(annotation: Union[type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
@@ -127,7 +108,7 @@ def field_annotation_is_scalar(annotation: Any) -> bool:
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_scalar_sequence(annotation: Union[type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False
@@ -196,13 +177,27 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
)
def is_pydantic_v1_model_instance(obj: Any) -> bool:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from pydantic import v1
return isinstance(obj, v1.BaseModel)
def is_pydantic_v1_model_class(cls: Any) -> bool:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from pydantic import v1
return lenient_issubclass(cls, v1.BaseModel)
def annotation_is_pydantic_v1(annotation: Any) -> bool:
if lenient_issubclass(annotation, may_v1.BaseModel):
if is_pydantic_v1_model_class(annotation):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, may_v1.BaseModel):
if is_pydantic_v1_model_class(arg):
return True
if field_annotation_is_sequence(annotation):
for sub_annotation in get_args(annotation):

View File

@@ -1,312 +0,0 @@
from copy import copy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from typing import (
Any,
Callable,
Dict,
List,
Sequence,
Set,
Tuple,
Type,
Union,
)
from fastapi._compat import shared
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from fastapi.types import ModelNameMap
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import Literal
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
# Keeping old "Required" functionality from Pydantic V1, without
# shadowing typing.Required.
RequiredParam: Any = Ellipsis
if not PYDANTIC_V2:
from pydantic import BaseConfig as BaseConfig
from pydantic import BaseModel as BaseModel
from pydantic import ValidationError as ValidationError
from pydantic import create_model as create_model
from pydantic.class_validators import Validator as Validator
from pydantic.color import Color as Color
from pydantic.error_wrappers import ErrorWrapper as ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import ( # type: ignore[attr-defined]
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
)
from pydantic.fields import FieldInfo as FieldInfo
from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined]
from pydantic.fields import Undefined as Undefined # type: ignore[attr-defined]
from pydantic.fields import ( # type: ignore[attr-defined]
UndefinedType as UndefinedType,
)
from pydantic.networks import AnyUrl as AnyUrl
from pydantic.networks import NameEmail as NameEmail
from pydantic.schema import TypeModelSet as TypeModelSet
from pydantic.schema import (
field_schema,
model_process_schema,
)
from pydantic.schema import (
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.schema import get_flat_models_from_field as get_flat_models_from_field
from pydantic.schema import (
get_flat_models_from_fields as get_flat_models_from_fields,
)
from pydantic.schema import get_model_name_map as get_model_name_map
from pydantic.types import SecretBytes as SecretBytes
from pydantic.types import SecretStr as SecretStr
from pydantic.typing import evaluate_forwardref as evaluate_forwardref
from pydantic.utils import lenient_issubclass as lenient_issubclass
else:
from pydantic.v1 import BaseConfig as BaseConfig # type: ignore[assignment]
from pydantic.v1 import BaseModel as BaseModel # type: ignore[assignment]
from pydantic.v1 import ( # type: ignore[assignment]
ValidationError as ValidationError,
)
from pydantic.v1 import create_model as create_model # type: ignore[no-redef]
from pydantic.v1.class_validators import Validator as Validator
from pydantic.v1.color import Color as Color # type: ignore[assignment]
from pydantic.v1.error_wrappers import ErrorWrapper as ErrorWrapper
from pydantic.v1.errors import MissingError
from pydantic.v1.fields import (
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
)
from pydantic.v1.fields import FieldInfo as FieldInfo # type: ignore[assignment]
from pydantic.v1.fields import ModelField as ModelField
from pydantic.v1.fields import Undefined as Undefined
from pydantic.v1.fields import UndefinedType as UndefinedType
from pydantic.v1.networks import AnyUrl as AnyUrl
from pydantic.v1.networks import ( # type: ignore[assignment]
NameEmail as NameEmail,
)
from pydantic.v1.schema import TypeModelSet as TypeModelSet
from pydantic.v1.schema import (
field_schema,
model_process_schema,
)
from pydantic.v1.schema import (
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.v1.schema import (
get_flat_models_from_field as get_flat_models_from_field,
)
from pydantic.v1.schema import (
get_flat_models_from_fields as get_flat_models_from_fields,
)
from pydantic.v1.schema import get_model_name_map as get_model_name_map
from pydantic.v1.types import ( # type: ignore[assignment]
SecretBytes as SecretBytes,
)
from pydantic.v1.types import ( # type: ignore[assignment]
SecretStr as SecretStr,
)
from pydantic.v1.typing import evaluate_forwardref as evaluate_forwardref
from pydantic.v1.utils import lenient_issubclass as lenient_issubclass
GetJsonSchemaHandler = Any
JsonSchemaValue = Dict[str, Any]
CoreSchema = Any
Url = AnyUrl
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
SHAPE_FROZENSET,
SHAPE_TUPLE,
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}
sequence_shape_to_type = {
SHAPE_LIST: list,
SHAPE_SET: set,
SHAPE_TUPLE: tuple,
SHAPE_SEQUENCE: list,
SHAPE_TUPLE_ELLIPSIS: list,
}
@dataclass
class GenerateJsonSchema:
ref_template: str
class PydanticSchemaGenerationError(Exception):
pass
RequestErrorModel: Type[BaseModel] = create_model("Request")
def with_info_plain_validator_function(
function: Callable[..., Any],
*,
ref: Union[str, None] = None,
metadata: Any = None,
serialization: Any = None,
) -> Any:
return {}
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict[str, Any]] = {}
for model in flat_models:
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
for m_schema in definitions.values():
if "description" in m_schema:
m_schema["description"] = m_schema["description"].split("\f")[0]
return definitions
def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, dict)
and not shared.field_annotation_is_sequence(field.type_)
and not is_dataclass(field.type_)
and not isinstance(field_info, params.Body)
):
return False
if field.sub_fields:
if not all(is_pv1_scalar_field(f) for f in field.sub_fields):
return False
return True
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
if (field.shape in sequence_shapes) and not lenient_issubclass(
field.type_, BaseModel
):
if field.sub_fields is not None:
for sub_field in field.sub_fields:
if not is_pv1_scalar_field(sub_field):
return False
return True
if shared._annotation_is_sequence(field.type_):
return True
return False
def _model_rebuild(model: Type[BaseModel]) -> None:
model.update_forward_refs()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.dict(**kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.__config__ # type: ignore[attr-defined]
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
return field_schema( # type: ignore[no-any-return]
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0]
# def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
# models = get_flat_models_from_fields(fields, known_models=set())
# return get_model_name_map(models) # type: ignore[no-any-return]
def get_definitions(
*,
fields: List[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Dict[str, Any]],
]:
models = get_flat_models_from_fields(fields, known_models=set())
return {}, get_model_definitions(flat_models=models, model_name_map=model_name_map)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or shared._annotation_is_sequence(field.type_)
def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return]
def is_bytes_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]

View File

@@ -1,24 +1,21 @@
import re
import warnings
from collections.abc import Sequence
from copy import copy, deepcopy
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import lru_cache
from typing import (
Annotated,
Any,
Dict,
List,
Sequence,
Set,
Tuple,
Type,
Union,
cast,
)
from fastapi._compat import may_v1, shared
from fastapi._compat import shared
from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.types import IncEx, ModelNameMap
from pydantic import BaseModel, TypeAdapter, create_model
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
from pydantic import ValidationError as ValidationError
@@ -33,7 +30,7 @@ from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
from typing_extensions import Annotated, Literal, get_args, get_origin
from typing_extensions import Literal, get_args, get_origin
try:
from pydantic_core.core_schema import (
@@ -50,6 +47,45 @@ UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
Validator = Any
# TODO: remove when dropping support for Pydantic < v2.12.3
_Attrs = {
"default": ...,
"default_factory": None,
"alias": None,
"alias_priority": None,
"validation_alias": None,
"serialization_alias": None,
"title": None,
"field_title_generator": None,
"description": None,
"examples": None,
"exclude": None,
"exclude_if": None,
"discriminator": None,
"deprecated": None,
"json_schema_extra": None,
"frozen": None,
"validate_default": None,
"repr": True,
"init": None,
"init_var": None,
"kw_only": None,
}
# TODO: remove when dropping support for Pydantic < v2.12.3
def asdict(field_info: FieldInfo) -> dict[str, Any]:
attributes = {}
for attr in _Attrs:
value = getattr(field_info, attr, Undefined)
if value is not Undefined:
attributes[attr] = value
return {
"annotation": field_info.annotation,
"metadata": field_info.metadata,
"attributes": attributes,
}
class BaseConfig:
pass
@@ -64,12 +100,25 @@ class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
config: Union[ConfigDict, None] = None
@property
def alias(self) -> str:
a = self.field_info.alias
return a if a is not None else self.name
@property
def validation_alias(self) -> Union[str, None]:
va = self.field_info.validation_alias
if isinstance(va, str) and va:
return va
return None
@property
def serialization_alias(self) -> Union[str, None]:
sa = self.field_info.serialization_alias
return sa or None
@property
def required(self) -> bool:
return self.field_info.is_required()
@@ -94,8 +143,19 @@ class ModelField:
warnings.simplefilter(
"ignore", category=UnsupportedFieldAttributeWarning
)
# TODO: remove after dropping support for Python 3.8 and
# setting the min Pydantic to v2.12.3 that adds asdict()
field_dict = asdict(self.field_info)
annotated_args = (
field_dict["annotation"],
*field_dict["metadata"],
# this FieldInfo needs to be created again so that it doesn't include
# the old field info metadata and only the rest of the attributes
Field(**field_dict["attributes"]),
)
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
Annotated[annotated_args],
config=self.config,
)
def get_default(self) -> Any:
@@ -106,17 +166,17 @@ class ModelField:
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
values: dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
loc: tuple[Union[int, str], ...] = (),
) -> tuple[Any, Union[list[dict[str, Any]], None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
None,
)
except ValidationError as exc:
return None, may_v1._regenerate_error_with_loc(
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
@@ -151,44 +211,39 @@ class ModelField:
return id(self)
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Any:
return annotation
def _model_rebuild(model: Type[BaseModel]) -> None:
model.model_rebuild()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.model_dump(mode=mode, **kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.model_config
def _has_computed_fields(field: ModelField) -> bool:
computed_fields = field._type_adapter.core_schema.get("schema", {}).get(
"computed_fields", []
)
return len(computed_fields) > 0
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
) -> dict[str, Any]:
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
None
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
)
field_alias = (
(field.validation_alias or field.alias)
if field.mode == "validation"
else (field.serialization_alias or field.alias)
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema["title"] = field.field_info.title or field.alias.title().replace(
json_schema["title"] = field.field_info.title or field_alias.title().replace(
"_", " "
)
return json_schema
@@ -199,14 +254,11 @@ def get_definitions(
fields: Sequence[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Dict[str, Any]],
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, dict[str, Any]],
]:
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
validation_fields = [field for field in fields if field.mode == "validation"]
serialization_fields = [field for field in fields if field.mode == "serialization"]
flat_validation_models = get_flat_models_from_fields(
@@ -236,13 +288,20 @@ def get_definitions(
unique_flat_model_fields = {
f for f in flat_model_fields if f.type_ not in input_types
}
inputs = [
(field, override_mode or field.mode, field._type_adapter.core_schema)
(
field,
(
field.mode
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
),
field._type_adapter.core_schema,
)
for field in list(fields) + list(unique_flat_model_fields)
]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
for item_def in cast(dict[str, dict[str, Any]], definitions).values():
if "description" in item_def:
item_description = cast(str, item_def["description"]).split("\f")[0]
item_def["description"] = item_description
@@ -256,9 +315,9 @@ def get_definitions(
def _replace_refs(
*,
schema: Dict[str, Any],
old_name_to_new_name_map: Dict[str, str],
) -> Dict[str, Any]:
schema: dict[str, Any],
old_name_to_new_name_map: dict[str, str],
) -> dict[str, Any]:
new_schema = deepcopy(schema)
for key, value in new_schema.items():
if key == "$ref":
@@ -293,18 +352,18 @@ def _replace_refs(
def _remap_definitions_and_field_mappings(
*,
model_name_map: ModelNameMap,
definitions: Dict[str, Any],
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
definitions: dict[str, Any],
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Tuple[
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Any],
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, Any],
]:
old_name_to_new_name_map = {}
for field_key, schema in field_mapping.items():
model = field_key[0].type_
if model not in model_name_map:
if model not in model_name_map or "$ref" not in schema:
continue
new_name = model_name_map[model]
old_name = schema["$ref"].split("/")[-1]
@@ -312,8 +371,8 @@ def _remap_definitions_and_field_mappings(
continue
old_name_to_new_name_map[old_name] = new_name
new_field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
new_field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
] = {}
for field_key, schema in field_mapping.items():
new_schema = _replace_refs(
@@ -371,11 +430,18 @@ def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation
if origin_type is Union or origin_type is UnionType: # Handle optional sequences
union_args = get_args(field.field_info.annotation)
for union_arg in union_args:
if union_arg is type(None):
continue
origin_type = get_origin(union_arg) or union_arg
break
assert issubclass(origin_type, shared.sequence_types) # type: ignore[arg-type]
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
@@ -385,50 +451,67 @@ def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
) -> type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
BodyModel: type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return [
ModelField(field_info=field_info, name=name)
for name, field_info in model.model_fields.items()
]
def get_model_fields(model: type[BaseModel]) -> list[ModelField]:
model_fields: list[ModelField] = []
for name, field_info in model.model_fields.items():
type_ = field_info.annotation
if lenient_issubclass(type_, (BaseModel, dict)) or is_dataclass(type_):
model_config = None
else:
model_config = model.model_config
model_fields.append(
ModelField(
field_info=field_info,
name=name,
config=model_config,
)
)
return model_fields
@lru_cache
def get_cached_model_fields(model: type[BaseModel]) -> list[ModelField]:
return get_model_fields(model) # type: ignore[return-value]
# Duplicate of several schema functions from Pydantic v1 to make them compatible with
# Pydantic v2 and allow mixing the models
TypeModelOrEnum = Union[Type["BaseModel"], Type[Enum]]
TypeModelSet = Set[TypeModelOrEnum]
TypeModelOrEnum = Union[type["BaseModel"], type[Enum]]
TypeModelSet = set[TypeModelOrEnum]
def normalize_name(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
def get_model_name_map(unique_models: TypeModelSet) -> Dict[TypeModelOrEnum, str]:
def get_model_name_map(unique_models: TypeModelSet) -> dict[TypeModelOrEnum, str]:
name_model_map = {}
conflicting_names: Set[str] = set()
for model in unique_models:
model_name = normalize_name(model.__name__)
if model_name in conflicting_names:
model_name = get_long_model_name(model)
name_model_map[model_name] = model
elif model_name in name_model_map:
conflicting_names.add(model_name)
conflicting_model = name_model_map.pop(model_name)
name_model_map[get_long_model_name(conflicting_model)] = conflicting_model
name_model_map[get_long_model_name(model)] = model
else:
name_model_map[model_name] = model
name_model_map[model_name] = model
return {v: k for k, v in name_model_map.items()}
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
all_flat_models = set()
v2_model_fields = [field for field in fields if isinstance(field, ModelField)]
v2_flat_models = get_flat_models_from_fields(v2_model_fields, known_models=set())
all_flat_models = all_flat_models.union(v2_flat_models) # type: ignore[arg-type]
model_name_map = get_model_name_map(all_flat_models) # type: ignore[arg-type]
return model_name_map
def get_flat_models_from_model(
model: Type["BaseModel"], known_models: Union[TypeModelSet, None] = None
model: type["BaseModel"], known_models: Union[TypeModelSet, None] = None
) -> TypeModelSet:
known_models = known_models or set()
fields = get_model_fields(model)
@@ -475,5 +558,11 @@ def get_flat_models_from_fields(
return known_models
def get_long_model_name(model: TypeModelOrEnum) -> str:
return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: tuple[Union[str, int], ...]
) -> list[dict[str, Any]]:
updated_loc_errors: list[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())} for err in errors
]
return updated_loc_errors