增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -1,6 +1,6 @@
"""FastAPI framework, high performance, easy to learn, fast to code, ready for production"""
__version__ = "0.121.2"
__version__ = "0.128.0"
from starlette import status as status

View File

@@ -1,44 +1,9 @@
from .main import BaseConfig as BaseConfig
from .main import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .main import RequiredParam as RequiredParam
from .main import Undefined as Undefined
from .main import UndefinedType as UndefinedType
from .main import Url as Url
from .main import Validator as Validator
from .main import _get_model_config as _get_model_config
from .main import _is_error_wrapper as _is_error_wrapper
from .main import _is_model_class as _is_model_class
from .main import _is_model_field as _is_model_field
from .main import _is_undefined as _is_undefined
from .main import _model_dump as _model_dump
from .main import _model_rebuild as _model_rebuild
from .main import copy_field_info as copy_field_info
from .main import create_body_model as create_body_model
from .main import evaluate_forwardref as evaluate_forwardref
from .main import get_annotation_from_field_info as get_annotation_from_field_info
from .main import get_cached_model_fields as get_cached_model_fields
from .main import get_compat_model_name_map as get_compat_model_name_map
from .main import get_definitions as get_definitions
from .main import get_missing_field_error as get_missing_field_error
from .main import get_schema_from_model_field as get_schema_from_model_field
from .main import is_bytes_field as is_bytes_field
from .main import is_bytes_sequence_field as is_bytes_sequence_field
from .main import is_scalar_field as is_scalar_field
from .main import is_scalar_sequence_field as is_scalar_sequence_field
from .main import is_sequence_field as is_sequence_field
from .main import serialize_sequence_value as serialize_sequence_value
from .main import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
from .may_v1 import CoreSchema as CoreSchema
from .may_v1 import GetJsonSchemaHandler as GetJsonSchemaHandler
from .may_v1 import JsonSchemaValue as JsonSchemaValue
from .may_v1 import _normalize_errors as _normalize_errors
from .model_field import ModelField as ModelField
from .shared import PYDANTIC_V2 as PYDANTIC_V2
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
from .shared import field_annotation_is_scalar as field_annotation_is_scalar
from .shared import is_pydantic_v1_model_class as is_pydantic_v1_model_class
from .shared import is_pydantic_v1_model_instance as is_pydantic_v1_model_instance
from .shared import (
is_uploadfile_or_nonable_uploadfile_annotation as is_uploadfile_or_nonable_uploadfile_annotation,
)
@@ -48,3 +13,29 @@ from .shared import (
from .shared import lenient_issubclass as lenient_issubclass
from .shared import sequence_types as sequence_types
from .shared import value_is_sequence as value_is_sequence
from .v2 import BaseConfig as BaseConfig
from .v2 import ModelField as ModelField
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .v2 import RequiredParam as RequiredParam
from .v2 import Undefined as Undefined
from .v2 import UndefinedType as UndefinedType
from .v2 import Url as Url
from .v2 import Validator as Validator
from .v2 import _regenerate_error_with_loc as _regenerate_error_with_loc
from .v2 import copy_field_info as copy_field_info
from .v2 import create_body_model as create_body_model
from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_cached_model_fields as get_cached_model_fields
from .v2 import get_compat_model_name_map as get_compat_model_name_map
from .v2 import get_definitions as get_definitions
from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import get_schema_from_model_field as get_schema_from_model_field
from .v2 import is_bytes_field as is_bytes_field
from .v2 import is_bytes_sequence_field as is_bytes_sequence_field
from .v2 import is_scalar_field as is_scalar_field
from .v2 import is_scalar_sequence_field as is_scalar_sequence_field
from .v2 import is_sequence_field as is_sequence_field
from .v2 import serialize_sequence_value as serialize_sequence_value
from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function,
)

View File

@@ -1,362 +0,0 @@
import sys
from functools import lru_cache
from typing import (
Any,
Dict,
List,
Sequence,
Tuple,
Type,
)
from fastapi._compat import may_v1
from fastapi._compat.shared import PYDANTIC_V2, lenient_issubclass
from fastapi.types import ModelNameMap
from pydantic import BaseModel
from typing_extensions import Literal
from .model_field import ModelField
if PYDANTIC_V2:
from .v2 import BaseConfig as BaseConfig
from .v2 import FieldInfo as FieldInfo
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from .v2 import RequiredParam as RequiredParam
from .v2 import Undefined as Undefined
from .v2 import UndefinedType as UndefinedType
from .v2 import Url as Url
from .v2 import Validator as Validator
from .v2 import evaluate_forwardref as evaluate_forwardref
from .v2 import get_missing_field_error as get_missing_field_error
from .v2 import (
with_info_plain_validator_function as with_info_plain_validator_function,
)
else:
from .v1 import BaseConfig as BaseConfig # type: ignore[assignment]
from .v1 import FieldInfo as FieldInfo
from .v1 import ( # type: ignore[assignment]
PydanticSchemaGenerationError as PydanticSchemaGenerationError,
)
from .v1 import RequiredParam as RequiredParam
from .v1 import Undefined as Undefined
from .v1 import UndefinedType as UndefinedType
from .v1 import Url as Url # type: ignore[assignment]
from .v1 import Validator as Validator
from .v1 import evaluate_forwardref as evaluate_forwardref
from .v1 import get_missing_field_error as get_missing_field_error
from .v1 import ( # type: ignore[assignment]
with_info_plain_validator_function as with_info_plain_validator_function,
)
@lru_cache
def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
if lenient_issubclass(model, may_v1.BaseModel):
from fastapi._compat import v1
return v1.get_model_fields(model)
else:
from . import v2
return v2.get_model_fields(model) # type: ignore[return-value]
def _is_undefined(value: object) -> bool:
if isinstance(value, may_v1.UndefinedType):
return True
elif PYDANTIC_V2:
from . import v2
return isinstance(value, v2.UndefinedType)
return False
def _get_model_config(model: BaseModel) -> Any:
if isinstance(model, may_v1.BaseModel):
from fastapi._compat import v1
return v1._get_model_config(model)
elif PYDANTIC_V2:
from . import v2
return v2._get_model_config(model)
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
if isinstance(model, may_v1.BaseModel):
from fastapi._compat import v1
return v1._model_dump(model, mode=mode, **kwargs)
elif PYDANTIC_V2:
from . import v2
return v2._model_dump(model, mode=mode, **kwargs)
def _is_error_wrapper(exc: Exception) -> bool:
if isinstance(exc, may_v1.ErrorWrapper):
return True
elif PYDANTIC_V2:
from . import v2
return isinstance(exc, v2.ErrorWrapper)
return False
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
if isinstance(field_info, may_v1.FieldInfo):
from fastapi._compat import v1
return v1.copy_field_info(field_info=field_info, annotation=annotation)
else:
assert PYDANTIC_V2
from . import v2
return v2.copy_field_info(field_info=field_info, annotation=annotation)
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
if fields and isinstance(fields[0], may_v1.ModelField):
from fastapi._compat import v1
return v1.create_body_model(fields=fields, model_name=model_name)
else:
assert PYDANTIC_V2
from . import v2
return v2.create_body_model(fields=fields, model_name=model_name) # type: ignore[arg-type]
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Any:
if isinstance(field_info, may_v1.FieldInfo):
from fastapi._compat import v1
return v1.get_annotation_from_field_info(
annotation=annotation, field_info=field_info, field_name=field_name
)
else:
assert PYDANTIC_V2
from . import v2
return v2.get_annotation_from_field_info(
annotation=annotation, field_info=field_info, field_name=field_name
)
def is_bytes_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_bytes_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_bytes_field(field) # type: ignore[arg-type]
def is_bytes_sequence_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_bytes_sequence_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_bytes_sequence_field(field) # type: ignore[arg-type]
def is_scalar_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_scalar_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_scalar_field(field) # type: ignore[arg-type]
def is_scalar_sequence_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_scalar_sequence_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_scalar_sequence_field(field) # type: ignore[arg-type]
def is_sequence_field(field: ModelField) -> bool:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.is_sequence_field(field)
else:
assert PYDANTIC_V2
from . import v2
return v2.is_sequence_field(field) # type: ignore[arg-type]
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.serialize_sequence_value(field=field, value=value)
else:
assert PYDANTIC_V2
from . import v2
return v2.serialize_sequence_value(field=field, value=value) # type: ignore[arg-type]
def _model_rebuild(model: Type[BaseModel]) -> None:
if lenient_issubclass(model, may_v1.BaseModel):
from fastapi._compat import v1
v1._model_rebuild(model)
elif PYDANTIC_V2:
from . import v2
v2._model_rebuild(model)
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
v1_model_fields = [
field for field in fields if isinstance(field, may_v1.ModelField)
]
if v1_model_fields:
from fastapi._compat import v1
v1_flat_models = v1.get_flat_models_from_fields(
v1_model_fields, known_models=set()
)
all_flat_models = v1_flat_models
else:
all_flat_models = set()
if PYDANTIC_V2:
from . import v2
v2_model_fields = [
field for field in fields if isinstance(field, v2.ModelField)
]
v2_flat_models = v2.get_flat_models_from_fields(
v2_model_fields, known_models=set()
)
all_flat_models = all_flat_models.union(v2_flat_models)
model_name_map = v2.get_model_name_map(all_flat_models)
return model_name_map
from fastapi._compat import v1
model_name_map = v1.get_model_name_map(all_flat_models)
return model_name_map
def get_definitions(
*,
fields: List[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]],
may_v1.JsonSchemaValue,
],
Dict[str, Dict[str, Any]],
]:
if sys.version_info < (3, 14):
v1_fields = [field for field in fields if isinstance(field, may_v1.ModelField)]
v1_field_maps, v1_definitions = may_v1.get_definitions(
fields=v1_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
if not PYDANTIC_V2:
return v1_field_maps, v1_definitions
else:
from . import v2
v2_fields = [field for field in fields if isinstance(field, v2.ModelField)]
v2_field_maps, v2_definitions = v2.get_definitions(
fields=v2_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
all_definitions = {**v1_definitions, **v2_definitions}
all_field_maps = {**v1_field_maps, **v2_field_maps}
return all_field_maps, all_definitions
# Pydantic v1 is not supported since Python 3.14
else:
from . import v2
v2_fields = [field for field in fields if isinstance(field, v2.ModelField)]
v2_field_maps, v2_definitions = v2.get_definitions(
fields=v2_fields,
model_name_map=model_name_map,
separate_input_output_schemas=separate_input_output_schemas,
)
return v2_field_maps, v2_definitions
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]],
may_v1.JsonSchemaValue,
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
if isinstance(field, may_v1.ModelField):
from fastapi._compat import v1
return v1.get_schema_from_model_field(
field=field,
model_name_map=model_name_map,
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
else:
assert PYDANTIC_V2
from . import v2
return v2.get_schema_from_model_field(
field=field, # type: ignore[arg-type]
model_name_map=model_name_map,
field_mapping=field_mapping, # type: ignore[arg-type]
separate_input_output_schemas=separate_input_output_schemas,
)
def _is_model_field(value: Any) -> bool:
if isinstance(value, may_v1.ModelField):
return True
elif PYDANTIC_V2:
from . import v2
return isinstance(value, v2.ModelField)
return False
def _is_model_class(value: Any) -> bool:
if lenient_issubclass(value, may_v1.BaseModel):
return True
elif PYDANTIC_V2:
from . import v2
return lenient_issubclass(value, v2.BaseModel) # type: ignore[attr-defined]
return False

View File

@@ -1,123 +0,0 @@
import sys
from typing import Any, Dict, List, Literal, Sequence, Tuple, Type, Union
from fastapi.types import ModelNameMap
if sys.version_info >= (3, 14):
class AnyUrl:
pass
class BaseConfig:
pass
class BaseModel:
pass
class Color:
pass
class CoreSchema:
pass
class ErrorWrapper:
pass
class FieldInfo:
pass
class GetJsonSchemaHandler:
pass
class JsonSchemaValue:
pass
class ModelField:
pass
class NameEmail:
pass
class RequiredParam:
pass
class SecretBytes:
pass
class SecretStr:
pass
class Undefined:
pass
class UndefinedType:
pass
class Url:
pass
from .v2 import ValidationError, create_model
def get_definitions(
*,
fields: List[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
Dict[str, Dict[str, Any]],
]:
return {}, {} # pragma: no cover
else:
from .v1 import AnyUrl as AnyUrl
from .v1 import BaseConfig as BaseConfig
from .v1 import BaseModel as BaseModel
from .v1 import Color as Color
from .v1 import CoreSchema as CoreSchema
from .v1 import ErrorWrapper as ErrorWrapper
from .v1 import FieldInfo as FieldInfo
from .v1 import GetJsonSchemaHandler as GetJsonSchemaHandler
from .v1 import JsonSchemaValue as JsonSchemaValue
from .v1 import ModelField as ModelField
from .v1 import NameEmail as NameEmail
from .v1 import RequiredParam as RequiredParam
from .v1 import SecretBytes as SecretBytes
from .v1 import SecretStr as SecretStr
from .v1 import Undefined as Undefined
from .v1 import UndefinedType as UndefinedType
from .v1 import Url as Url
from .v1 import ValidationError, create_model
from .v1 import get_definitions as get_definitions
RequestErrorModel: Type[BaseModel] = create_model("Request")
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
use_errors: List[Any] = []
for error in errors:
if isinstance(error, ErrorWrapper):
new_errors = ValidationError( # type: ignore[call-arg]
errors=[error], model=RequestErrorModel
).errors()
use_errors.extend(new_errors)
elif isinstance(error, list):
use_errors.extend(_normalize_errors(error))
else:
use_errors.append(error)
return use_errors
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
) -> List[Dict[str, Any]]:
updated_loc_errors: List[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())}
for err in _normalize_errors(errors)
]
return updated_loc_errors

View File

@@ -1,53 +0,0 @@
from typing import (
Any,
Dict,
List,
Tuple,
Union,
)
from fastapi.types import IncEx
from pydantic.fields import FieldInfo
from typing_extensions import Literal, Protocol
class ModelField(Protocol):
field_info: "FieldInfo"
name: str
mode: Literal["validation", "serialization"] = "validation"
_version: Literal["v1", "v2"] = "v1"
@property
def alias(self) -> str: ...
@property
def required(self) -> bool: ...
@property
def default(self) -> Any: ...
@property
def type_(self) -> Any: ...
def get_default(self) -> Any: ...
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]: ...
def serialize(
self,
value: Any,
*,
mode: Literal["json", "python"] = "json",
include: Union[IncEx, None] = None,
exclude: Union[IncEx, None] = None,
by_alias: bool = True,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any: ...

View File

@@ -1,36 +1,24 @@
import sys
import types
import typing
import warnings
from collections import deque
from collections.abc import Mapping, Sequence
from dataclasses import is_dataclass
from typing import (
Annotated,
Any,
Deque,
FrozenSet,
List,
Mapping,
Sequence,
Set,
Tuple,
Type,
Union,
)
from fastapi._compat import may_v1
from fastapi.types import UnionType
from pydantic import BaseModel
from pydantic.version import VERSION as PYDANTIC_VERSION
from starlette.datastructures import UploadFile
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import get_args, get_origin
# Copy from Pydantic v2, compatible with v1
if sys.version_info < (3, 9):
# Pydantic no longer supports Python 3.8, this might be incorrect, but the code
# this is used for is also never reached in this codebase, as it's a copy of
# Pydantic's lenient_issubclass, just for compatibility with v1
# TODO: remove when dropping support for Python 3.8
WithArgsTypes: Tuple[Any, ...] = ()
elif sys.version_info < (3, 10):
if sys.version_info < (3, 10):
WithArgsTypes: tuple[Any, ...] = (typing._GenericAlias, types.GenericAlias) # type: ignore[attr-defined]
else:
WithArgsTypes: tuple[Any, ...] = (
@@ -45,26 +33,21 @@ PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
sequence_annotation_to_type = {
Sequence: list,
List: list,
list: list,
Tuple: tuple,
tuple: tuple,
Set: set,
set: set,
FrozenSet: frozenset,
frozenset: frozenset,
Deque: deque,
deque: deque,
}
sequence_types = tuple(sequence_annotation_to_type.keys())
Url: Type[Any]
Url: type[Any]
# Copy of Pydantic v2, compatible with v1
def lenient_issubclass(
cls: Any, class_or_tuple: Union[Type[Any], Tuple[Type[Any], ...], None]
cls: Any, class_or_tuple: Union[type[Any], tuple[type[Any], ...], None]
) -> bool:
try:
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
@@ -74,13 +57,13 @@ def lenient_issubclass(
raise # pragma: no cover
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def _annotation_is_sequence(annotation: Union[type[Any], None]) -> bool:
if lenient_issubclass(annotation, (str, bytes)):
return False
return lenient_issubclass(annotation, sequence_types) # type: ignore[arg-type]
return lenient_issubclass(annotation, sequence_types)
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_sequence(annotation: Union[type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
@@ -93,20 +76,18 @@ def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
def value_is_sequence(value: Any) -> bool:
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes))
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def _annotation_is_complex(annotation: Union[type[Any], None]) -> bool:
return (
lenient_issubclass(
annotation, (BaseModel, may_v1.BaseModel, Mapping, UploadFile)
)
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_complex(annotation: Union[type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
@@ -127,7 +108,7 @@ def field_annotation_is_scalar(annotation: Any) -> bool:
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
def field_annotation_is_scalar_sequence(annotation: Union[type[Any], None]) -> bool:
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
at_least_one_scalar_sequence = False
@@ -196,13 +177,27 @@ def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
)
def is_pydantic_v1_model_instance(obj: Any) -> bool:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from pydantic import v1
return isinstance(obj, v1.BaseModel)
def is_pydantic_v1_model_class(cls: Any) -> bool:
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
from pydantic import v1
return lenient_issubclass(cls, v1.BaseModel)
def annotation_is_pydantic_v1(annotation: Any) -> bool:
if lenient_issubclass(annotation, may_v1.BaseModel):
if is_pydantic_v1_model_class(annotation):
return True
origin = get_origin(annotation)
if origin is Union or origin is UnionType:
for arg in get_args(annotation):
if lenient_issubclass(arg, may_v1.BaseModel):
if is_pydantic_v1_model_class(arg):
return True
if field_annotation_is_sequence(annotation):
for sub_annotation in get_args(annotation):

View File

@@ -1,312 +0,0 @@
from copy import copy
from dataclasses import dataclass, is_dataclass
from enum import Enum
from typing import (
Any,
Callable,
Dict,
List,
Sequence,
Set,
Tuple,
Type,
Union,
)
from fastapi._compat import shared
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
from fastapi.types import ModelNameMap
from pydantic.version import VERSION as PYDANTIC_VERSION
from typing_extensions import Literal
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
# Keeping old "Required" functionality from Pydantic V1, without
# shadowing typing.Required.
RequiredParam: Any = Ellipsis
if not PYDANTIC_V2:
from pydantic import BaseConfig as BaseConfig
from pydantic import BaseModel as BaseModel
from pydantic import ValidationError as ValidationError
from pydantic import create_model as create_model
from pydantic.class_validators import Validator as Validator
from pydantic.color import Color as Color
from pydantic.error_wrappers import ErrorWrapper as ErrorWrapper
from pydantic.errors import MissingError
from pydantic.fields import ( # type: ignore[attr-defined]
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
)
from pydantic.fields import FieldInfo as FieldInfo
from pydantic.fields import ModelField as ModelField # type: ignore[attr-defined]
from pydantic.fields import Undefined as Undefined # type: ignore[attr-defined]
from pydantic.fields import ( # type: ignore[attr-defined]
UndefinedType as UndefinedType,
)
from pydantic.networks import AnyUrl as AnyUrl
from pydantic.networks import NameEmail as NameEmail
from pydantic.schema import TypeModelSet as TypeModelSet
from pydantic.schema import (
field_schema,
model_process_schema,
)
from pydantic.schema import (
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.schema import get_flat_models_from_field as get_flat_models_from_field
from pydantic.schema import (
get_flat_models_from_fields as get_flat_models_from_fields,
)
from pydantic.schema import get_model_name_map as get_model_name_map
from pydantic.types import SecretBytes as SecretBytes
from pydantic.types import SecretStr as SecretStr
from pydantic.typing import evaluate_forwardref as evaluate_forwardref
from pydantic.utils import lenient_issubclass as lenient_issubclass
else:
from pydantic.v1 import BaseConfig as BaseConfig # type: ignore[assignment]
from pydantic.v1 import BaseModel as BaseModel # type: ignore[assignment]
from pydantic.v1 import ( # type: ignore[assignment]
ValidationError as ValidationError,
)
from pydantic.v1 import create_model as create_model # type: ignore[no-redef]
from pydantic.v1.class_validators import Validator as Validator
from pydantic.v1.color import Color as Color # type: ignore[assignment]
from pydantic.v1.error_wrappers import ErrorWrapper as ErrorWrapper
from pydantic.v1.errors import MissingError
from pydantic.v1.fields import (
SHAPE_FROZENSET,
SHAPE_LIST,
SHAPE_SEQUENCE,
SHAPE_SET,
SHAPE_SINGLETON,
SHAPE_TUPLE,
SHAPE_TUPLE_ELLIPSIS,
)
from pydantic.v1.fields import FieldInfo as FieldInfo # type: ignore[assignment]
from pydantic.v1.fields import ModelField as ModelField
from pydantic.v1.fields import Undefined as Undefined
from pydantic.v1.fields import UndefinedType as UndefinedType
from pydantic.v1.networks import AnyUrl as AnyUrl
from pydantic.v1.networks import ( # type: ignore[assignment]
NameEmail as NameEmail,
)
from pydantic.v1.schema import TypeModelSet as TypeModelSet
from pydantic.v1.schema import (
field_schema,
model_process_schema,
)
from pydantic.v1.schema import (
get_annotation_from_field_info as get_annotation_from_field_info,
)
from pydantic.v1.schema import (
get_flat_models_from_field as get_flat_models_from_field,
)
from pydantic.v1.schema import (
get_flat_models_from_fields as get_flat_models_from_fields,
)
from pydantic.v1.schema import get_model_name_map as get_model_name_map
from pydantic.v1.types import ( # type: ignore[assignment]
SecretBytes as SecretBytes,
)
from pydantic.v1.types import ( # type: ignore[assignment]
SecretStr as SecretStr,
)
from pydantic.v1.typing import evaluate_forwardref as evaluate_forwardref
from pydantic.v1.utils import lenient_issubclass as lenient_issubclass
GetJsonSchemaHandler = Any
JsonSchemaValue = Dict[str, Any]
CoreSchema = Any
Url = AnyUrl
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
SHAPE_FROZENSET,
SHAPE_TUPLE,
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}
sequence_shape_to_type = {
SHAPE_LIST: list,
SHAPE_SET: set,
SHAPE_TUPLE: tuple,
SHAPE_SEQUENCE: list,
SHAPE_TUPLE_ELLIPSIS: list,
}
@dataclass
class GenerateJsonSchema:
ref_template: str
class PydanticSchemaGenerationError(Exception):
pass
RequestErrorModel: Type[BaseModel] = create_model("Request")
def with_info_plain_validator_function(
function: Callable[..., Any],
*,
ref: Union[str, None] = None,
metadata: Any = None,
serialization: Any = None,
) -> Any:
return {}
def get_model_definitions(
*,
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Dict[str, Any]:
definitions: Dict[str, Dict[str, Any]] = {}
for model in flat_models:
m_schema, m_definitions, m_nested_models = model_process_schema(
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
definitions.update(m_definitions)
model_name = model_name_map[model]
definitions[model_name] = m_schema
for m_schema in definitions.values():
if "description" in m_schema:
m_schema["description"] = m_schema["description"].split("\f")[0]
return definitions
def is_pv1_scalar_field(field: ModelField) -> bool:
from fastapi import params
field_info = field.field_info
if not (
field.shape == SHAPE_SINGLETON
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, dict)
and not shared.field_annotation_is_sequence(field.type_)
and not is_dataclass(field.type_)
and not isinstance(field_info, params.Body)
):
return False
if field.sub_fields:
if not all(is_pv1_scalar_field(f) for f in field.sub_fields):
return False
return True
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
if (field.shape in sequence_shapes) and not lenient_issubclass(
field.type_, BaseModel
):
if field.sub_fields is not None:
for sub_field in field.sub_fields:
if not is_pv1_scalar_field(sub_field):
return False
return True
if shared._annotation_is_sequence(field.type_):
return True
return False
def _model_rebuild(model: Type[BaseModel]) -> None:
model.update_forward_refs()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.dict(**kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.__config__ # type: ignore[attr-defined]
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
return field_schema( # type: ignore[no-any-return]
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)[0]
# def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
# models = get_flat_models_from_fields(fields, known_models=set())
# return get_model_name_map(models) # type: ignore[no-any-return]
def get_definitions(
*,
fields: List[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Dict[str, Any]],
]:
models = get_flat_models_from_fields(fields, known_models=set())
return {}, get_model_definitions(flat_models=models, model_name_map=model_name_map)
def is_scalar_field(field: ModelField) -> bool:
return is_pv1_scalar_field(field)
def is_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes or shared._annotation_is_sequence(field.type_)
def is_scalar_sequence_field(field: ModelField) -> bool:
return is_pv1_scalar_sequence_field(field)
def is_bytes_field(field: ModelField) -> bool:
return lenient_issubclass(field.type_, bytes) # type: ignore[no-any-return]
def is_bytes_sequence_field(field: ModelField) -> bool:
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes)
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
return copy(field_info)
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
missing_field_error = ErrorWrapper(MissingError(), loc=loc)
new_error = ValidationError([missing_field_error], RequestErrorModel)
return new_error.errors()[0] # type: ignore[return-value]
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
BodyModel = create_model(model_name)
for f in fields:
BodyModel.__fields__[f.name] = f # type: ignore[index]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return list(model.__fields__.values()) # type: ignore[attr-defined]

View File

@@ -1,24 +1,21 @@
import re
import warnings
from collections.abc import Sequence
from copy import copy, deepcopy
from dataclasses import dataclass
from dataclasses import dataclass, is_dataclass
from enum import Enum
from functools import lru_cache
from typing import (
Annotated,
Any,
Dict,
List,
Sequence,
Set,
Tuple,
Type,
Union,
cast,
)
from fastapi._compat import may_v1, shared
from fastapi._compat import shared
from fastapi.openapi.constants import REF_TEMPLATE
from fastapi.types import IncEx, ModelNameMap
from pydantic import BaseModel, TypeAdapter, create_model
from fastapi.types import IncEx, ModelNameMap, UnionType
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
from pydantic import ValidationError as ValidationError
@@ -33,7 +30,7 @@ from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
from pydantic_core import CoreSchema as CoreSchema
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from pydantic_core import Url as Url
from typing_extensions import Annotated, Literal, get_args, get_origin
from typing_extensions import Literal, get_args, get_origin
try:
from pydantic_core.core_schema import (
@@ -50,6 +47,45 @@ UndefinedType = PydanticUndefinedType
evaluate_forwardref = eval_type_lenient
Validator = Any
# TODO: remove when dropping support for Pydantic < v2.12.3
_Attrs = {
"default": ...,
"default_factory": None,
"alias": None,
"alias_priority": None,
"validation_alias": None,
"serialization_alias": None,
"title": None,
"field_title_generator": None,
"description": None,
"examples": None,
"exclude": None,
"exclude_if": None,
"discriminator": None,
"deprecated": None,
"json_schema_extra": None,
"frozen": None,
"validate_default": None,
"repr": True,
"init": None,
"init_var": None,
"kw_only": None,
}
# TODO: remove when dropping support for Pydantic < v2.12.3
def asdict(field_info: FieldInfo) -> dict[str, Any]:
attributes = {}
for attr in _Attrs:
value = getattr(field_info, attr, Undefined)
if value is not Undefined:
attributes[attr] = value
return {
"annotation": field_info.annotation,
"metadata": field_info.metadata,
"attributes": attributes,
}
class BaseConfig:
pass
@@ -64,12 +100,25 @@ class ModelField:
field_info: FieldInfo
name: str
mode: Literal["validation", "serialization"] = "validation"
config: Union[ConfigDict, None] = None
@property
def alias(self) -> str:
a = self.field_info.alias
return a if a is not None else self.name
@property
def validation_alias(self) -> Union[str, None]:
va = self.field_info.validation_alias
if isinstance(va, str) and va:
return va
return None
@property
def serialization_alias(self) -> Union[str, None]:
sa = self.field_info.serialization_alias
return sa or None
@property
def required(self) -> bool:
return self.field_info.is_required()
@@ -94,8 +143,19 @@ class ModelField:
warnings.simplefilter(
"ignore", category=UnsupportedFieldAttributeWarning
)
# TODO: remove after dropping support for Python 3.8 and
# setting the min Pydantic to v2.12.3 that adds asdict()
field_dict = asdict(self.field_info)
annotated_args = (
field_dict["annotation"],
*field_dict["metadata"],
# this FieldInfo needs to be created again so that it doesn't include
# the old field info metadata and only the rest of the attributes
Field(**field_dict["attributes"]),
)
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
Annotated[self.field_info.annotation, self.field_info]
Annotated[annotated_args],
config=self.config,
)
def get_default(self) -> Any:
@@ -106,17 +166,17 @@ class ModelField:
def validate(
self,
value: Any,
values: Dict[str, Any] = {}, # noqa: B006
values: dict[str, Any] = {}, # noqa: B006
*,
loc: Tuple[Union[int, str], ...] = (),
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
loc: tuple[Union[int, str], ...] = (),
) -> tuple[Any, Union[list[dict[str, Any]], None]]:
try:
return (
self._type_adapter.validate_python(value, from_attributes=True),
None,
)
except ValidationError as exc:
return None, may_v1._regenerate_error_with_loc(
return None, _regenerate_error_with_loc(
errors=exc.errors(include_url=False), loc_prefix=loc
)
@@ -151,44 +211,39 @@ class ModelField:
return id(self)
def get_annotation_from_field_info(
annotation: Any, field_info: FieldInfo, field_name: str
) -> Any:
return annotation
def _model_rebuild(model: Type[BaseModel]) -> None:
model.model_rebuild()
def _model_dump(
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
) -> Any:
return model.model_dump(mode=mode, **kwargs)
def _get_model_config(model: BaseModel) -> Any:
return model.model_config
def _has_computed_fields(field: ModelField) -> bool:
computed_fields = field._type_adapter.core_schema.get("schema", {}).get(
"computed_fields", []
)
return len(computed_fields) > 0
def get_schema_from_model_field(
*,
field: ModelField,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
separate_input_output_schemas: bool = True,
) -> Dict[str, Any]:
) -> dict[str, Any]:
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
None
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
)
field_alias = (
(field.validation_alias or field.alias)
if field.mode == "validation"
else (field.serialization_alias or field.alias)
)
# This expects that GenerateJsonSchema was already used to generate the definitions
json_schema = field_mapping[(field, override_mode or field.mode)]
if "$ref" not in json_schema:
# TODO remove when deprecating Pydantic v1
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
json_schema["title"] = field.field_info.title or field.alias.title().replace(
json_schema["title"] = field.field_info.title or field_alias.title().replace(
"_", " "
)
return json_schema
@@ -199,14 +254,11 @@ def get_definitions(
fields: Sequence[ModelField],
model_name_map: ModelNameMap,
separate_input_output_schemas: bool = True,
) -> Tuple[
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Dict[str, Any]],
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, dict[str, Any]],
]:
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
override_mode: Union[Literal["validation"], None] = (
None if separate_input_output_schemas else "validation"
)
validation_fields = [field for field in fields if field.mode == "validation"]
serialization_fields = [field for field in fields if field.mode == "serialization"]
flat_validation_models = get_flat_models_from_fields(
@@ -236,13 +288,20 @@ def get_definitions(
unique_flat_model_fields = {
f for f in flat_model_fields if f.type_ not in input_types
}
inputs = [
(field, override_mode or field.mode, field._type_adapter.core_schema)
(
field,
(
field.mode
if (separate_input_output_schemas or _has_computed_fields(field))
else "validation"
),
field._type_adapter.core_schema,
)
for field in list(fields) + list(unique_flat_model_fields)
]
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
for item_def in cast(Dict[str, Dict[str, Any]], definitions).values():
for item_def in cast(dict[str, dict[str, Any]], definitions).values():
if "description" in item_def:
item_description = cast(str, item_def["description"]).split("\f")[0]
item_def["description"] = item_description
@@ -256,9 +315,9 @@ def get_definitions(
def _replace_refs(
*,
schema: Dict[str, Any],
old_name_to_new_name_map: Dict[str, str],
) -> Dict[str, Any]:
schema: dict[str, Any],
old_name_to_new_name_map: dict[str, str],
) -> dict[str, Any]:
new_schema = deepcopy(schema)
for key, value in new_schema.items():
if key == "$ref":
@@ -293,18 +352,18 @@ def _replace_refs(
def _remap_definitions_and_field_mappings(
*,
model_name_map: ModelNameMap,
definitions: Dict[str, Any],
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
definitions: dict[str, Any],
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
],
) -> Tuple[
Dict[Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
Dict[str, Any],
) -> tuple[
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
dict[str, Any],
]:
old_name_to_new_name_map = {}
for field_key, schema in field_mapping.items():
model = field_key[0].type_
if model not in model_name_map:
if model not in model_name_map or "$ref" not in schema:
continue
new_name = model_name_map[model]
old_name = schema["$ref"].split("/")[-1]
@@ -312,8 +371,8 @@ def _remap_definitions_and_field_mappings(
continue
old_name_to_new_name_map[old_name] = new_name
new_field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
new_field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
] = {}
for field_key, schema in field_mapping.items():
new_schema = _replace_refs(
@@ -371,11 +430,18 @@ def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation
if origin_type is Union or origin_type is UnionType: # Handle optional sequences
union_args = get_args(field.field_info.annotation)
for union_arg in union_args:
if union_arg is type(None):
continue
origin_type = get_origin(union_arg) or union_arg
break
assert issubclass(origin_type, shared.sequence_types) # type: ignore[arg-type]
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index]
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
).errors(include_url=False)[0]
@@ -385,50 +451,67 @@ def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
def create_body_model(
*, fields: Sequence[ModelField], model_name: str
) -> Type[BaseModel]:
) -> type[BaseModel]:
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
BodyModel: type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
return BodyModel
def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
return [
ModelField(field_info=field_info, name=name)
for name, field_info in model.model_fields.items()
]
def get_model_fields(model: type[BaseModel]) -> list[ModelField]:
model_fields: list[ModelField] = []
for name, field_info in model.model_fields.items():
type_ = field_info.annotation
if lenient_issubclass(type_, (BaseModel, dict)) or is_dataclass(type_):
model_config = None
else:
model_config = model.model_config
model_fields.append(
ModelField(
field_info=field_info,
name=name,
config=model_config,
)
)
return model_fields
@lru_cache
def get_cached_model_fields(model: type[BaseModel]) -> list[ModelField]:
return get_model_fields(model) # type: ignore[return-value]
# Duplicate of several schema functions from Pydantic v1 to make them compatible with
# Pydantic v2 and allow mixing the models
TypeModelOrEnum = Union[Type["BaseModel"], Type[Enum]]
TypeModelSet = Set[TypeModelOrEnum]
TypeModelOrEnum = Union[type["BaseModel"], type[Enum]]
TypeModelSet = set[TypeModelOrEnum]
def normalize_name(name: str) -> str:
return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
def get_model_name_map(unique_models: TypeModelSet) -> Dict[TypeModelOrEnum, str]:
def get_model_name_map(unique_models: TypeModelSet) -> dict[TypeModelOrEnum, str]:
name_model_map = {}
conflicting_names: Set[str] = set()
for model in unique_models:
model_name = normalize_name(model.__name__)
if model_name in conflicting_names:
model_name = get_long_model_name(model)
name_model_map[model_name] = model
elif model_name in name_model_map:
conflicting_names.add(model_name)
conflicting_model = name_model_map.pop(model_name)
name_model_map[get_long_model_name(conflicting_model)] = conflicting_model
name_model_map[get_long_model_name(model)] = model
else:
name_model_map[model_name] = model
name_model_map[model_name] = model
return {v: k for k, v in name_model_map.items()}
def get_compat_model_name_map(fields: list[ModelField]) -> ModelNameMap:
all_flat_models = set()
v2_model_fields = [field for field in fields if isinstance(field, ModelField)]
v2_flat_models = get_flat_models_from_fields(v2_model_fields, known_models=set())
all_flat_models = all_flat_models.union(v2_flat_models) # type: ignore[arg-type]
model_name_map = get_model_name_map(all_flat_models) # type: ignore[arg-type]
return model_name_map
def get_flat_models_from_model(
model: Type["BaseModel"], known_models: Union[TypeModelSet, None] = None
model: type["BaseModel"], known_models: Union[TypeModelSet, None] = None
) -> TypeModelSet:
known_models = known_models or set()
fields = get_model_fields(model)
@@ -475,5 +558,11 @@ def get_flat_models_from_fields(
return known_models
def get_long_model_name(model: TypeModelOrEnum) -> str:
return f"{model.__module__}__{model.__qualname__}".replace(".", "__")
def _regenerate_error_with_loc(
*, errors: Sequence[Any], loc_prefix: tuple[Union[str, int], ...]
) -> list[dict[str, Any]]:
updated_loc_errors: list[Any] = [
{**err, "loc": loc_prefix + err.get("loc", ())} for err in errors
]
return updated_loc_errors

View File

@@ -1,14 +1,10 @@
from collections.abc import Awaitable, Coroutine, Sequence
from enum import Enum
from typing import (
Annotated,
Any,
Awaitable,
Callable,
Coroutine,
Dict,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
@@ -44,7 +40,7 @@ from starlette.requests import Request
from starlette.responses import HTMLResponse, JSONResponse, Response
from starlette.routing import BaseRoute
from starlette.types import ASGIApp, ExceptionHandler, Lifespan, Receive, Scope, Send
from typing_extensions import Annotated, deprecated
from typing_extensions import deprecated
AppType = TypeVar("AppType", bound="FastAPI")
@@ -81,7 +77,7 @@ class FastAPI(Starlette):
),
] = False,
routes: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
**Note**: you probably shouldn't use this parameter, it is inherited
@@ -230,7 +226,7 @@ class FastAPI(Starlette):
),
] = "/openapi.json",
openapi_tags: Annotated[
Optional[List[Dict[str, Any]]],
Optional[list[dict[str, Any]]],
Doc(
"""
A list of tags used by OpenAPI, these are the same `tags` you can set
@@ -290,7 +286,7 @@ class FastAPI(Starlette):
),
] = None,
servers: Annotated[
Optional[List[Dict[str, Union[str, Any]]]],
Optional[list[dict[str, Union[str, Any]]]],
Doc(
"""
A `list` of `dict`s with connectivity information to a target server.
@@ -301,7 +297,12 @@ class FastAPI(Starlette):
browser tabs open). Or if you want to leave fixed the possible URLs.
If the servers `list` is not provided, or is an empty `list`, the
default value would be a `dict` with a `url` value of `/`.
`servers` property in the generated OpenAPI will be:
* a `dict` with a `url` value of the application's mounting point
(`root_path`) if it's different from `/`.
* otherwise, the `servers` property will be omitted from the OpenAPI
schema.
Each item in the `list` is a `dict` containing:
@@ -356,7 +357,7 @@ class FastAPI(Starlette):
),
] = None,
default_response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
The default response class to be used.
@@ -462,7 +463,7 @@ class FastAPI(Starlette):
),
] = "/docs/oauth2-redirect",
swagger_ui_init_oauth: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
OAuth2 configuration for the Swagger UI, by default shown at `/docs`.
@@ -488,8 +489,8 @@ class FastAPI(Starlette):
] = None,
exception_handlers: Annotated[
Optional[
Dict[
Union[int, Type[Exception]],
dict[
Union[int, type[Exception]],
Callable[[Request, Any], Coroutine[Any, Any, Response]],
]
],
@@ -562,7 +563,7 @@ class FastAPI(Starlette):
),
] = None,
contact: Annotated[
Optional[Dict[str, Union[str, Any]]],
Optional[dict[str, Union[str, Any]]],
Doc(
"""
A dictionary with the contact information for the exposed API.
@@ -595,7 +596,7 @@ class FastAPI(Starlette):
),
] = None,
license_info: Annotated[
Optional[Dict[str, Union[str, Any]]],
Optional[dict[str, Union[str, Any]]],
Doc(
"""
A dictionary with the license information for the exposed API.
@@ -684,7 +685,7 @@ class FastAPI(Starlette):
),
] = True,
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses to be shown in OpenAPI.
@@ -700,7 +701,7 @@ class FastAPI(Starlette):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
OpenAPI callbacks that should apply to all *path operations*.
@@ -757,7 +758,7 @@ class FastAPI(Starlette):
),
] = True,
swagger_ui_parameters: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Parameters to configure Swagger UI, the autogenerated interactive API
@@ -815,7 +816,7 @@ class FastAPI(Starlette):
),
] = True,
openapi_external_docs: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
This field allows you to provide additional external documentation links.
@@ -901,7 +902,7 @@ class FastAPI(Starlette):
"""
),
] = "3.1.0"
self.openapi_schema: Optional[Dict[str, Any]] = None
self.openapi_schema: Optional[dict[str, Any]] = None
if self.openapi_url:
assert self.title, "A title must be provided for OpenAPI, e.g.: 'My API'"
assert self.version, "A version must be provided for OpenAPI, e.g.: '2.1.0'"
@@ -944,7 +945,7 @@ class FastAPI(Starlette):
),
] = State()
self.dependency_overrides: Annotated[
Dict[Callable[..., Any], Callable[..., Any]],
dict[Callable[..., Any], Callable[..., Any]],
Doc(
"""
A dictionary with overrides for the dependencies.
@@ -975,7 +976,7 @@ class FastAPI(Starlette):
responses=responses,
generate_unique_id_function=generate_unique_id_function,
)
self.exception_handlers: Dict[
self.exception_handlers: dict[
Any, Callable[[Request, Any], Union[Response, Awaitable[Response]]]
] = {} if exception_handlers is None else dict(exception_handlers)
self.exception_handlers.setdefault(HTTPException, http_exception_handler)
@@ -988,7 +989,7 @@ class FastAPI(Starlette):
websocket_request_validation_exception_handler, # type: ignore
)
self.user_middleware: List[Middleware] = (
self.user_middleware: list[Middleware] = (
[] if middleware is None else list(middleware)
)
self.middleware_stack: Union[ASGIApp, None] = None
@@ -1042,7 +1043,7 @@ class FastAPI(Starlette):
app = cls(app, *args, **kwargs)
return app
def openapi(self) -> Dict[str, Any]:
def openapi(self) -> dict[str, Any]:
"""
Generate the OpenAPI schema of the application. This is called by FastAPI
internally.
@@ -1140,14 +1141,14 @@ class FastAPI(Starlette):
*,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
tags: Optional[list[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
responses: Optional[dict[Union[int, str], dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[List[str]] = None,
methods: Optional[list[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[IncEx] = None,
response_model_exclude: Optional[IncEx] = None,
@@ -1156,11 +1157,11 @@ class FastAPI(Starlette):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
response_class: Union[type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
name: Optional[str] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
openapi_extra: Optional[dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
generate_unique_id
),
@@ -1198,14 +1199,14 @@ class FastAPI(Starlette):
*,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
tags: Optional[list[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
responses: Optional[dict[Union[int, str], dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[List[str]] = None,
methods: Optional[list[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[IncEx] = None,
response_model_exclude: Optional[IncEx] = None,
@@ -1214,9 +1215,9 @@ class FastAPI(Starlette):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
response_class: type[Response] = Default(JSONResponse),
name: Optional[str] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
openapi_extra: Optional[dict[str, Any]] = None,
generate_unique_id_function: Callable[[routing.APIRoute], str] = Default(
generate_unique_id
),
@@ -1338,7 +1339,7 @@ class FastAPI(Starlette):
*,
prefix: Annotated[str, Doc("An optional path prefix for the router.")] = "",
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to all the *path operations* in this
@@ -1380,7 +1381,7 @@ class FastAPI(Starlette):
),
] = None,
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses to be shown in OpenAPI.
@@ -1447,7 +1448,7 @@ class FastAPI(Starlette):
),
] = True,
default_response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Default response class to be used for the *path operations* in this
@@ -1475,7 +1476,7 @@ class FastAPI(Starlette):
),
] = Default(JSONResponse),
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -1598,7 +1599,7 @@ class FastAPI(Starlette):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -1664,7 +1665,7 @@ class FastAPI(Starlette):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -1805,7 +1806,7 @@ class FastAPI(Starlette):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -1826,7 +1827,7 @@ class FastAPI(Starlette):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -1842,7 +1843,7 @@ class FastAPI(Starlette):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -1971,7 +1972,7 @@ class FastAPI(Starlette):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -2037,7 +2038,7 @@ class FastAPI(Starlette):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -2178,7 +2179,7 @@ class FastAPI(Starlette):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -2199,7 +2200,7 @@ class FastAPI(Starlette):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -2215,7 +2216,7 @@ class FastAPI(Starlette):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -2349,7 +2350,7 @@ class FastAPI(Starlette):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -2415,7 +2416,7 @@ class FastAPI(Starlette):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -2556,7 +2557,7 @@ class FastAPI(Starlette):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -2577,7 +2578,7 @@ class FastAPI(Starlette):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -2593,7 +2594,7 @@ class FastAPI(Starlette):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -2727,7 +2728,7 @@ class FastAPI(Starlette):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -2793,7 +2794,7 @@ class FastAPI(Starlette):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -2934,7 +2935,7 @@ class FastAPI(Starlette):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -2955,7 +2956,7 @@ class FastAPI(Starlette):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -2971,7 +2972,7 @@ class FastAPI(Starlette):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -3100,7 +3101,7 @@ class FastAPI(Starlette):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -3166,7 +3167,7 @@ class FastAPI(Starlette):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -3307,7 +3308,7 @@ class FastAPI(Starlette):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -3328,7 +3329,7 @@ class FastAPI(Starlette):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -3344,7 +3345,7 @@ class FastAPI(Starlette):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -3473,7 +3474,7 @@ class FastAPI(Starlette):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -3539,7 +3540,7 @@ class FastAPI(Starlette):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -3680,7 +3681,7 @@ class FastAPI(Starlette):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -3701,7 +3702,7 @@ class FastAPI(Starlette):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -3717,7 +3718,7 @@ class FastAPI(Starlette):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -3846,7 +3847,7 @@ class FastAPI(Starlette):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -3912,7 +3913,7 @@ class FastAPI(Starlette):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -4053,7 +4054,7 @@ class FastAPI(Starlette):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -4074,7 +4075,7 @@ class FastAPI(Starlette):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -4090,7 +4091,7 @@ class FastAPI(Starlette):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -4224,7 +4225,7 @@ class FastAPI(Starlette):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -4290,7 +4291,7 @@ class FastAPI(Starlette):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -4431,7 +4432,7 @@ class FastAPI(Starlette):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -4452,7 +4453,7 @@ class FastAPI(Starlette):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -4468,7 +4469,7 @@ class FastAPI(Starlette):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -4623,7 +4624,7 @@ class FastAPI(Starlette):
def exception_handler(
self,
exc_class_or_status_code: Annotated[
Union[int, Type[Exception]],
Union[int, type[Exception]],
Doc(
"""
The Exception class this would handle, or a status code.

View File

@@ -1,8 +1,8 @@
from typing import Any, Callable
from typing import Annotated, Any, Callable
from annotated_doc import Doc
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
from typing_extensions import Annotated, ParamSpec
from typing_extensions import ParamSpec
P = ParamSpec("P")

View File

@@ -1,5 +1,7 @@
from collections.abc import AsyncGenerator
from contextlib import AbstractContextManager
from contextlib import asynccontextmanager as asynccontextmanager
from typing import AsyncGenerator, ContextManager, TypeVar
from typing import TypeVar
import anyio.to_thread
from anyio import CapacityLimiter
@@ -14,7 +16,7 @@ _T = TypeVar("_T")
@asynccontextmanager
async def contextmanager_in_threadpool(
cm: ContextManager[_T],
cm: AbstractContextManager[_T],
) -> AsyncGenerator[_T, None]:
# blocking __exit__ from running waiting on a free thread
# can create race conditions/deadlocks if the context manager itself

View File

@@ -1,21 +1,16 @@
from collections.abc import Mapping
from typing import (
Annotated,
Any,
BinaryIO,
Callable,
Dict,
Iterable,
Optional,
Type,
TypeVar,
cast,
)
from annotated_doc import Doc
from fastapi._compat import (
CoreSchema,
GetJsonSchemaHandler,
JsonSchemaValue,
)
from pydantic import GetJsonSchemaHandler
from starlette.datastructures import URL as URL # noqa: F401
from starlette.datastructures import Address as Address # noqa: F401
from starlette.datastructures import FormData as FormData # noqa: F401
@@ -23,7 +18,6 @@ from starlette.datastructures import Headers as Headers # noqa: F401
from starlette.datastructures import QueryParams as QueryParams # noqa: F401
from starlette.datastructures import State as State # noqa: F401
from starlette.datastructures import UploadFile as StarletteUploadFile
from typing_extensions import Annotated
class UploadFile(StarletteUploadFile):
@@ -137,37 +131,22 @@ class UploadFile(StarletteUploadFile):
"""
return await super().close()
@classmethod
def __get_validators__(cls: Type["UploadFile"]) -> Iterable[Callable[..., Any]]:
yield cls.validate
@classmethod
def validate(cls: Type["UploadFile"], v: Any) -> Any:
if not isinstance(v, StarletteUploadFile):
raise ValueError(f"Expected UploadFile, received: {type(v)}")
return v
@classmethod
def _validate(cls, __input_value: Any, _: Any) -> "UploadFile":
if not isinstance(__input_value, StarletteUploadFile):
raise ValueError(f"Expected UploadFile, received: {type(__input_value)}")
return cast(UploadFile, __input_value)
# TODO: remove when deprecating Pydantic v1
@classmethod
def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
field_schema.update({"type": "string", "format": "binary"})
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
cls, core_schema: Mapping[str, Any], handler: GetJsonSchemaHandler
) -> dict[str, Any]:
return {"type": "string", "format": "binary"}
@classmethod
def __get_pydantic_core_schema__(
cls, source: Type[Any], handler: Callable[[Any], CoreSchema]
) -> CoreSchema:
cls, source: type[Any], handler: Callable[[Any], Mapping[str, Any]]
) -> Mapping[str, Any]:
from ._compat.v2 import with_info_plain_validator_function
return with_info_plain_validator_function(cls._validate)

View File

@@ -1,8 +1,8 @@
import inspect
import sys
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any, Callable, List, Optional, Sequence, Union
from functools import cached_property, partial
from typing import Any, Callable, Optional, Union
from fastapi._compat import ModelField
from fastapi.security.base import SecurityBase
@@ -15,21 +15,27 @@ else: # pragma: no cover
from asyncio import iscoroutinefunction
@dataclass
class SecurityRequirement:
security_scheme: SecurityBase
scopes: Optional[Sequence[str]] = None
def _unwrapped_call(call: Optional[Callable[..., Any]]) -> Any:
if call is None:
return call # pragma: no cover
unwrapped = inspect.unwrap(_impartial(call))
return unwrapped
def _impartial(func: Callable[..., Any]) -> Callable[..., Any]:
while isinstance(func, partial):
func = func.func
return func
@dataclass
class Dependant:
path_params: List[ModelField] = field(default_factory=list)
query_params: List[ModelField] = field(default_factory=list)
header_params: List[ModelField] = field(default_factory=list)
cookie_params: List[ModelField] = field(default_factory=list)
body_params: List[ModelField] = field(default_factory=list)
dependencies: List["Dependant"] = field(default_factory=list)
security_requirements: List[SecurityRequirement] = field(default_factory=list)
path_params: list[ModelField] = field(default_factory=list)
query_params: list[ModelField] = field(default_factory=list)
header_params: list[ModelField] = field(default_factory=list)
cookie_params: list[ModelField] = field(default_factory=list)
body_params: list[ModelField] = field(default_factory=list)
dependencies: list["Dependant"] = field(default_factory=list)
name: Optional[str] = None
call: Optional[Callable[..., Any]] = None
request_param_name: Optional[str] = None
@@ -38,41 +44,145 @@ class Dependant:
response_param_name: Optional[str] = None
background_tasks_param_name: Optional[str] = None
security_scopes_param_name: Optional[str] = None
security_scopes: Optional[List[str]] = None
own_oauth_scopes: Optional[list[str]] = None
parent_oauth_scopes: Optional[list[str]] = None
use_cache: bool = True
path: Optional[str] = None
scope: Union[Literal["function", "request"], None] = None
@cached_property
def oauth_scopes(self) -> list[str]:
scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
# This doesn't use a set to preserve order, just in case
for scope in self.own_oauth_scopes or []:
if scope not in scopes:
scopes.append(scope)
return scopes
@cached_property
def cache_key(self) -> DependencyCacheKey:
scopes_for_cache = (
tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
)
return (
self.call,
tuple(sorted(set(self.security_scopes or []))),
scopes_for_cache,
self.computed_scope or "",
)
@cached_property
def is_gen_callable(self) -> bool:
if inspect.isgeneratorfunction(self.call):
def _uses_scopes(self) -> bool:
if self.own_oauth_scopes:
return True
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
return inspect.isgeneratorfunction(dunder_call)
if self.security_scopes_param_name is not None:
return True
if self._is_security_scheme:
return True
for sub_dep in self.dependencies:
if sub_dep._uses_scopes:
return True
return False
@cached_property
def _is_security_scheme(self) -> bool:
if self.call is None:
return False # pragma: no cover
unwrapped = _unwrapped_call(self.call)
return isinstance(unwrapped, SecurityBase)
# Mainly to get the type of SecurityBase, but it's the same self.call
@cached_property
def _security_scheme(self) -> SecurityBase:
unwrapped = _unwrapped_call(self.call)
assert isinstance(unwrapped, SecurityBase)
return unwrapped
@cached_property
def _security_dependencies(self) -> list["Dependant"]:
security_deps = [dep for dep in self.dependencies if dep._is_security_scheme]
return security_deps
@cached_property
def is_gen_callable(self) -> bool:
if self.call is None:
return False # pragma: no cover
if inspect.isgeneratorfunction(
_impartial(self.call)
) or inspect.isgeneratorfunction(_unwrapped_call(self.call)):
return True
if inspect.isclass(_unwrapped_call(self.call)):
return False
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
if dunder_call is None:
return False # pragma: no cover
if inspect.isgeneratorfunction(
_impartial(dunder_call)
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)):
return True
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
if dunder_unwrapped_call is None:
return False # pragma: no cover
if inspect.isgeneratorfunction(
_impartial(dunder_unwrapped_call)
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)):
return True
return False
@cached_property
def is_async_gen_callable(self) -> bool:
if inspect.isasyncgenfunction(self.call):
if self.call is None:
return False # pragma: no cover
if inspect.isasyncgenfunction(
_impartial(self.call)
) or inspect.isasyncgenfunction(_unwrapped_call(self.call)):
return True
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
return inspect.isasyncgenfunction(dunder_call)
if inspect.isclass(_unwrapped_call(self.call)):
return False
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
if dunder_call is None:
return False # pragma: no cover
if inspect.isasyncgenfunction(
_impartial(dunder_call)
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)):
return True
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
if dunder_unwrapped_call is None:
return False # pragma: no cover
if inspect.isasyncgenfunction(
_impartial(dunder_unwrapped_call)
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)):
return True
return False
@cached_property
def is_coroutine_callable(self) -> bool:
if inspect.isroutine(self.call):
return iscoroutinefunction(self.call)
if inspect.isclass(self.call):
if self.call is None:
return False # pragma: no cover
if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction(
_impartial(self.call)
):
return True
if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction(
_unwrapped_call(self.call)
):
return True
if inspect.isclass(_unwrapped_call(self.call)):
return False
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
return iscoroutinefunction(dunder_call)
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
if dunder_call is None:
return False # pragma: no cover
if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction(
_unwrapped_call(dunder_call)
):
return True
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
if dunder_unwrapped_call is None:
return False # pragma: no cover
if iscoroutinefunction(
_impartial(dunder_unwrapped_call)
) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)):
return True
return False
@cached_property
def computed_scope(self) -> Union[str, None]:

View File

@@ -1,19 +1,16 @@
import dataclasses
import inspect
import sys
from collections.abc import Coroutine, Mapping, Sequence
from contextlib import AsyncExitStack, contextmanager
from copy import copy, deepcopy
from dataclasses import dataclass
from typing import (
Annotated,
Any,
Callable,
Coroutine,
Dict,
ForwardRef,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)
@@ -21,17 +18,14 @@ from typing import (
import anyio
from fastapi import params
from fastapi._compat import (
PYDANTIC_V2,
ModelField,
RequiredParam,
Undefined,
_is_error_wrapper,
_is_model_class,
_regenerate_error_with_loc,
copy_field_info,
create_body_model,
evaluate_forwardref,
field_annotation_is_scalar,
get_annotation_from_field_info,
get_cached_model_fields,
get_missing_field_error,
is_bytes_field,
@@ -42,23 +36,19 @@ from fastapi._compat import (
is_uploadfile_or_nonable_uploadfile_annotation,
is_uploadfile_sequence_annotation,
lenient_issubclass,
may_v1,
sequence_types,
serialize_sequence_value,
value_is_sequence,
)
from fastapi._compat.shared import annotation_is_pydantic_v1
from fastapi.background import BackgroundTasks
from fastapi.concurrency import (
asynccontextmanager,
contextmanager_in_threadpool,
)
from fastapi.dependencies.models import Dependant, SecurityRequirement
from fastapi.dependencies.models import Dependant
from fastapi.exceptions import DependencyScopeError
from fastapi.logger import logger
from fastapi.security.base import SecurityBase
from fastapi.security.oauth2 import OAuth2, SecurityScopes
from fastapi.security.open_id_connect_url import OpenIdConnect
from fastapi.security.oauth2 import SecurityScopes
from fastapi.types import DependencyCacheKey
from fastapi.utils import create_model_field, get_path_param_names
from pydantic import BaseModel
@@ -75,9 +65,7 @@ from starlette.datastructures import (
from starlette.requests import HTTPConnection, Request
from starlette.responses import Response
from starlette.websockets import WebSocket
from typing_extensions import Annotated, Literal, get_args, get_origin
from .. import temp_pydantic_v1_params
from typing_extensions import Literal, get_args, get_origin
multipart_not_installed_error = (
'Form data requires "python-multipart" to be installed. \n'
@@ -125,14 +113,14 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
assert callable(depends.dependency), (
"A parameter-less dependency must have a callable dependency"
)
use_security_scopes: List[str] = []
own_oauth_scopes: list[str] = []
if isinstance(depends, params.Security) and depends.scopes:
use_security_scopes.extend(depends.scopes)
own_oauth_scopes.extend(depends.scopes)
return get_dependant(
path=path,
call=depends.dependency,
scope=depends.scope,
security_scopes=use_security_scopes,
own_oauth_scopes=own_oauth_scopes,
)
@@ -140,11 +128,15 @@ def get_flat_dependant(
dependant: Dependant,
*,
skip_repeats: bool = False,
visited: Optional[List[DependencyCacheKey]] = None,
visited: Optional[list[DependencyCacheKey]] = None,
parent_oauth_scopes: Optional[list[str]] = None,
) -> Dependant:
if visited is None:
visited = []
visited.append(dependant.cache_key)
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
dependant.oauth_scopes or []
)
flat_dependant = Dependant(
path_params=dependant.path_params.copy(),
@@ -152,36 +144,51 @@ def get_flat_dependant(
header_params=dependant.header_params.copy(),
cookie_params=dependant.cookie_params.copy(),
body_params=dependant.body_params.copy(),
security_requirements=dependant.security_requirements.copy(),
name=dependant.name,
call=dependant.call,
request_param_name=dependant.request_param_name,
websocket_param_name=dependant.websocket_param_name,
http_connection_param_name=dependant.http_connection_param_name,
response_param_name=dependant.response_param_name,
background_tasks_param_name=dependant.background_tasks_param_name,
security_scopes_param_name=dependant.security_scopes_param_name,
own_oauth_scopes=dependant.own_oauth_scopes,
parent_oauth_scopes=use_parent_oauth_scopes,
use_cache=dependant.use_cache,
path=dependant.path,
scope=dependant.scope,
)
for sub_dependant in dependant.dependencies:
if skip_repeats and sub_dependant.cache_key in visited:
continue
flat_sub = get_flat_dependant(
sub_dependant, skip_repeats=skip_repeats, visited=visited
sub_dependant,
skip_repeats=skip_repeats,
visited=visited,
parent_oauth_scopes=flat_dependant.oauth_scopes,
)
flat_dependant.dependencies.append(flat_sub)
flat_dependant.path_params.extend(flat_sub.path_params)
flat_dependant.query_params.extend(flat_sub.query_params)
flat_dependant.header_params.extend(flat_sub.header_params)
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
flat_dependant.body_params.extend(flat_sub.body_params)
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
flat_dependant.dependencies.extend(flat_sub.dependencies)
return flat_dependant
def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]:
if not fields:
return fields
first_field = fields[0]
if len(fields) == 1 and _is_model_class(first_field.type_):
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
fields_to_extract = get_cached_model_fields(first_field.type_)
return fields_to_extract
return fields
def get_flat_params(dependant: Dependant) -> List[ModelField]:
def get_flat_params(dependant: Dependant) -> list[ModelField]:
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
@@ -190,9 +197,23 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
return path_params + query_params + header_params + cookie_params
def _get_signature(call: Callable[..., Any]) -> inspect.Signature:
if sys.version_info >= (3, 10):
try:
signature = inspect.signature(call, eval_str=True)
except NameError:
# Handle type annotations with if TYPE_CHECKING, not used by FastAPI
# e.g. dependency return types
signature = inspect.signature(call)
else:
signature = inspect.signature(call)
return signature
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
signature = _get_signature(call)
unwrapped = inspect.unwrap(call)
globalns = getattr(unwrapped, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
@@ -206,7 +227,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
return typed_signature
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
@@ -216,13 +237,14 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
signature = inspect.signature(call)
signature = _get_signature(call)
unwrapped = inspect.unwrap(call)
annotation = signature.return_annotation
if annotation is inspect.Signature.empty:
return None
globalns = getattr(call, "__globals__", {})
globalns = getattr(unwrapped, "__globals__", {})
return get_typed_annotation(annotation, globalns)
@@ -231,7 +253,8 @@ def get_dependant(
path: str,
call: Callable[..., Any],
name: Optional[str] = None,
security_scopes: Optional[List[str]] = None,
own_oauth_scopes: Optional[list[str]] = None,
parent_oauth_scopes: Optional[list[str]] = None,
use_cache: bool = True,
scope: Union[Literal["function", "request"], None] = None,
) -> Dependant:
@@ -239,21 +262,15 @@ def get_dependant(
call=call,
name=name,
path=path,
security_scopes=security_scopes,
use_cache=use_cache,
scope=scope,
own_oauth_scopes=own_oauth_scopes,
parent_oauth_scopes=parent_oauth_scopes,
)
current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
path_param_names = get_path_param_names(path)
endpoint_signature = get_typed_signature(call)
signature_params = endpoint_signature.parameters
if isinstance(call, SecurityBase):
use_scopes: List[str] = []
if isinstance(call, (OAuth2, OpenIdConnect)):
use_scopes = security_scopes or use_scopes
security_requirement = SecurityRequirement(
security_scheme=call, scopes=use_scopes
)
dependant.security_requirements.append(security_requirement)
for param_name, param in signature_params.items():
is_path_param = param_name in path_param_names
param_details = analyze_param(
@@ -274,15 +291,16 @@ def get_dependant(
f'The dependency "{dependant.call.__name__}" has a scope of '
'"request", it cannot depend on dependencies with scope "function".'
)
use_security_scopes = security_scopes or []
sub_own_oauth_scopes: list[str] = []
if isinstance(param_details.depends, params.Security):
if param_details.depends.scopes:
use_security_scopes.extend(param_details.depends.scopes)
sub_own_oauth_scopes = list(param_details.depends.scopes)
sub_dependant = get_dependant(
path=path,
call=param_details.depends.dependency,
name=param_name,
security_scopes=use_security_scopes,
own_oauth_scopes=sub_own_oauth_scopes,
parent_oauth_scopes=current_scopes,
use_cache=param_details.depends.use_cache,
scope=param_details.depends.scope,
)
@@ -298,9 +316,7 @@ def get_dependant(
)
continue
assert param_details.field is not None
if isinstance(
param_details.field.field_info, (params.Body, temp_pydantic_v1_params.Body)
):
if isinstance(param_details.field.field_info, params.Body):
dependant.body_params.append(param_details.field)
else:
add_param_to_fields(field=param_details.field, dependant=dependant)
@@ -359,7 +375,7 @@ def analyze_param(
fastapi_annotations = [
arg
for arg in annotated_args[1:]
if isinstance(arg, (FieldInfo, may_v1.FieldInfo, params.Depends))
if isinstance(arg, (FieldInfo, params.Depends))
]
fastapi_specific_annotations = [
arg
@@ -368,29 +384,27 @@ def analyze_param(
arg,
(
params.Param,
temp_pydantic_v1_params.Param,
params.Body,
temp_pydantic_v1_params.Body,
params.Depends,
),
)
]
if fastapi_specific_annotations:
fastapi_annotation: Union[
FieldInfo, may_v1.FieldInfo, params.Depends, None
] = fastapi_specific_annotations[-1]
fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
fastapi_specific_annotations[-1]
)
else:
fastapi_annotation = None
# Set default for Annotated FieldInfo
if isinstance(fastapi_annotation, (FieldInfo, may_v1.FieldInfo)):
if isinstance(fastapi_annotation, FieldInfo):
# Copy `field_info` because we mutate `field_info.default` below.
field_info = copy_field_info(
field_info=fastapi_annotation, annotation=use_annotation
field_info=fastapi_annotation, # type: ignore[arg-type]
annotation=use_annotation,
)
assert field_info.default in {
Undefined,
may_v1.Undefined,
} or field_info.default in {RequiredParam, may_v1.RequiredParam}, (
assert (
field_info.default == Undefined or field_info.default == RequiredParam
), (
f"`{field_info.__class__.__name__}` default value cannot be set in"
f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
)
@@ -414,21 +428,20 @@ def analyze_param(
)
depends = value
# Get FieldInfo from default value
elif isinstance(value, (FieldInfo, may_v1.FieldInfo)):
elif isinstance(value, FieldInfo):
assert field_info is None, (
"Cannot specify FastAPI annotations in `Annotated` and default value"
f" together for {param_name!r}"
)
field_info = value
if PYDANTIC_V2:
if isinstance(field_info, FieldInfo):
field_info.annotation = type_annotation
field_info = value # type: ignore[assignment]
if isinstance(field_info, FieldInfo):
field_info.annotation = type_annotation
# Get Depends from type annotation
if depends is not None and depends.dependency is None:
# Copy `depends` before mutating it
depends = copy(depends)
depends.dependency = type_annotation
depends = dataclasses.replace(depends, dependency=type_annotation)
# Handle non-param type annotations like Request
if lenient_issubclass(
@@ -459,14 +472,7 @@ def analyze_param(
) or is_uploadfile_sequence_annotation(type_annotation):
field_info = params.File(annotation=use_annotation, default=default_value)
elif not field_annotation_is_scalar(annotation=type_annotation):
if annotation_is_pydantic_v1(use_annotation):
field_info = temp_pydantic_v1_params.Body(
annotation=use_annotation, default=default_value
)
else:
field_info = params.Body(
annotation=use_annotation, default=default_value
)
field_info = params.Body(annotation=use_annotation, default=default_value)
else:
field_info = params.Query(annotation=use_annotation, default=default_value)
@@ -475,23 +481,17 @@ def analyze_param(
if field_info is not None:
# Handle field_info.in_
if is_path_param:
assert isinstance(
field_info, (params.Path, temp_pydantic_v1_params.Path)
), (
assert isinstance(field_info, params.Path), (
f"Cannot use `{field_info.__class__.__name__}` for path param"
f" {param_name!r}"
)
elif (
isinstance(field_info, (params.Param, temp_pydantic_v1_params.Param))
isinstance(field_info, params.Param)
and getattr(field_info, "in_", None) is None
):
field_info.in_ = params.ParamTypes.query
use_annotation_from_field_info = get_annotation_from_field_info(
use_annotation,
field_info,
param_name,
)
if isinstance(field_info, (params.Form, temp_pydantic_v1_params.Form)):
use_annotation_from_field_info = use_annotation
if isinstance(field_info, params.Form):
ensure_multipart_is_installed()
if not field_info.alias and getattr(field_info, "convert_underscores", None):
alias = param_name.replace("_", "-")
@@ -503,20 +503,19 @@ def analyze_param(
type_=use_annotation_from_field_info,
default=field_info.default,
alias=alias,
required=field_info.default
in (RequiredParam, may_v1.RequiredParam, Undefined),
required=field_info.default in (RequiredParam, Undefined),
field_info=field_info,
)
if is_path_param:
assert is_scalar_field(field=field), (
"Path params must be of one of the supported types"
)
elif isinstance(field_info, (params.Query, temp_pydantic_v1_params.Query)):
elif isinstance(field_info, params.Query):
assert (
is_scalar_field(field)
or is_scalar_sequence_field(field)
or (
_is_model_class(field.type_)
lenient_issubclass(field.type_, BaseModel)
# For Pydantic v1
and getattr(field, "shape", 1) == 1
)
@@ -542,34 +541,34 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
async def _solve_generator(
*, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
*, dependant: Dependant, stack: AsyncExitStack, sub_values: dict[str, Any]
) -> Any:
assert dependant.call
if dependant.is_gen_callable:
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
elif dependant.is_async_gen_callable:
if dependant.is_async_gen_callable:
cm = asynccontextmanager(dependant.call)(**sub_values)
elif dependant.is_gen_callable:
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
return await stack.enter_async_context(cm)
@dataclass
class SolvedDependency:
values: Dict[str, Any]
errors: List[Any]
values: dict[str, Any]
errors: list[Any]
background_tasks: Optional[StarletteBackgroundTasks]
response: Response
dependency_cache: Dict[DependencyCacheKey, Any]
dependency_cache: dict[DependencyCacheKey, Any]
async def solve_dependencies(
*,
request: Union[Request, WebSocket],
dependant: Dependant,
body: Optional[Union[Dict[str, Any], FormData]] = None,
body: Optional[Union[dict[str, Any], FormData]] = None,
background_tasks: Optional[StarletteBackgroundTasks] = None,
response: Optional[Response] = None,
dependency_overrides_provider: Optional[Any] = None,
dependency_cache: Optional[Dict[DependencyCacheKey, Any]] = None,
dependency_cache: Optional[dict[DependencyCacheKey, Any]] = None,
# TODO: remove this parameter later, no longer used, not removing it yet as some
# people might be monkey patching this function (although that's not supported)
async_exit_stack: AsyncExitStack,
@@ -583,8 +582,8 @@ async def solve_dependencies(
assert isinstance(function_astack, AsyncExitStack), (
"fastapi_function_astack not found in request scope"
)
values: Dict[str, Any] = {}
errors: List[Any] = []
values: dict[str, Any] = {}
errors: list[Any] = []
if response is None:
response = Response()
del response.headers["content-length"]
@@ -608,7 +607,7 @@ async def solve_dependencies(
path=use_path,
call=call,
name=sub_dependant.name,
security_scopes=sub_dependant.security_scopes,
parent_oauth_scopes=sub_dependant.oauth_scopes,
scope=sub_dependant.scope,
)
@@ -690,7 +689,7 @@ async def solve_dependencies(
values[dependant.response_param_name] = response
if dependant.security_scopes_param_name:
values[dependant.security_scopes_param_name] = SecurityScopes(
scopes=dependant.security_scopes
scopes=dependant.oauth_scopes
)
return SolvedDependency(
values=values,
@@ -702,18 +701,16 @@ async def solve_dependencies(
def _validate_value_with_model_field(
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
) -> Tuple[Any, List[Any]]:
*, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
) -> tuple[Any, list[Any]]:
if value is None:
if field.required:
return None, [get_missing_field_error(loc=loc)]
else:
return deepcopy(field.default), []
v_, errors_ = field.validate(value, values, loc=loc)
if _is_error_wrapper(errors_): # type: ignore[arg-type]
return None, [errors_]
elif isinstance(errors_, list):
new_errors = may_v1._regenerate_error_with_loc(errors=errors_, loc_prefix=())
if isinstance(errors_, list):
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
return None, new_errors
else:
return v_, []
@@ -722,7 +719,7 @@ def _validate_value_with_model_field(
def _get_multidict_value(
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
) -> Any:
alias = alias or field.alias
alias = alias or get_validation_alias(field)
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
value = values.getlist(alias)
else:
@@ -730,7 +727,7 @@ def _get_multidict_value(
if (
value is None
or (
isinstance(field.field_info, (params.Form, temp_pydantic_v1_params.Form))
isinstance(field.field_info, params.Form)
and isinstance(value, str) # For type checks
and value == ""
)
@@ -746,9 +743,9 @@ def _get_multidict_value(
def request_params_to_args(
fields: Sequence[ModelField],
received_params: Union[Mapping[str, Any], QueryParams, Headers],
) -> Tuple[Dict[str, Any], List[Any]]:
values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = []
) -> tuple[dict[str, Any], list[Any]]:
values: dict[str, Any] = {}
errors: list[dict[str, Any]] = []
if not fields:
return values, errors
@@ -766,7 +763,7 @@ def request_params_to_args(
first_field.field_info, "convert_underscores", True
)
params_to_process: Dict[str, Any] = {}
params_to_process: dict[str, Any] = {}
processed_keys = set()
@@ -779,27 +776,31 @@ def request_params_to_args(
field.field_info, "convert_underscores", default_convert_underscores
)
if convert_underscores:
alias = (
field.alias
if field.alias != field.name
else field.name.replace("_", "-")
)
alias = get_validation_alias(field)
if alias == field.name:
alias = alias.replace("_", "-")
value = _get_multidict_value(field, received_params, alias=alias)
if value is not None:
params_to_process[field.name] = value
processed_keys.add(alias or field.alias)
processed_keys.add(field.name)
params_to_process[get_validation_alias(field)] = value
processed_keys.add(alias or get_validation_alias(field))
for key, value in received_params.items():
for key in received_params.keys():
if key not in processed_keys:
params_to_process[key] = value
if hasattr(received_params, "getlist"):
value = received_params.getlist(key)
if isinstance(value, list) and (len(value) == 1):
params_to_process[key] = value[0]
else:
params_to_process[key] = value
else:
params_to_process[key] = received_params.get(key)
if single_not_embedded_field:
field_info = first_field.field_info
assert isinstance(field_info, (params.Param, temp_pydantic_v1_params.Param)), (
assert isinstance(field_info, params.Param), (
"Params must be subclasses of Param"
)
loc: Tuple[str, ...] = (field_info.in_.value,)
loc: tuple[str, ...] = (field_info.in_.value,)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=params_to_process, values=values, loc=loc
)
@@ -808,10 +809,10 @@ def request_params_to_args(
for field in fields:
value = _get_multidict_value(field, received_params)
field_info = field.field_info
assert isinstance(field_info, (params.Param, temp_pydantic_v1_params.Param)), (
assert isinstance(field_info, params.Param), (
"Params must be subclasses of Param"
)
loc = (field_info.in_.value, field.alias)
loc = (field_info.in_.value, get_validation_alias(field))
v_, errors_ = _validate_value_with_model_field(
field=field, value=value, values=values, loc=loc
)
@@ -835,13 +836,13 @@ def is_union_of_base_models(field_type: Any) -> bool:
union_args = get_args(field_type)
for arg in union_args:
if not _is_model_class(arg):
if not lenient_issubclass(arg, BaseModel):
return False
return True
def _should_embed_body_fields(fields: List[ModelField]) -> bool:
def _should_embed_body_fields(fields: list[ModelField]) -> bool:
if not fields:
return False
# More than one dependency could have the same field, it would show up as multiple
@@ -857,8 +858,8 @@ def _should_embed_body_fields(fields: List[ModelField]) -> bool:
# If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
# otherwise it has to be embedded, so that the key value pair can be extracted
if (
isinstance(first_field.field_info, (params.Form, temp_pydantic_v1_params.Form))
and not _is_model_class(first_field.type_)
isinstance(first_field.field_info, params.Form)
and not lenient_issubclass(first_field.type_, BaseModel)
and not is_union_of_base_models(first_field.type_)
):
return True
@@ -866,28 +867,28 @@ def _should_embed_body_fields(fields: List[ModelField]) -> bool:
async def _extract_form_body(
body_fields: List[ModelField],
body_fields: list[ModelField],
received_body: FormData,
) -> Dict[str, Any]:
) -> dict[str, Any]:
values = {}
for field in body_fields:
value = _get_multidict_value(field, received_body)
field_info = field.field_info
if (
isinstance(field_info, (params.File, temp_pydantic_v1_params.File))
isinstance(field_info, params.File)
and is_bytes_field(field)
and isinstance(value, UploadFile)
):
value = await value.read()
elif (
is_bytes_sequence_field(field)
and isinstance(field_info, (params.File, temp_pydantic_v1_params.File))
and isinstance(field_info, params.File)
and value_is_sequence(value)
):
# For types
assert isinstance(value, sequence_types) # type: ignore[arg-type]
results: List[Union[bytes, str]] = []
assert isinstance(value, sequence_types)
results: list[Union[bytes, str]] = []
async def process_fn(
fn: Callable[[], Coroutine[Any, Any, Any]],
@@ -900,30 +901,35 @@ async def _extract_form_body(
tg.start_soon(process_fn, sub_value.read)
value = serialize_sequence_value(field=field, value=results)
if value is not None:
values[field.alias] = value
for key, value in received_body.items():
if key not in values:
values[key] = value
values[get_validation_alias(field)] = value
field_aliases = {get_validation_alias(field) for field in body_fields}
for key in received_body.keys():
if key not in field_aliases:
param_values = received_body.getlist(key)
if len(param_values) == 1:
values[key] = param_values[0]
else:
values[key] = param_values
return values
async def request_body_to_args(
body_fields: List[ModelField],
received_body: Optional[Union[Dict[str, Any], FormData]],
body_fields: list[ModelField],
received_body: Optional[Union[dict[str, Any], FormData]],
embed_body_fields: bool,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
values: Dict[str, Any] = {}
errors: List[Dict[str, Any]] = []
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
values: dict[str, Any] = {}
errors: list[dict[str, Any]] = []
assert body_fields, "request_body_to_args() should be called with fields"
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
first_field = body_fields[0]
body_to_process = received_body
fields_to_extract: List[ModelField] = body_fields
fields_to_extract: list[ModelField] = body_fields
if (
single_not_embedded_field
and _is_model_class(first_field.type_)
and lenient_issubclass(first_field.type_, BaseModel)
and isinstance(received_body, FormData)
):
fields_to_extract = get_cached_model_fields(first_field.type_)
@@ -932,17 +938,17 @@ async def request_body_to_args(
body_to_process = await _extract_form_body(fields_to_extract, received_body)
if single_not_embedded_field:
loc: Tuple[str, ...] = ("body",)
loc: tuple[str, ...] = ("body",)
v_, errors_ = _validate_value_with_model_field(
field=first_field, value=body_to_process, values=values, loc=loc
)
return {first_field.name: v_}, errors_
for field in body_fields:
loc = ("body", field.alias)
loc = ("body", get_validation_alias(field))
value: Optional[Any] = None
if body_to_process is not None:
try:
value = body_to_process.get(field.alias)
value = body_to_process.get(get_validation_alias(field))
# If the received body is a list, not a dict
except AttributeError:
errors.append(get_missing_field_error(loc))
@@ -980,36 +986,23 @@ def get_body_field(
fields=flat_dependant.body_params, model_name=model_name
)
required = any(True for f in flat_dependant.body_params if f.required)
BodyFieldInfo_kwargs: Dict[str, Any] = {
BodyFieldInfo_kwargs: dict[str, Any] = {
"annotation": BodyModel,
"alias": "body",
}
if not required:
BodyFieldInfo_kwargs["default"] = None
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
BodyFieldInfo: Type[params.Body] = params.File
elif any(
isinstance(f.field_info, temp_pydantic_v1_params.File)
for f in flat_dependant.body_params
):
BodyFieldInfo: Type[temp_pydantic_v1_params.Body] = temp_pydantic_v1_params.File # type: ignore[no-redef]
BodyFieldInfo: type[params.Body] = params.File
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
BodyFieldInfo = params.Form
elif any(
isinstance(f.field_info, temp_pydantic_v1_params.Form)
for f in flat_dependant.body_params
):
BodyFieldInfo = temp_pydantic_v1_params.Form # type: ignore[assignment]
else:
if annotation_is_pydantic_v1(BodyModel):
BodyFieldInfo = temp_pydantic_v1_params.Body # type: ignore[assignment]
else:
BodyFieldInfo = params.Body
BodyFieldInfo = params.Body
body_param_media_types = [
f.field_info.media_type
for f in flat_dependant.body_params
if isinstance(f.field_info, (params.Body, temp_pydantic_v1_params.Body))
if isinstance(f.field_info, params.Body)
]
if len(set(body_param_media_types)) == 1:
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
@@ -1021,3 +1014,8 @@ def get_body_field(
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
)
return final_field
def get_validation_alias(field: ModelField) -> str:
va = getattr(field, "validation_alias", None)
return va or field.alias

View File

@@ -14,19 +14,22 @@ from ipaddress import (
from pathlib import Path, PurePath
from re import Pattern
from types import GeneratorType
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Annotated, Any, Callable, Optional, Union
from uuid import UUID
from annotated_doc import Doc
from fastapi._compat import may_v1
from fastapi.exceptions import PydanticV1NotSupportedError
from fastapi.types import IncEx
from pydantic import BaseModel
from pydantic.color import Color
from pydantic.networks import AnyUrl, NameEmail
from pydantic.types import SecretBytes, SecretStr
from typing_extensions import Annotated
from pydantic_core import PydanticUndefinedType
from ._compat import Url, _is_undefined, _model_dump
from ._compat import (
Url,
is_pydantic_v1_model_instance,
)
# Taken from Pydantic v1 as is
@@ -34,14 +37,14 @@ def isoformat(o: Union[datetime.date, datetime.time]) -> str:
return o.isoformat()
# Taken from Pydantic v1 as is
# Adapted from Pydantic v1
# TODO: pv2 should this return strings instead?
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
"""
Encodes a Decimal as int of there's no exponent, otherwise float
Encodes a Decimal as int if there's no exponent, otherwise float
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
where a integer (but not int typed) is used. Encoding this as a float
where an integer (but not int typed) is used. Encoding this as a float
results in failed round-tripping between encode and parse.
Our Id type is a prime example of this.
@@ -50,17 +53,20 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
>>> decimal_encoder(Decimal("1"))
1
>>> decimal_encoder(Decimal("NaN"))
nan
"""
if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
exponent = dec_value.as_tuple().exponent
if isinstance(exponent, int) and exponent >= 0:
return int(dec_value)
else:
return float(dec_value)
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = {
bytes: lambda o: o.decode(),
Color: str,
may_v1.Color: str,
datetime.date: isoformat,
datetime.datetime: isoformat,
datetime.time: isoformat,
@@ -77,26 +83,21 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
IPv6Interface: str,
IPv6Network: str,
NameEmail: str,
may_v1.NameEmail: str,
Path: str,
Pattern: lambda o: o.pattern,
SecretBytes: str,
may_v1.SecretBytes: str,
SecretStr: str,
may_v1.SecretStr: str,
set: list,
UUID: str,
Url: str,
may_v1.Url: str,
AnyUrl: str,
may_v1.AnyUrl: str,
}
def generate_encoders_by_class_tuples(
type_encoder_map: Dict[Any, Callable[[Any], Any]],
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(
type_encoder_map: dict[Any, Callable[[Any], Any]],
) -> dict[Callable[[Any], Any], tuple[Any, ...]]:
encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(
tuple
)
for type_, encoder in type_encoder_map.items():
@@ -176,7 +177,7 @@ def jsonable_encoder(
),
] = False,
custom_encoder: Annotated[
Optional[Dict[Any, Callable[[Any], Any]]],
Optional[dict[Any, Callable[[Any], Any]]],
Doc(
"""
Pydantic's `custom_encoder` parameter, passed to Pydantic models to define
@@ -221,15 +222,8 @@ def jsonable_encoder(
include = set(include)
if exclude is not None and not isinstance(exclude, (set, dict)):
exclude = set(exclude)
if isinstance(obj, (BaseModel, may_v1.BaseModel)):
# TODO: remove when deprecating Pydantic v1
encoders: Dict[Any, Any] = {}
if isinstance(obj, may_v1.BaseModel):
encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined]
if custom_encoder:
encoders = {**encoders, **custom_encoder}
obj_dict = _model_dump(
obj,
if isinstance(obj, BaseModel):
obj_dict = obj.model_dump(
mode="json",
include=include,
exclude=exclude,
@@ -238,14 +232,10 @@ def jsonable_encoder(
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
)
if "__root__" in obj_dict:
obj_dict = obj_dict["__root__"]
return jsonable_encoder(
obj_dict,
exclude_none=exclude_none,
exclude_defaults=exclude_defaults,
# TODO: remove when deprecating Pydantic v1
custom_encoder=encoders,
sqlalchemy_safe=sqlalchemy_safe,
)
if dataclasses.is_dataclass(obj):
@@ -268,7 +258,7 @@ def jsonable_encoder(
return str(obj)
if isinstance(obj, (str, int, float, type(None))):
return obj
if _is_undefined(obj):
if isinstance(obj, PydanticUndefinedType):
return None
if isinstance(obj, dict):
encoded_dict = {}
@@ -328,11 +318,15 @@ def jsonable_encoder(
for encoder, classes_tuple in encoders_by_class_tuples.items():
if isinstance(obj, classes_tuple):
return encoder(obj)
if is_pydantic_v1_model_instance(obj):
raise PydanticV1NotSupportedError(
"pydantic.v1 models are no longer supported by FastAPI."
f" Please update the model {obj!r}."
)
try:
data = dict(obj)
except Exception as e:
errors: List[Exception] = []
errors: list[Exception] = []
errors.append(e)
try:
data = vars(obj)

View File

@@ -1,10 +1,17 @@
from typing import Any, Dict, Optional, Sequence, Type, Union
from collections.abc import Sequence
from typing import Annotated, Any, Optional, TypedDict, Union
from annotated_doc import Doc
from pydantic import BaseModel, create_model
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.exceptions import WebSocketException as StarletteWebSocketException
from typing_extensions import Annotated
class EndpointContext(TypedDict, total=False):
function: str
path: str
file: str
line: int
class HTTPException(StarletteHTTPException):
@@ -55,7 +62,7 @@ class HTTPException(StarletteHTTPException):
),
] = None,
headers: Annotated[
Optional[Dict[str, str]],
Optional[dict[str, str]],
Doc(
"""
Any headers to send to the client in the response.
@@ -137,8 +144,8 @@ class WebSocketException(StarletteWebSocketException):
super().__init__(code=code, reason=reason)
RequestErrorModel: Type[BaseModel] = create_model("Request")
WebSocketErrorModel: Type[BaseModel] = create_model("WebSocket")
RequestErrorModel: type[BaseModel] = create_model("Request")
WebSocketErrorModel: type[BaseModel] = create_model("WebSocket")
class FastAPIError(RuntimeError):
@@ -155,30 +162,85 @@ class DependencyScopeError(FastAPIError):
class ValidationException(Exception):
def __init__(self, errors: Sequence[Any]) -> None:
def __init__(
self,
errors: Sequence[Any],
*,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
self._errors = errors
self.endpoint_ctx = endpoint_ctx
ctx = endpoint_ctx or {}
self.endpoint_function = ctx.get("function")
self.endpoint_path = ctx.get("path")
self.endpoint_file = ctx.get("file")
self.endpoint_line = ctx.get("line")
def errors(self) -> Sequence[Any]:
return self._errors
def _format_endpoint_context(self) -> str:
if not (self.endpoint_file and self.endpoint_line and self.endpoint_function):
if self.endpoint_path:
return f"\n Endpoint: {self.endpoint_path}"
return ""
context = f'\n File "{self.endpoint_file}", line {self.endpoint_line}, in {self.endpoint_function}'
if self.endpoint_path:
context += f"\n {self.endpoint_path}"
return context
def __str__(self) -> str:
message = f"{len(self._errors)} validation error{'s' if len(self._errors) != 1 else ''}:\n"
for err in self._errors:
message += f" {err}\n"
message += self._format_endpoint_context()
return message.rstrip()
class RequestValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
super().__init__(errors)
def __init__(
self,
errors: Sequence[Any],
*,
body: Any = None,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body
class WebSocketRequestValidationError(ValidationException):
pass
def __init__(
self,
errors: Sequence[Any],
*,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
class ResponseValidationError(ValidationException):
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
super().__init__(errors)
def __init__(
self,
errors: Sequence[Any],
*,
body: Any = None,
endpoint_ctx: Optional[EndpointContext] = None,
) -> None:
super().__init__(errors, endpoint_ctx=endpoint_ctx)
self.body = body
def __str__(self) -> str:
message = f"{len(self._errors)} validation errors:\n"
for err in self._errors:
message += f" {err}\n"
return message
class PydanticV1NotSupportedError(FastAPIError):
"""
A pydantic.v1 model is used, which is no longer supported.
"""
class FastAPIDeprecationWarning(UserWarning):
"""
A custom deprecation warning as DeprecationWarning is ignored
Ref: https://sethmlarson.dev/deprecations-via-warnings-dont-work-for-python-libraries
"""

View File

@@ -1,13 +1,12 @@
import json
from typing import Any, Dict, Optional
from typing import Annotated, Any, Optional
from annotated_doc import Doc
from fastapi.encoders import jsonable_encoder
from starlette.responses import HTMLResponse
from typing_extensions import Annotated
swagger_ui_default_parameters: Annotated[
Dict[str, Any],
dict[str, Any],
Doc(
"""
Default configurations for Swagger UI.
@@ -82,7 +81,7 @@ def get_swagger_ui_html(
),
] = None,
init_oauth: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
A dictionary with Swagger UI OAuth2 initialization configurations.
@@ -90,7 +89,7 @@ def get_swagger_ui_html(
),
] = None,
swagger_ui_parameters: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Configuration parameters for Swagger UI.

View File

@@ -1,17 +1,16 @@
from collections.abc import Iterable, Mapping
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
from typing import Annotated, Any, Callable, Optional, Union
from fastapi._compat import (
PYDANTIC_V2,
CoreSchema,
GetJsonSchemaHandler,
JsonSchemaValue,
_model_rebuild,
with_info_plain_validator_function,
)
from fastapi._compat import with_info_plain_validator_function
from fastapi.logger import logger
from pydantic import AnyUrl, BaseModel, Field
from typing_extensions import Annotated, Literal, TypedDict
from pydantic import (
AnyUrl,
BaseModel,
Field,
GetJsonSchemaHandler,
)
from typing_extensions import Literal, TypedDict
from typing_extensions import deprecated as typing_deprecated
try:
@@ -44,25 +43,19 @@ except ImportError: # pragma: no cover
@classmethod
def __get_pydantic_json_schema__(
cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
cls, core_schema: Mapping[str, Any], handler: GetJsonSchemaHandler
) -> dict[str, Any]:
return {"type": "string", "format": "email"}
@classmethod
def __get_pydantic_core_schema__(
cls, source: Type[Any], handler: Callable[[Any], CoreSchema]
) -> CoreSchema:
cls, source: type[Any], handler: Callable[[Any], Mapping[str, Any]]
) -> Mapping[str, Any]:
return with_info_plain_validator_function(cls._validate)
class BaseModelWithConfig(BaseModel):
if PYDANTIC_V2:
model_config = {"extra": "allow"}
else:
class Config:
extra = "allow"
model_config = {"extra": "allow"}
class Contact(BaseModelWithConfig):
@@ -88,7 +81,7 @@ class Info(BaseModelWithConfig):
class ServerVariable(BaseModelWithConfig):
enum: Annotated[Optional[List[str]], Field(min_length=1)] = None
enum: Annotated[Optional[list[str]], Field(min_length=1)] = None
default: str
description: Optional[str] = None
@@ -96,7 +89,7 @@ class ServerVariable(BaseModelWithConfig):
class Server(BaseModelWithConfig):
url: Union[AnyUrl, str]
description: Optional[str] = None
variables: Optional[Dict[str, ServerVariable]] = None
variables: Optional[dict[str, ServerVariable]] = None
class Reference(BaseModel):
@@ -105,7 +98,7 @@ class Reference(BaseModel):
class Discriminator(BaseModel):
propertyName: str
mapping: Optional[Dict[str, str]] = None
mapping: Optional[dict[str, str]] = None
class XML(BaseModelWithConfig):
@@ -137,34 +130,34 @@ class Schema(BaseModelWithConfig):
dynamicAnchor: Optional[str] = Field(default=None, alias="$dynamicAnchor")
ref: Optional[str] = Field(default=None, alias="$ref")
dynamicRef: Optional[str] = Field(default=None, alias="$dynamicRef")
defs: Optional[Dict[str, "SchemaOrBool"]] = Field(default=None, alias="$defs")
defs: Optional[dict[str, "SchemaOrBool"]] = Field(default=None, alias="$defs")
comment: Optional[str] = Field(default=None, alias="$comment")
# Ref: JSON Schema 2020-12: https://json-schema.org/draft/2020-12/json-schema-core.html#name-a-vocabulary-for-applying-s
# A Vocabulary for Applying Subschemas
allOf: Optional[List["SchemaOrBool"]] = None
anyOf: Optional[List["SchemaOrBool"]] = None
oneOf: Optional[List["SchemaOrBool"]] = None
allOf: Optional[list["SchemaOrBool"]] = None
anyOf: Optional[list["SchemaOrBool"]] = None
oneOf: Optional[list["SchemaOrBool"]] = None
not_: Optional["SchemaOrBool"] = Field(default=None, alias="not")
if_: Optional["SchemaOrBool"] = Field(default=None, alias="if")
then: Optional["SchemaOrBool"] = None
else_: Optional["SchemaOrBool"] = Field(default=None, alias="else")
dependentSchemas: Optional[Dict[str, "SchemaOrBool"]] = None
prefixItems: Optional[List["SchemaOrBool"]] = None
dependentSchemas: Optional[dict[str, "SchemaOrBool"]] = None
prefixItems: Optional[list["SchemaOrBool"]] = None
# TODO: uncomment and remove below when deprecating Pydantic v1
# It generates a list of schemas for tuples, before prefixItems was available
# items: Optional["SchemaOrBool"] = None
items: Optional[Union["SchemaOrBool", List["SchemaOrBool"]]] = None
items: Optional[Union["SchemaOrBool", list["SchemaOrBool"]]] = None
contains: Optional["SchemaOrBool"] = None
properties: Optional[Dict[str, "SchemaOrBool"]] = None
patternProperties: Optional[Dict[str, "SchemaOrBool"]] = None
properties: Optional[dict[str, "SchemaOrBool"]] = None
patternProperties: Optional[dict[str, "SchemaOrBool"]] = None
additionalProperties: Optional["SchemaOrBool"] = None
propertyNames: Optional["SchemaOrBool"] = None
unevaluatedItems: Optional["SchemaOrBool"] = None
unevaluatedProperties: Optional["SchemaOrBool"] = None
# Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-a-vocabulary-for-structural
# A Vocabulary for Structural Validation
type: Optional[Union[SchemaType, List[SchemaType]]] = None
enum: Optional[List[Any]] = None
type: Optional[Union[SchemaType, list[SchemaType]]] = None
enum: Optional[list[Any]] = None
const: Optional[Any] = None
multipleOf: Optional[float] = Field(default=None, gt=0)
maximum: Optional[float] = None
@@ -181,8 +174,8 @@ class Schema(BaseModelWithConfig):
minContains: Optional[int] = Field(default=None, ge=0)
maxProperties: Optional[int] = Field(default=None, ge=0)
minProperties: Optional[int] = Field(default=None, ge=0)
required: Optional[List[str]] = None
dependentRequired: Optional[Dict[str, Set[str]]] = None
required: Optional[list[str]] = None
dependentRequired: Optional[dict[str, set[str]]] = None
# Ref: JSON Schema Validation 2020-12: https://json-schema.org/draft/2020-12/json-schema-validation.html#name-vocabularies-for-semantic-c
# Vocabularies for Semantic Content With "format"
format: Optional[str] = None
@@ -199,7 +192,7 @@ class Schema(BaseModelWithConfig):
deprecated: Optional[bool] = None
readOnly: Optional[bool] = None
writeOnly: Optional[bool] = None
examples: Optional[List[Any]] = None
examples: Optional[list[Any]] = None
# Ref: OpenAPI 3.1.0: https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#schema-object
# Schema Object
discriminator: Optional[Discriminator] = None
@@ -225,13 +218,7 @@ class Example(TypedDict, total=False):
value: Optional[Any]
externalValue: Optional[AnyUrl]
if PYDANTIC_V2: # type: ignore [misc]
__pydantic_config__ = {"extra": "allow"}
else:
class Config:
extra = "allow"
__pydantic_config__ = {"extra": "allow"} # type: ignore[misc]
class ParameterInType(Enum):
@@ -243,7 +230,7 @@ class ParameterInType(Enum):
class Encoding(BaseModelWithConfig):
contentType: Optional[str] = None
headers: Optional[Dict[str, Union["Header", Reference]]] = None
headers: Optional[dict[str, Union["Header", Reference]]] = None
style: Optional[str] = None
explode: Optional[bool] = None
allowReserved: Optional[bool] = None
@@ -252,8 +239,8 @@ class Encoding(BaseModelWithConfig):
class MediaType(BaseModelWithConfig):
schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema")
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
encoding: Optional[Dict[str, Encoding]] = None
examples: Optional[dict[str, Union[Example, Reference]]] = None
encoding: Optional[dict[str, Encoding]] = None
class ParameterBase(BaseModelWithConfig):
@@ -266,9 +253,9 @@ class ParameterBase(BaseModelWithConfig):
allowReserved: Optional[bool] = None
schema_: Optional[Union[Schema, Reference]] = Field(default=None, alias="schema")
example: Optional[Any] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
examples: Optional[dict[str, Union[Example, Reference]]] = None
# Serialization rules for more complex scenarios
content: Optional[Dict[str, MediaType]] = None
content: Optional[dict[str, MediaType]] = None
class Parameter(ParameterBase):
@@ -282,14 +269,14 @@ class Header(ParameterBase):
class RequestBody(BaseModelWithConfig):
description: Optional[str] = None
content: Dict[str, MediaType]
content: dict[str, MediaType]
required: Optional[bool] = None
class Link(BaseModelWithConfig):
operationRef: Optional[str] = None
operationId: Optional[str] = None
parameters: Optional[Dict[str, Union[Any, str]]] = None
parameters: Optional[dict[str, Union[Any, str]]] = None
requestBody: Optional[Union[Any, str]] = None
description: Optional[str] = None
server: Optional[Server] = None
@@ -297,25 +284,25 @@ class Link(BaseModelWithConfig):
class Response(BaseModelWithConfig):
description: str
headers: Optional[Dict[str, Union[Header, Reference]]] = None
content: Optional[Dict[str, MediaType]] = None
links: Optional[Dict[str, Union[Link, Reference]]] = None
headers: Optional[dict[str, Union[Header, Reference]]] = None
content: Optional[dict[str, MediaType]] = None
links: Optional[dict[str, Union[Link, Reference]]] = None
class Operation(BaseModelWithConfig):
tags: Optional[List[str]] = None
tags: Optional[list[str]] = None
summary: Optional[str] = None
description: Optional[str] = None
externalDocs: Optional[ExternalDocumentation] = None
operationId: Optional[str] = None
parameters: Optional[List[Union[Parameter, Reference]]] = None
parameters: Optional[list[Union[Parameter, Reference]]] = None
requestBody: Optional[Union[RequestBody, Reference]] = None
# Using Any for Specification Extensions
responses: Optional[Dict[str, Union[Response, Any]]] = None
callbacks: Optional[Dict[str, Union[Dict[str, "PathItem"], Reference]]] = None
responses: Optional[dict[str, Union[Response, Any]]] = None
callbacks: Optional[dict[str, Union[dict[str, "PathItem"], Reference]]] = None
deprecated: Optional[bool] = None
security: Optional[List[Dict[str, List[str]]]] = None
servers: Optional[List[Server]] = None
security: Optional[list[dict[str, list[str]]]] = None
servers: Optional[list[Server]] = None
class PathItem(BaseModelWithConfig):
@@ -330,8 +317,8 @@ class PathItem(BaseModelWithConfig):
head: Optional[Operation] = None
patch: Optional[Operation] = None
trace: Optional[Operation] = None
servers: Optional[List[Server]] = None
parameters: Optional[List[Union[Parameter, Reference]]] = None
servers: Optional[list[Server]] = None
parameters: Optional[list[Union[Parameter, Reference]]] = None
class SecuritySchemeType(Enum):
@@ -370,7 +357,7 @@ class HTTPBearer(HTTPBase):
class OAuthFlow(BaseModelWithConfig):
refreshUrl: Optional[str] = None
scopes: Dict[str, str] = {}
scopes: dict[str, str] = {}
class OAuthFlowImplicit(OAuthFlow):
@@ -413,17 +400,17 @@ SecurityScheme = Union[APIKey, HTTPBase, OAuth2, OpenIdConnect, HTTPBearer]
class Components(BaseModelWithConfig):
schemas: Optional[Dict[str, Union[Schema, Reference]]] = None
responses: Optional[Dict[str, Union[Response, Reference]]] = None
parameters: Optional[Dict[str, Union[Parameter, Reference]]] = None
examples: Optional[Dict[str, Union[Example, Reference]]] = None
requestBodies: Optional[Dict[str, Union[RequestBody, Reference]]] = None
headers: Optional[Dict[str, Union[Header, Reference]]] = None
securitySchemes: Optional[Dict[str, Union[SecurityScheme, Reference]]] = None
links: Optional[Dict[str, Union[Link, Reference]]] = None
schemas: Optional[dict[str, Union[Schema, Reference]]] = None
responses: Optional[dict[str, Union[Response, Reference]]] = None
parameters: Optional[dict[str, Union[Parameter, Reference]]] = None
examples: Optional[dict[str, Union[Example, Reference]]] = None
requestBodies: Optional[dict[str, Union[RequestBody, Reference]]] = None
headers: Optional[dict[str, Union[Header, Reference]]] = None
securitySchemes: Optional[dict[str, Union[SecurityScheme, Reference]]] = None
links: Optional[dict[str, Union[Link, Reference]]] = None
# Using Any for Specification Extensions
callbacks: Optional[Dict[str, Union[Dict[str, PathItem], Reference, Any]]] = None
pathItems: Optional[Dict[str, Union[PathItem, Reference]]] = None
callbacks: Optional[dict[str, Union[dict[str, PathItem], Reference, Any]]] = None
pathItems: Optional[dict[str, Union[PathItem, Reference]]] = None
class Tag(BaseModelWithConfig):
@@ -436,16 +423,16 @@ class OpenAPI(BaseModelWithConfig):
openapi: str
info: Info
jsonSchemaDialect: Optional[str] = None
servers: Optional[List[Server]] = None
servers: Optional[list[Server]] = None
# Using Any for Specification Extensions
paths: Optional[Dict[str, Union[PathItem, Any]]] = None
webhooks: Optional[Dict[str, Union[PathItem, Reference]]] = None
paths: Optional[dict[str, Union[PathItem, Any]]] = None
webhooks: Optional[dict[str, Union[PathItem, Reference]]] = None
components: Optional[Components] = None
security: Optional[List[Dict[str, List[str]]]] = None
tags: Optional[List[Tag]] = None
security: Optional[list[dict[str, list[str]]]] = None
tags: Optional[list[Tag]] = None
externalDocs: Optional[ExternalDocumentation] = None
_model_rebuild(Schema)
_model_rebuild(Operation)
_model_rebuild(Encoding)
Schema.model_rebuild()
Operation.model_rebuild()
Encoding.model_rebuild()

View File

@@ -1,11 +1,11 @@
import http.client
import inspect
import warnings
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from fastapi import routing
from fastapi._compat import (
JsonSchemaValue,
ModelField,
Undefined,
get_compat_model_name_map,
@@ -19,8 +19,10 @@ from fastapi.dependencies.utils import (
_get_flat_fields_from_params,
get_flat_dependant,
get_flat_params,
get_validation_alias,
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, ParamTypes
@@ -36,8 +38,6 @@ from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from typing_extensions import Literal
from .._compat import _is_model_field
validation_error_definition = {
"title": "ValidationError",
"type": "object",
@@ -65,7 +65,7 @@ validation_error_response_definition = {
},
}
status_code_ranges: Dict[str, str] = {
status_code_ranges: dict[str, str] = {
"1XX": "Information",
"2XX": "Success",
"3XX": "Redirection",
@@ -77,18 +77,27 @@ status_code_ranges: Dict[str, str] = {
def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
# Use a dict to merge scopes for same security scheme
operation_security_dict: dict[str, list[str]] = {}
for security_dependency in flat_dependant._security_dependencies:
security_definition = jsonable_encoder(
security_requirement.security_scheme.model,
security_dependency._security_scheme.model,
by_alias=True,
exclude_none=True,
)
security_name = security_requirement.security_scheme.scheme_name
security_name = security_dependency._security_scheme.scheme_name
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
# Merge scopes for the same security scheme
if security_name not in operation_security_dict:
operation_security_dict[security_name] = []
for scope in security_dependency.oauth_scopes or []:
if scope not in operation_security_dict[security_name]:
operation_security_dict[security_name].append(scope)
operation_security = [
{name: scopes} for name, scopes in operation_security_dict.items()
]
return security_definitions, operation_security
@@ -96,11 +105,11 @@ def _get_openapi_operation_parameters(
*,
dependant: Dependant,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
],
separate_input_output_schemas: bool = True,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
parameters = []
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
@@ -132,7 +141,7 @@ def _get_openapi_operation_parameters(
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
name = param.alias
name = get_validation_alias(param)
convert_underscores = getattr(
param.field_info,
"convert_underscores",
@@ -140,7 +149,7 @@ def _get_openapi_operation_parameters(
)
if (
param_type == ParamTypes.header
and param.alias == param.name
and name == param.name
and convert_underscores
):
name = param.name.replace("_", "-")
@@ -169,14 +178,14 @@ def get_openapi_operation_request_body(
*,
body_field: Optional[ModelField],
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
],
separate_input_output_schemas: bool = True,
) -> Optional[Dict[str, Any]]:
) -> Optional[dict[str, Any]]:
if not body_field:
return None
assert _is_model_field(body_field)
assert isinstance(body_field, ModelField)
body_schema = get_schema_from_model_field(
field=body_field,
model_name_map=model_name_map,
@@ -186,10 +195,10 @@ def get_openapi_operation_request_body(
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
required = body_field.required
request_body_oai: Dict[str, Any] = {}
request_body_oai: dict[str, Any] = {}
if required:
request_body_oai["required"] = required
request_media_content: Dict[str, Any] = {"schema": body_schema}
request_media_content: dict[str, Any] = {"schema": body_schema}
if field_info.openapi_examples:
request_media_content["examples"] = jsonable_encoder(
field_info.openapi_examples
@@ -204,9 +213,9 @@ def generate_operation_id(
*, route: routing.APIRoute, method: str
) -> str: # pragma: nocover
warnings.warn(
"fastapi.openapi.utils.generate_operation_id() was deprecated, "
message="fastapi.openapi.utils.generate_operation_id() was deprecated, "
"it is not used internally, and will be removed soon",
DeprecationWarning,
category=FastAPIDeprecationWarning,
stacklevel=2,
)
if route.operation_id:
@@ -222,9 +231,9 @@ def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
def get_openapi_operation_metadata(
*, route: routing.APIRoute, method: str, operation_ids: Set[str]
) -> Dict[str, Any]:
operation: Dict[str, Any] = {}
*, route: routing.APIRoute, method: str, operation_ids: set[str]
) -> dict[str, Any]:
operation: dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
operation["summary"] = generate_operation_summary(route=route, method=method)
@@ -250,19 +259,19 @@ def get_openapi_operation_metadata(
def get_openapi_path(
*,
route: routing.APIRoute,
operation_ids: Set[str],
operation_ids: set[str],
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
],
separate_input_output_schemas: bool = True,
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
path = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
security_schemes: dict[str, Any] = {}
definitions: dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
if isinstance(route.response_class, DefaultPlaceholder):
current_response_class: Type[Response] = route.response_class.value
current_response_class: type[Response] = route.response_class.value
else:
current_response_class = route.response_class
assert current_response_class, "A response class is needed to generate OpenAPI"
@@ -272,7 +281,7 @@ def get_openapi_path(
operation = get_openapi_operation_metadata(
route=route, method=method, operation_ids=operation_ids
)
parameters: List[Dict[str, Any]] = []
parameters: list[dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
@@ -380,7 +389,7 @@ def get_openapi_path(
"An additional response must be a dict"
)
field = route.response_fields.get(additional_status_code)
additional_field_schema: Optional[Dict[str, Any]] = None
additional_field_schema: Optional[dict[str, Any]] = None
if field:
additional_field_schema = get_schema_from_model_field(
field=field,
@@ -435,17 +444,17 @@ def get_openapi_path(
def get_fields_from_routes(
routes: Sequence[BaseRoute],
) -> List[ModelField]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
request_fields_from_routes: List[ModelField] = []
callback_flat_models: List[ModelField] = []
) -> list[ModelField]:
body_fields_from_routes: list[ModelField] = []
responses_from_routes: list[ModelField] = []
request_fields_from_routes: list[ModelField] = []
callback_flat_models: list[ModelField] = []
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert _is_model_field(route.body_field), (
assert isinstance(route.body_field, ModelField), (
"A request body must be a Pydantic Field"
)
body_fields_from_routes.append(route.body_field)
@@ -473,15 +482,15 @@ def get_openapi(
description: Optional[str] = None,
routes: Sequence[BaseRoute],
webhooks: Optional[Sequence[BaseRoute]] = None,
tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
tags: Optional[list[dict[str, Any]]] = None,
servers: Optional[list[dict[str, Union[str, Any]]]] = None,
terms_of_service: Optional[str] = None,
contact: Optional[Dict[str, Union[str, Any]]] = None,
license_info: Optional[Dict[str, Union[str, Any]]] = None,
contact: Optional[dict[str, Union[str, Any]]] = None,
license_info: Optional[dict[str, Union[str, Any]]] = None,
separate_input_output_schemas: bool = True,
external_docs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
info: Dict[str, Any] = {"title": title, "version": version}
external_docs: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
info: dict[str, Any] = {"title": title, "version": version}
if summary:
info["summary"] = summary
if description:
@@ -492,13 +501,13 @@ def get_openapi(
info["contact"] = contact
if license_info:
info["license"] = license_info
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
output: dict[str, Any] = {"openapi": openapi_version, "info": info}
if servers:
output["servers"] = servers
components: Dict[str, Dict[str, Any]] = {}
paths: Dict[str, Dict[str, Any]] = {}
webhook_paths: Dict[str, Dict[str, Any]] = {}
operation_ids: Set[str] = set()
components: dict[str, dict[str, Any]] = {}
paths: dict[str, dict[str, Any]] = {}
webhook_paths: dict[str, dict[str, Any]] = {}
operation_ids: set[str] = set()
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
model_name_map = get_compat_model_name_map(all_fields)
field_mapping, definitions = get_definitions(

View File

@@ -1,10 +1,12 @@
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from collections.abc import Sequence
from typing import Annotated, Any, Callable, Optional, Union
from annotated_doc import Doc
from fastapi import params
from fastapi._compat import Undefined
from fastapi.openapi.models import Example
from typing_extensions import Annotated, Literal, deprecated
from pydantic import AliasChoices, AliasPath
from typing_extensions import Literal, deprecated
_Unset: Any = Undefined
@@ -53,10 +55,8 @@ def Path( # noqa: N802
"""
),
] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Annotated[
Union[str, None],
Union[str, AliasPath, AliasChoices, None],
Doc(
"""
'Whitelist' validation step. The parameter field will be the single one
@@ -209,7 +209,7 @@ def Path( # noqa: N802
),
] = _Unset,
examples: Annotated[
Optional[List[Any]],
Optional[list[Any]],
Doc(
"""
Example values for this field.
@@ -224,7 +224,7 @@ def Path( # noqa: N802
),
] = _Unset,
openapi_examples: Annotated[
Optional[Dict[str, Example]],
Optional[dict[str, Example]],
Doc(
"""
OpenAPI-specific examples.
@@ -262,7 +262,7 @@ def Path( # noqa: N802
),
] = True,
json_schema_extra: Annotated[
Union[Dict[str, Any], None],
Union[dict[str, Any], None],
Doc(
"""
Any additional JSON schema data.
@@ -378,10 +378,8 @@ def Query( # noqa: N802
"""
),
] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Annotated[
Union[str, None],
Union[str, AliasPath, AliasChoices, None],
Doc(
"""
'Whitelist' validation step. The parameter field will be the single one
@@ -534,7 +532,7 @@ def Query( # noqa: N802
),
] = _Unset,
examples: Annotated[
Optional[List[Any]],
Optional[list[Any]],
Doc(
"""
Example values for this field.
@@ -549,7 +547,7 @@ def Query( # noqa: N802
),
] = _Unset,
openapi_examples: Annotated[
Optional[Dict[str, Example]],
Optional[dict[str, Example]],
Doc(
"""
OpenAPI-specific examples.
@@ -587,7 +585,7 @@ def Query( # noqa: N802
),
] = True,
json_schema_extra: Annotated[
Union[Dict[str, Any], None],
Union[dict[str, Any], None],
Doc(
"""
Any additional JSON schema data.
@@ -682,10 +680,8 @@ def Header( # noqa: N802
"""
),
] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Annotated[
Union[str, None],
Union[str, AliasPath, AliasChoices, None],
Doc(
"""
'Whitelist' validation step. The parameter field will be the single one
@@ -849,7 +845,7 @@ def Header( # noqa: N802
),
] = _Unset,
examples: Annotated[
Optional[List[Any]],
Optional[list[Any]],
Doc(
"""
Example values for this field.
@@ -864,7 +860,7 @@ def Header( # noqa: N802
),
] = _Unset,
openapi_examples: Annotated[
Optional[Dict[str, Example]],
Optional[dict[str, Example]],
Doc(
"""
OpenAPI-specific examples.
@@ -902,7 +898,7 @@ def Header( # noqa: N802
),
] = True,
json_schema_extra: Annotated[
Union[Dict[str, Any], None],
Union[dict[str, Any], None],
Doc(
"""
Any additional JSON schema data.
@@ -998,10 +994,8 @@ def Cookie( # noqa: N802
"""
),
] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Annotated[
Union[str, None],
Union[str, AliasPath, AliasChoices, None],
Doc(
"""
'Whitelist' validation step. The parameter field will be the single one
@@ -1154,7 +1148,7 @@ def Cookie( # noqa: N802
),
] = _Unset,
examples: Annotated[
Optional[List[Any]],
Optional[list[Any]],
Doc(
"""
Example values for this field.
@@ -1169,7 +1163,7 @@ def Cookie( # noqa: N802
),
] = _Unset,
openapi_examples: Annotated[
Optional[Dict[str, Example]],
Optional[dict[str, Example]],
Doc(
"""
OpenAPI-specific examples.
@@ -1207,7 +1201,7 @@ def Cookie( # noqa: N802
),
] = True,
json_schema_extra: Annotated[
Union[Dict[str, Any], None],
Union[dict[str, Any], None],
Doc(
"""
Any additional JSON schema data.
@@ -1325,10 +1319,8 @@ def Body( # noqa: N802
"""
),
] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Annotated[
Union[str, None],
Union[str, AliasPath, AliasChoices, None],
Doc(
"""
'Whitelist' validation step. The parameter field will be the single one
@@ -1481,7 +1473,7 @@ def Body( # noqa: N802
),
] = _Unset,
examples: Annotated[
Optional[List[Any]],
Optional[list[Any]],
Doc(
"""
Example values for this field.
@@ -1496,7 +1488,7 @@ def Body( # noqa: N802
),
] = _Unset,
openapi_examples: Annotated[
Optional[Dict[str, Example]],
Optional[dict[str, Example]],
Doc(
"""
OpenAPI-specific examples.
@@ -1534,7 +1526,7 @@ def Body( # noqa: N802
),
] = True,
json_schema_extra: Annotated[
Union[Dict[str, Any], None],
Union[dict[str, Any], None],
Doc(
"""
Any additional JSON schema data.
@@ -1640,10 +1632,8 @@ def Form( # noqa: N802
"""
),
] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Annotated[
Union[str, None],
Union[str, AliasPath, AliasChoices, None],
Doc(
"""
'Whitelist' validation step. The parameter field will be the single one
@@ -1796,7 +1786,7 @@ def Form( # noqa: N802
),
] = _Unset,
examples: Annotated[
Optional[List[Any]],
Optional[list[Any]],
Doc(
"""
Example values for this field.
@@ -1811,7 +1801,7 @@ def Form( # noqa: N802
),
] = _Unset,
openapi_examples: Annotated[
Optional[Dict[str, Example]],
Optional[dict[str, Example]],
Doc(
"""
OpenAPI-specific examples.
@@ -1849,7 +1839,7 @@ def Form( # noqa: N802
),
] = True,
json_schema_extra: Annotated[
Union[Dict[str, Any], None],
Union[dict[str, Any], None],
Doc(
"""
Any additional JSON schema data.
@@ -1954,10 +1944,8 @@ def File( # noqa: N802
"""
),
] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Annotated[
Union[str, None],
Union[str, AliasPath, AliasChoices, None],
Doc(
"""
'Whitelist' validation step. The parameter field will be the single one
@@ -2110,7 +2098,7 @@ def File( # noqa: N802
),
] = _Unset,
examples: Annotated[
Optional[List[Any]],
Optional[list[Any]],
Doc(
"""
Example values for this field.
@@ -2125,7 +2113,7 @@ def File( # noqa: N802
),
] = _Unset,
openapi_examples: Annotated[
Optional[Dict[str, Example]],
Optional[dict[str, Example]],
Doc(
"""
OpenAPI-specific examples.
@@ -2163,7 +2151,7 @@ def File( # noqa: N802
),
] = True,
json_schema_extra: Annotated[
Union[Dict[str, Any], None],
Union[dict[str, Any], None],
Doc(
"""
Any additional JSON schema data.

View File

@@ -1,15 +1,16 @@
import warnings
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Annotated, Any, Callable, Optional, Union
from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.openapi.models import Example
from pydantic import AliasChoices, AliasPath
from pydantic.fields import FieldInfo
from typing_extensions import Annotated, Literal, deprecated
from typing_extensions import Literal, deprecated
from ._compat import (
PYDANTIC_V2,
PYDANTIC_VERSION_MINOR_TUPLE,
Undefined,
)
@@ -34,9 +35,7 @@ class Param(FieldInfo): # type: ignore[misc]
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
validation_alias: Union[str, AliasPath, AliasChoices, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
@@ -59,7 +58,7 @@ class Param(FieldInfo): # type: ignore[misc]
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
examples: Optional[list[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
@@ -67,16 +66,16 @@ class Param(FieldInfo): # type: ignore[misc]
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
openapi_examples: Optional[dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
json_schema_extra: Union[dict[str, Any], None] = None,
**extra: Any,
):
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
category=DeprecationWarning,
category=FastAPIDeprecationWarning,
stacklevel=4,
)
self.example = example
@@ -106,29 +105,28 @@ class Param(FieldInfo): # type: ignore[misc]
if regex is not None:
warnings.warn(
"`regex` has been deprecated, please use `pattern` instead",
category=DeprecationWarning,
category=FastAPIDeprecationWarning,
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
self.deprecated = deprecated
else:
kwargs["deprecated"] = deprecated
if PYDANTIC_V2:
kwargs.update(
{
"annotation": annotation,
"alias_priority": alias_priority,
"validation_alias": validation_alias,
"serialization_alias": serialization_alias,
"strict": strict,
"json_schema_extra": current_json_schema_extra,
}
)
kwargs["pattern"] = pattern or regex
else:
kwargs["regex"] = pattern or regex
kwargs.update(**current_json_schema_extra)
kwargs["deprecated"] = deprecated
if serialization_alias in (_Unset, None) and isinstance(alias, str):
serialization_alias = alias
if validation_alias in (_Unset, None):
validation_alias = alias
kwargs.update(
{
"annotation": annotation,
"alias_priority": alias_priority,
"validation_alias": validation_alias,
"serialization_alias": serialization_alias,
"strict": strict,
"json_schema_extra": current_json_schema_extra,
}
)
kwargs["pattern"] = pattern or regex
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs)
@@ -148,9 +146,7 @@ class Path(Param): # type: ignore[misc]
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
validation_alias: Union[str, AliasPath, AliasChoices, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
@@ -173,7 +169,7 @@ class Path(Param): # type: ignore[misc]
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
examples: Optional[list[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
@@ -181,10 +177,10 @@ class Path(Param): # type: ignore[misc]
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
openapi_examples: Optional[dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
json_schema_extra: Union[dict[str, Any], None] = None,
**extra: Any,
):
assert default is ..., "Path parameters cannot have a default value"
@@ -234,9 +230,7 @@ class Query(Param): # type: ignore[misc]
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
validation_alias: Union[str, AliasPath, AliasChoices, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
@@ -259,7 +253,7 @@ class Query(Param): # type: ignore[misc]
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
examples: Optional[list[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
@@ -267,10 +261,10 @@ class Query(Param): # type: ignore[misc]
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
openapi_examples: Optional[dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
json_schema_extra: Union[dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
@@ -318,9 +312,7 @@ class Header(Param): # type: ignore[misc]
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
validation_alias: Union[str, AliasPath, AliasChoices, None] = None,
serialization_alias: Union[str, None] = None,
convert_underscores: bool = True,
title: Optional[str] = None,
@@ -344,7 +336,7 @@ class Header(Param): # type: ignore[misc]
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
examples: Optional[list[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
@@ -352,10 +344,10 @@ class Header(Param): # type: ignore[misc]
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
openapi_examples: Optional[dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
json_schema_extra: Union[dict[str, Any], None] = None,
**extra: Any,
):
self.convert_underscores = convert_underscores
@@ -404,9 +396,7 @@ class Cookie(Param): # type: ignore[misc]
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
validation_alias: Union[str, AliasPath, AliasChoices, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
@@ -429,7 +419,7 @@ class Cookie(Param): # type: ignore[misc]
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
examples: Optional[list[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
@@ -437,10 +427,10 @@ class Cookie(Param): # type: ignore[misc]
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
openapi_examples: Optional[dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
json_schema_extra: Union[dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
@@ -488,9 +478,7 @@ class Body(FieldInfo): # type: ignore[misc]
media_type: str = "application/json",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
validation_alias: Union[str, AliasPath, AliasChoices, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
@@ -513,7 +501,7 @@ class Body(FieldInfo): # type: ignore[misc]
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
examples: Optional[list[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
@@ -521,10 +509,10 @@ class Body(FieldInfo): # type: ignore[misc]
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
openapi_examples: Optional[dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
json_schema_extra: Union[dict[str, Any], None] = None,
**extra: Any,
):
self.embed = embed
@@ -532,7 +520,7 @@ class Body(FieldInfo): # type: ignore[misc]
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
category=DeprecationWarning,
category=FastAPIDeprecationWarning,
stacklevel=4,
)
self.example = example
@@ -562,29 +550,26 @@ class Body(FieldInfo): # type: ignore[misc]
if regex is not None:
warnings.warn(
"`regex` has been deprecated, please use `pattern` instead",
category=DeprecationWarning,
category=FastAPIDeprecationWarning,
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
self.deprecated = deprecated
else:
kwargs["deprecated"] = deprecated
if PYDANTIC_V2:
kwargs.update(
{
"annotation": annotation,
"alias_priority": alias_priority,
"validation_alias": validation_alias,
"serialization_alias": serialization_alias,
"strict": strict,
"json_schema_extra": current_json_schema_extra,
}
)
kwargs["pattern"] = pattern or regex
else:
kwargs["regex"] = pattern or regex
kwargs.update(**current_json_schema_extra)
kwargs["deprecated"] = deprecated
if serialization_alias in (_Unset, None) and isinstance(alias, str):
serialization_alias = alias
if validation_alias in (_Unset, None):
validation_alias = alias
kwargs.update(
{
"annotation": annotation,
"alias_priority": alias_priority,
"validation_alias": validation_alias,
"serialization_alias": serialization_alias,
"strict": strict,
"json_schema_extra": current_json_schema_extra,
}
)
kwargs["pattern"] = pattern or regex
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
@@ -604,9 +589,7 @@ class Form(Body): # type: ignore[misc]
media_type: str = "application/x-www-form-urlencoded",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
validation_alias: Union[str, AliasPath, AliasChoices, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
@@ -629,7 +612,7 @@ class Form(Body): # type: ignore[misc]
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
examples: Optional[list[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
@@ -637,10 +620,10 @@ class Form(Body): # type: ignore[misc]
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
openapi_examples: Optional[dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
json_schema_extra: Union[dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
@@ -688,9 +671,7 @@ class File(Form): # type: ignore[misc]
media_type: str = "multipart/form-data",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
validation_alias: Union[str, AliasPath, AliasChoices, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
@@ -713,7 +694,7 @@ class File(Form): # type: ignore[misc]
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
examples: Optional[list[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
@@ -721,10 +702,10 @@ class File(Form): # type: ignore[misc]
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
openapi_examples: Optional[dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
json_schema_extra: Union[dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
@@ -762,13 +743,13 @@ class File(Form): # type: ignore[misc]
)
@dataclass
@dataclass(frozen=True)
class Depends:
dependency: Optional[Callable[..., Any]] = None
use_cache: bool = True
scope: Union[Literal["function", "request"], None] = None
@dataclass
@dataclass(frozen=True)
class Security(Depends):
scopes: Optional[Sequence[str]] = None

View File

@@ -1,37 +1,31 @@
import dataclasses
import email.message
import functools
import inspect
import json
import sys
from collections.abc import (
AsyncIterator,
Awaitable,
Collection,
Coroutine,
Mapping,
Sequence,
)
from contextlib import AsyncExitStack, asynccontextmanager
from enum import Enum, IntEnum
from typing import (
Annotated,
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
Coroutine,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
from annotated_doc import Doc
from fastapi import params, temp_pydantic_v1_params
from fastapi import params
from fastapi._compat import (
ModelField,
Undefined,
_get_model_config,
_model_dump,
_normalize_errors,
annotation_is_pydantic_v1,
lenient_issubclass,
)
from fastapi.datastructures import Default, DefaultPlaceholder
@@ -47,7 +41,9 @@ from fastapi.dependencies.utils import (
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import (
EndpointContext,
FastAPIError,
PydanticV1NotSupportedError,
RequestValidationError,
ResponseValidationError,
WebSocketRequestValidationError,
@@ -60,7 +56,6 @@ from fastapi.utils import (
get_value_or_default,
is_body_allowed_for_status_code,
)
from pydantic import BaseModel
from starlette import routing
from starlette._exception_handler import wrap_app_handling_exceptions
from starlette._utils import is_async_callable
@@ -77,12 +72,7 @@ from starlette.routing import (
from starlette.routing import Mount as Mount # noqa
from starlette.types import AppType, ASGIApp, Lifespan, Receive, Scope, Send
from starlette.websockets import WebSocket
from typing_extensions import Annotated, deprecated
if sys.version_info >= (3, 13): # pragma: no cover
from inspect import iscoroutinefunction
else: # pragma: no cover
from asyncio import iscoroutinefunction
from typing_extensions import deprecated
# Copy of starlette.routing.request_response modified to include the
@@ -153,54 +143,6 @@ def websocket_session(
return app
def _prepare_response_content(
res: Any,
*,
exclude_unset: bool,
exclude_defaults: bool = False,
exclude_none: bool = False,
) -> Any:
if isinstance(res, BaseModel):
read_with_orm_mode = getattr(_get_model_config(res), "read_with_orm_mode", None)
if read_with_orm_mode:
# Let from_orm extract the data from this model instead of converting
# it now to a dict.
# Otherwise, there's no way to extract lazy data that requires attribute
# access instead of dict iteration, e.g. lazy relationships.
return res
return _model_dump(
res,
by_alias=True,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
elif isinstance(res, list):
return [
_prepare_response_content(
item,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for item in res
]
elif isinstance(res, dict):
return {
k: _prepare_response_content(
v,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
for k, v in res.items()
}
elif dataclasses.is_dataclass(res):
assert not isinstance(res, type)
return dataclasses.asdict(res)
return res
def _merge_lifespan_context(
original_context: Lifespan[Any], nested_context: Lifespan[Any]
) -> Lifespan[Any]:
@@ -218,6 +160,33 @@ def _merge_lifespan_context(
return merged_lifespan # type: ignore[return-value]
# Cache for endpoint context to avoid re-extracting on every request
_endpoint_context_cache: dict[int, EndpointContext] = {}
def _extract_endpoint_context(func: Any) -> EndpointContext:
"""Extract endpoint context with caching to avoid repeated file I/O."""
func_id = id(func)
if func_id in _endpoint_context_cache:
return _endpoint_context_cache[func_id]
try:
ctx: EndpointContext = {}
if (source_file := inspect.getsourcefile(func)) is not None:
ctx["file"] = source_file
if (line_number := inspect.getsourcelines(func)[1]) is not None:
ctx["line"] = line_number
if (func_name := getattr(func, "__name__", None)) is not None:
ctx["function"] = func_name
except Exception:
ctx = EndpointContext()
_endpoint_context_cache[func_id] = ctx
return ctx
async def serialize_response(
*,
field: Optional[ModelField] = None,
@@ -229,17 +198,10 @@ async def serialize_response(
exclude_defaults: bool = False,
exclude_none: bool = False,
is_coroutine: bool = True,
endpoint_ctx: Optional[EndpointContext] = None,
) -> Any:
if field:
errors = []
if not hasattr(field, "serialize"):
# pydantic v1
response_content = _prepare_response_content(
response_content,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
if is_coroutine:
value, errors_ = field.validate(response_content, {}, loc=("response",))
else:
@@ -248,25 +210,15 @@ async def serialize_response(
)
if isinstance(errors_, list):
errors.extend(errors_)
elif errors_:
errors.append(errors_)
if errors:
ctx = endpoint_ctx or EndpointContext()
raise ResponseValidationError(
errors=_normalize_errors(errors), body=response_content
errors=errors,
body=response_content,
endpoint_ctx=ctx,
)
if hasattr(field, "serialize"):
return field.serialize(
value,
include=include,
exclude=exclude,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
return jsonable_encoder(
return field.serialize(
value,
include=include,
exclude=exclude,
@@ -275,12 +227,13 @@ async def serialize_response(
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
else:
return jsonable_encoder(response_content)
async def run_endpoint_function(
*, dependant: Dependant, values: Dict[str, Any], is_coroutine: bool
*, dependant: Dependant, values: dict[str, Any], is_coroutine: bool
) -> Any:
# Only called by get_request_handler. Has been split into its own function to
# facilitate profiling endpoints, since inner functions are harder to profile.
@@ -296,7 +249,7 @@ def get_request_handler(
dependant: Dependant,
body_field: Optional[ModelField] = None,
status_code: Optional[int] = None,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(JSONResponse),
response_class: Union[type[Response], DefaultPlaceholder] = Default(JSONResponse),
response_field: Optional[ModelField] = None,
response_model_include: Optional[IncEx] = None,
response_model_exclude: Optional[IncEx] = None,
@@ -308,12 +261,10 @@ def get_request_handler(
embed_body_fields: bool = False,
) -> Callable[[Request], Coroutine[Any, Any, Response]]:
assert dependant.call is not None, "dependant.call must be a function"
is_coroutine = iscoroutinefunction(dependant.call)
is_body_form = body_field and isinstance(
body_field.field_info, (params.Form, temp_pydantic_v1_params.Form)
)
is_coroutine = dependant.is_coroutine_callable
is_body_form = body_field and isinstance(body_field.field_info, params.Form)
if isinstance(response_class, DefaultPlaceholder):
actual_response_class: Type[Response] = response_class.value
actual_response_class: type[Response] = response_class.value
else:
actual_response_class = response_class
@@ -324,6 +275,18 @@ def get_request_handler(
"fastapi_middleware_astack not found in request scope"
)
# Extract endpoint context for error messages
endpoint_ctx = (
_extract_endpoint_context(dependant.call)
if dependant.call
else EndpointContext()
)
if dependant.path:
# For mounted sub-apps, include the mount path prefix
mount_path = request.scope.get("root_path", "").rstrip("/")
endpoint_ctx["path"] = f"{request.method} {mount_path}{dependant.path}"
# Read body and auto-close files
try:
body: Any = None
@@ -361,6 +324,7 @@ def get_request_handler(
}
],
body=e.doc,
endpoint_ctx=endpoint_ctx,
)
raise validation_error from e
except HTTPException:
@@ -373,7 +337,7 @@ def get_request_handler(
raise http_error from e
# Solve dependencies and run path operation function, auto-closing dependencies
errors: List[Any] = []
errors: list[Any] = []
async_exit_stack = request.scope.get("fastapi_inner_astack")
assert isinstance(async_exit_stack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
@@ -398,7 +362,7 @@ def get_request_handler(
raw_response.background = solved_result.background_tasks
response = raw_response
else:
response_args: Dict[str, Any] = {
response_args: dict[str, Any] = {
"background": solved_result.background_tasks
}
# If status_code was set, use it, otherwise use the default from the
@@ -420,6 +384,7 @@ def get_request_handler(
exclude_defaults=response_model_exclude_defaults,
exclude_none=response_model_exclude_none,
is_coroutine=is_coroutine,
endpoint_ctx=endpoint_ctx,
)
response = actual_response_class(content, **response_args)
if not is_body_allowed_for_status_code(response.status_code):
@@ -427,7 +392,7 @@ def get_request_handler(
response.headers.raw.extend(solved_result.response.headers.raw)
if errors:
validation_error = RequestValidationError(
_normalize_errors(errors), body=body
errors, body=body, endpoint_ctx=endpoint_ctx
)
raise validation_error
@@ -444,6 +409,15 @@ def get_websocket_app(
embed_body_fields: bool = False,
) -> Callable[[WebSocket], Coroutine[Any, Any, Any]]:
async def app(websocket: WebSocket) -> None:
endpoint_ctx = (
_extract_endpoint_context(dependant.call)
if dependant.call
else EndpointContext()
)
if dependant.path:
# For mounted sub-apps, include the mount path prefix
mount_path = websocket.scope.get("root_path", "").rstrip("/")
endpoint_ctx["path"] = f"WS {mount_path}{dependant.path}"
async_exit_stack = websocket.scope.get("fastapi_inner_astack")
assert isinstance(async_exit_stack, AsyncExitStack), (
"fastapi_inner_astack not found in request scope"
@@ -457,7 +431,8 @@ def get_websocket_app(
)
if solved_result.errors:
raise WebSocketRequestValidationError(
_normalize_errors(solved_result.errors)
solved_result.errors,
endpoint_ctx=endpoint_ctx,
)
assert dependant.call is not None, "dependant.call must be a function"
await dependant.call(**solved_result.values)
@@ -500,7 +475,7 @@ class APIWebSocketRoute(routing.WebSocketRoute):
)
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
def matches(self, scope: Scope) -> tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
@@ -515,15 +490,15 @@ class APIRoute(routing.Route):
*,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
tags: Optional[list[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
responses: Optional[dict[Union[int, str], dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
name: Optional[str] = None,
methods: Optional[Union[Set[str], List[str]]] = None,
methods: Optional[Union[set[str], list[str]]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[IncEx] = None,
response_model_exclude: Optional[IncEx] = None,
@@ -532,12 +507,12 @@ class APIRoute(routing.Route):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
response_class: Union[type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
dependency_overrides_provider: Optional[Any] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
callbacks: Optional[list[BaseRoute]] = None,
openapi_extra: Optional[dict[str, Any]] = None,
generate_unique_id_function: Union[
Callable[["APIRoute"], str], DefaultPlaceholder
] = Default(generate_unique_id),
@@ -573,7 +548,7 @@ class APIRoute(routing.Route):
self.path_regex, self.path_format, self.param_convertors = compile_path(path)
if methods is None:
methods = ["GET"]
self.methods: Set[str] = {method.upper() for method in methods}
self.methods: set[str] = {method.upper() for method in methods}
if isinstance(generate_unique_id_function, DefaultPlaceholder):
current_generate_unique_id: Callable[[APIRoute], str] = (
generate_unique_id_function.value
@@ -590,6 +565,11 @@ class APIRoute(routing.Route):
f"Status code {status_code} must not have a response body"
)
response_name = "Response_" + self.unique_id
if annotation_is_pydantic_v1(self.response_model):
raise PydanticV1NotSupportedError(
"pydantic.v1 models are no longer supported by FastAPI."
f" Please update the response model {self.response_model!r}."
)
self.response_field = create_model_field(
name=response_name,
type_=self.response_model,
@@ -623,12 +603,17 @@ class APIRoute(routing.Route):
f"Status code {additional_status_code} must not have a response body"
)
response_name = f"Response_{additional_status_code}_{self.unique_id}"
if annotation_is_pydantic_v1(model):
raise PydanticV1NotSupportedError(
"pydantic.v1 models are no longer supported by FastAPI."
f" In responses={{}}, please update {model}."
)
response_field = create_model_field(
name=response_name, type_=model, mode="serialization"
)
response_fields[additional_status_code] = response_field
if response_fields:
self.response_fields: Dict[Union[int, str], ModelField] = response_fields
self.response_fields: dict[Union[int, str], ModelField] = response_fields
else:
self.response_fields = {}
@@ -669,7 +654,7 @@ class APIRoute(routing.Route):
embed_body_fields=self._embed_body_fields,
)
def matches(self, scope: Scope) -> Tuple[Match, Scope]:
def matches(self, scope: Scope) -> tuple[Match, Scope]:
match, child_scope = super().matches(scope)
if match != Match.NONE:
child_scope["route"] = self
@@ -708,7 +693,7 @@ class APIRouter(routing.Router):
*,
prefix: Annotated[str, Doc("An optional path prefix for the router.")] = "",
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to all the *path operations* in this
@@ -734,7 +719,7 @@ class APIRouter(routing.Router):
),
] = None,
default_response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
The default response class to be used.
@@ -745,7 +730,7 @@ class APIRouter(routing.Router):
),
] = Default(JSONResponse),
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses to be shown in OpenAPI.
@@ -761,7 +746,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
OpenAPI callbacks that should apply to all *path operations* in this
@@ -775,7 +760,7 @@ class APIRouter(routing.Router):
),
] = None,
routes: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
**Note**: you probably shouldn't use this parameter, it is inherited
@@ -826,7 +811,7 @@ class APIRouter(routing.Router):
),
] = None,
route_class: Annotated[
Type[APIRoute],
type[APIRoute],
Doc(
"""
Custom route (*path operation*) class to be used by this router.
@@ -932,7 +917,7 @@ class APIRouter(routing.Router):
"A path prefix must not end with '/', as the routes will start with '/'"
)
self.prefix = prefix
self.tags: List[Union[str, Enum]] = tags or []
self.tags: list[Union[str, Enum]] = tags or []
self.dependencies = list(dependencies or [])
self.deprecated = deprecated
self.include_in_schema = include_in_schema
@@ -969,14 +954,14 @@ class APIRouter(routing.Router):
*,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
tags: Optional[list[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
responses: Optional[dict[Union[int, str], dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[Union[Set[str], List[str]]] = None,
methods: Optional[Union[set[str], list[str]]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[IncEx] = None,
response_model_exclude: Optional[IncEx] = None,
@@ -985,13 +970,13 @@ class APIRouter(routing.Router):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Union[Type[Response], DefaultPlaceholder] = Default(
response_class: Union[type[Response], DefaultPlaceholder] = Default(
JSONResponse
),
name: Optional[str] = None,
route_class_override: Optional[Type[APIRoute]] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
route_class_override: Optional[type[APIRoute]] = None,
callbacks: Optional[list[BaseRoute]] = None,
openapi_extra: Optional[dict[str, Any]] = None,
generate_unique_id_function: Union[
Callable[[APIRoute], str], DefaultPlaceholder
] = Default(generate_unique_id),
@@ -1050,14 +1035,14 @@ class APIRouter(routing.Router):
*,
response_model: Any = Default(None),
status_code: Optional[int] = None,
tags: Optional[List[Union[str, Enum]]] = None,
tags: Optional[list[Union[str, Enum]]] = None,
dependencies: Optional[Sequence[params.Depends]] = None,
summary: Optional[str] = None,
description: Optional[str] = None,
response_description: str = "Successful Response",
responses: Optional[Dict[Union[int, str], Dict[str, Any]]] = None,
responses: Optional[dict[Union[int, str], dict[str, Any]]] = None,
deprecated: Optional[bool] = None,
methods: Optional[List[str]] = None,
methods: Optional[list[str]] = None,
operation_id: Optional[str] = None,
response_model_include: Optional[IncEx] = None,
response_model_exclude: Optional[IncEx] = None,
@@ -1066,10 +1051,10 @@ class APIRouter(routing.Router):
response_model_exclude_defaults: bool = False,
response_model_exclude_none: bool = False,
include_in_schema: bool = True,
response_class: Type[Response] = Default(JSONResponse),
response_class: type[Response] = Default(JSONResponse),
name: Optional[str] = None,
callbacks: Optional[List[BaseRoute]] = None,
openapi_extra: Optional[Dict[str, Any]] = None,
callbacks: Optional[list[BaseRoute]] = None,
openapi_extra: Optional[dict[str, Any]] = None,
generate_unique_id_function: Callable[[APIRoute], str] = Default(
generate_unique_id
),
@@ -1209,7 +1194,7 @@ class APIRouter(routing.Router):
*,
prefix: Annotated[str, Doc("An optional path prefix for the router.")] = "",
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to all the *path operations* in this
@@ -1235,7 +1220,7 @@ class APIRouter(routing.Router):
),
] = None,
default_response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
The default response class to be used.
@@ -1246,7 +1231,7 @@ class APIRouter(routing.Router):
),
] = Default(JSONResponse),
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses to be shown in OpenAPI.
@@ -1262,7 +1247,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
OpenAPI callbacks that should apply to all *path operations* in this
@@ -1367,7 +1352,7 @@ class APIRouter(routing.Router):
current_tags.extend(tags)
if route.tags:
current_tags.extend(route.tags)
current_dependencies: List[params.Depends] = []
current_dependencies: list[params.Depends] = []
if dependencies:
current_dependencies.extend(dependencies)
if route.dependencies:
@@ -1508,7 +1493,7 @@ class APIRouter(routing.Router):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -1574,7 +1559,7 @@ class APIRouter(routing.Router):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -1715,7 +1700,7 @@ class APIRouter(routing.Router):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -1736,7 +1721,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -1752,7 +1737,7 @@ class APIRouter(routing.Router):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -1885,7 +1870,7 @@ class APIRouter(routing.Router):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -1951,7 +1936,7 @@ class APIRouter(routing.Router):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -2092,7 +2077,7 @@ class APIRouter(routing.Router):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -2113,7 +2098,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -2129,7 +2114,7 @@ class APIRouter(routing.Router):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -2267,7 +2252,7 @@ class APIRouter(routing.Router):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -2333,7 +2318,7 @@ class APIRouter(routing.Router):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -2474,7 +2459,7 @@ class APIRouter(routing.Router):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -2495,7 +2480,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -2511,7 +2496,7 @@ class APIRouter(routing.Router):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -2649,7 +2634,7 @@ class APIRouter(routing.Router):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -2715,7 +2700,7 @@ class APIRouter(routing.Router):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -2856,7 +2841,7 @@ class APIRouter(routing.Router):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -2877,7 +2862,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -2893,7 +2878,7 @@ class APIRouter(routing.Router):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -3026,7 +3011,7 @@ class APIRouter(routing.Router):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -3092,7 +3077,7 @@ class APIRouter(routing.Router):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -3233,7 +3218,7 @@ class APIRouter(routing.Router):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -3254,7 +3239,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -3270,7 +3255,7 @@ class APIRouter(routing.Router):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -3403,7 +3388,7 @@ class APIRouter(routing.Router):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -3469,7 +3454,7 @@ class APIRouter(routing.Router):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -3610,7 +3595,7 @@ class APIRouter(routing.Router):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -3631,7 +3616,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -3647,7 +3632,7 @@ class APIRouter(routing.Router):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -3785,7 +3770,7 @@ class APIRouter(routing.Router):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -3851,7 +3836,7 @@ class APIRouter(routing.Router):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -3992,7 +3977,7 @@ class APIRouter(routing.Router):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -4013,7 +3998,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -4029,7 +4014,7 @@ class APIRouter(routing.Router):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path
@@ -4167,7 +4152,7 @@ class APIRouter(routing.Router):
),
] = None,
tags: Annotated[
Optional[List[Union[str, Enum]]],
Optional[list[Union[str, Enum]]],
Doc(
"""
A list of tags to be applied to the *path operation*.
@@ -4233,7 +4218,7 @@ class APIRouter(routing.Router):
),
] = "Successful Response",
responses: Annotated[
Optional[Dict[Union[int, str], Dict[str, Any]]],
Optional[dict[Union[int, str], dict[str, Any]]],
Doc(
"""
Additional responses that could be returned by this *path operation*.
@@ -4374,7 +4359,7 @@ class APIRouter(routing.Router):
),
] = True,
response_class: Annotated[
Type[Response],
type[Response],
Doc(
"""
Response class to be used for this *path operation*.
@@ -4395,7 +4380,7 @@ class APIRouter(routing.Router):
),
] = None,
callbacks: Annotated[
Optional[List[BaseRoute]],
Optional[list[BaseRoute]],
Doc(
"""
List of *path operations* that will be used as OpenAPI callbacks.
@@ -4411,7 +4396,7 @@ class APIRouter(routing.Router):
),
] = None,
openapi_extra: Annotated[
Optional[Dict[str, Any]],
Optional[dict[str, Any]],
Doc(
"""
Extra metadata to be included in the OpenAPI schema for this *path

View File

@@ -1,22 +1,51 @@
from typing import Optional
from typing import Annotated, Optional, Union
from annotated_doc import Doc
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated
from starlette.status import HTTP_401_UNAUTHORIZED
class APIKeyBase(SecurityBase):
@staticmethod
def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]:
def __init__(
self,
location: APIKeyIn,
name: str,
description: Union[str, None],
scheme_name: Union[str, None],
auto_error: bool,
):
self.auto_error = auto_error
self.model: APIKey = APIKey(
**{"in": location},
name=name,
description=description,
)
self.scheme_name = scheme_name or self.__class__.__name__
def make_not_authenticated_error(self) -> HTTPException:
"""
The WWW-Authenticate header is not standardized for API Key authentication but
the HTTP specification requires that an error of 401 "Unauthorized" must
include a WWW-Authenticate header.
Ref: https://datatracker.ietf.org/doc/html/rfc9110#name-401-unauthorized
For this, this method sends a custom challenge `APIKey`.
"""
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "APIKey"},
)
def check_api_key(self, api_key: Optional[str]) -> Optional[str]:
if not api_key:
if auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
if self.auto_error:
raise self.make_not_authenticated_error()
return None
return api_key
@@ -100,17 +129,17 @@ class APIKeyQuery(APIKeyBase):
),
] = True,
):
self.model: APIKey = APIKey(
**{"in": APIKeyIn.query},
super().__init__(
location=APIKeyIn.query,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.query_params.get(self.model.name)
return self.check_api_key(api_key, self.auto_error)
return self.check_api_key(api_key)
class APIKeyHeader(APIKeyBase):
@@ -188,17 +217,17 @@ class APIKeyHeader(APIKeyBase):
),
] = True,
):
self.model: APIKey = APIKey(
**{"in": APIKeyIn.header},
super().__init__(
location=APIKeyIn.header,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.headers.get(self.model.name)
return self.check_api_key(api_key, self.auto_error)
return self.check_api_key(api_key)
class APIKeyCookie(APIKeyBase):
@@ -276,14 +305,14 @@ class APIKeyCookie(APIKeyBase):
),
] = True,
):
self.model: APIKey = APIKey(
**{"in": APIKeyIn.cookie},
super().__init__(
location=APIKeyIn.cookie,
name=name,
scheme_name=scheme_name,
description=description,
auto_error=auto_error,
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
return self.check_api_key(api_key, self.auto_error)
return self.check_api_key(api_key)

View File

@@ -1,6 +1,6 @@
import binascii
from base64 import b64decode
from typing import Optional
from typing import Annotated, Optional
from annotated_doc import Doc
from fastapi.exceptions import HTTPException
@@ -10,8 +10,7 @@ from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing_extensions import Annotated
from starlette.status import HTTP_401_UNAUTHORIZED
class HTTPBasicCredentials(BaseModel):
@@ -76,10 +75,22 @@ class HTTPBase(SecurityBase):
description: Optional[str] = None,
auto_error: bool = True,
):
self.model = HTTPBaseModel(scheme=scheme, description=description)
self.model: HTTPBaseModel = HTTPBaseModel(
scheme=scheme, description=description
)
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_authenticate_headers(self) -> dict[str, str]:
return {"WWW-Authenticate": f"{self.model.scheme.title()}"}
def make_not_authenticated_error(self) -> HTTPException:
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers=self.make_authenticate_headers(),
)
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
@@ -87,9 +98,7 @@ class HTTPBase(SecurityBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
@@ -99,6 +108,8 @@ class HTTPBasic(HTTPBase):
"""
HTTP Basic authentication.
Ref: https://datatracker.ietf.org/doc/html/rfc7617
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
@@ -185,36 +196,28 @@ class HTTPBasic(HTTPBase):
self.realm = realm
self.auto_error = auto_error
def make_authenticate_headers(self) -> dict[str, str]:
if self.realm:
return {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
return {"WWW-Authenticate": "Basic"}
async def __call__( # type: ignore
self, request: Request
) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if self.realm:
unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
else:
unauthorized_headers = {"WWW-Authenticate": "Basic"}
if not authorization or scheme.lower() != "basic":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers=unauthorized_headers,
)
raise self.make_not_authenticated_error()
else:
return None
invalid_user_credentials_exc = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
headers=unauthorized_headers,
)
try:
data = b64decode(param).decode("ascii")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise invalid_user_credentials_exc # noqa: B904
except (ValueError, UnicodeDecodeError, binascii.Error) as e:
raise self.make_not_authenticated_error() from e
username, separator, password = data.partition(":")
if not separator:
raise invalid_user_credentials_exc
raise self.make_not_authenticated_error()
return HTTPBasicCredentials(username=username, password=password)
@@ -306,17 +309,12 @@ class HTTPBearer(HTTPBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
if scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
@@ -326,6 +324,12 @@ class HTTPDigest(HTTPBase):
"""
HTTP Digest authentication.
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
but it doesn't implement the full Digest scheme, you would need to to subclass it
and implement it in your code.
Ref: https://datatracker.ietf.org/doc/html/rfc7616
## Usage
Create an instance object and use that object as the dependency in `Depends()`.
@@ -408,17 +412,12 @@ class HTTPDigest(HTTPBase):
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
if scheme.lower() != "digest":
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="Invalid authentication credentials",
)
raise self.make_not_authenticated_error()
else:
return None
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union, cast
from typing import Annotated, Any, Optional, Union, cast
from annotated_doc import Doc
from fastapi.exceptions import HTTPException
@@ -8,10 +8,7 @@ from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
# TODO: import from typing when deprecating Python 3.9
from typing_extensions import Annotated
from starlette.status import HTTP_401_UNAUTHORIZED
class OAuth2PasswordRequestForm:
@@ -323,7 +320,7 @@ class OAuth2(SecurityBase):
self,
*,
flows: Annotated[
Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]],
Union[OAuthFlowsModel, dict[str, dict[str, Any]]],
Doc(
"""
The dictionary of OAuth2 flows.
@@ -377,13 +374,33 @@ class OAuth2(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_not_authenticated_error(self) -> HTTPException:
"""
The OAuth 2 specification doesn't define the challenge that should be used,
because a `Bearer` token is not really the only option to authenticate.
But declaring any other authentication challenge would be application-specific
as it's not defined in the specification.
For practical reasons, this method uses the `Bearer` challenge by default, as
it's probably the most common one.
If you are implementing an OAuth2 authentication scheme other than the provided
ones in FastAPI (based on bearer tokens), you might want to override this.
Ref: https://datatracker.ietf.org/doc/html/rfc6749
"""
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
async def __call__(self, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
return authorization
@@ -420,7 +437,7 @@ class OAuth2PasswordBearer(OAuth2):
),
] = None,
scopes: Annotated[
Optional[Dict[str, str]],
Optional[dict[str, str]],
Doc(
"""
The OAuth2 scopes that would be required by the *path operations* that
@@ -491,11 +508,7 @@ class OAuth2PasswordBearer(OAuth2):
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
raise self.make_not_authenticated_error()
else:
return None
return param
@@ -537,7 +550,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
),
] = None,
scopes: Annotated[
Optional[Dict[str, str]],
Optional[dict[str, str]],
Doc(
"""
The OAuth2 scopes that would be required by the *path operations* that
@@ -601,11 +614,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
raise self.make_not_authenticated_error()
else:
return None # pragma: nocover
return param
@@ -627,7 +636,7 @@ class SecurityScopes:
def __init__(
self,
scopes: Annotated[
Optional[List[str]],
Optional[list[str]],
Doc(
"""
This will be filled by FastAPI.
@@ -636,7 +645,7 @@ class SecurityScopes:
] = None,
):
self.scopes: Annotated[
List[str],
list[str],
Doc(
"""
The list of all the scopes required by dependencies.

View File

@@ -1,18 +1,22 @@
from typing import Optional
from typing import Annotated, Optional
from annotated_doc import Doc
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated
from starlette.status import HTTP_401_UNAUTHORIZED
class OpenIdConnect(SecurityBase):
"""
OpenID Connect authentication class. An instance of it would be used as a
dependency.
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
but it doesn't implement the full OpenIdConnect scheme, for example, it doesn't use
the OpenIDConnect URL. You would need to to subclass it and implement it in your
code.
"""
def __init__(
@@ -73,13 +77,18 @@ class OpenIdConnect(SecurityBase):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
def make_not_authenticated_error(self) -> HTTPException:
return HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
async def __call__(self, request: Request) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
)
raise self.make_not_authenticated_error()
else:
return None
return authorization

View File

@@ -1,9 +1,9 @@
from typing import Optional, Tuple
from typing import Optional
def get_authorization_scheme_param(
authorization_header_value: Optional[str],
) -> Tuple[str, str]:
) -> tuple[str, str]:
if not authorization_header_value:
return "", ""
scheme, _, param = authorization_header_value.partition(" ")

View File

@@ -1,724 +0,0 @@
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
from fastapi.openapi.models import Example
from fastapi.params import ParamTypes
from typing_extensions import Annotated, deprecated
from ._compat.may_v1 import FieldInfo, Undefined
from ._compat.shared import PYDANTIC_VERSION_MINOR_TUPLE
_Unset: Any = Undefined
class Param(FieldInfo): # type: ignore[misc]
in_: ParamTypes
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
category=DeprecationWarning,
stacklevel=4,
)
self.example = example
self.include_in_schema = include_in_schema
self.openapi_examples = openapi_examples
kwargs = dict(
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
discriminator=discriminator,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
**extra,
)
if examples is not None:
kwargs["examples"] = examples
if regex is not None:
warnings.warn(
"`regex` has been deprecated, please use `pattern` instead",
category=DeprecationWarning,
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
self.deprecated = deprecated
else:
kwargs["deprecated"] = deprecated
kwargs["regex"] = pattern or regex
kwargs.update(**current_json_schema_extra)
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"
class Path(Param): # type: ignore[misc]
in_ = ParamTypes.path
def __init__(
self,
default: Any = ...,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
assert default is ..., "Path parameters cannot have a default value"
self.in_ = self.in_
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Query(Param): # type: ignore[misc]
in_ = ParamTypes.query
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Header(Param): # type: ignore[misc]
in_ = ParamTypes.header
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
convert_underscores: bool = True,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
self.convert_underscores = convert_underscores
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Cookie(Param): # type: ignore[misc]
in_ = ParamTypes.cookie
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class Body(FieldInfo): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
embed: Union[bool, None] = None,
media_type: str = "application/json",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
self.embed = embed
self.media_type = media_type
if example is not _Unset:
warnings.warn(
"`example` has been deprecated, please use `examples` instead",
category=DeprecationWarning,
stacklevel=4,
)
self.example = example
self.include_in_schema = include_in_schema
self.openapi_examples = openapi_examples
kwargs = dict(
default=default,
default_factory=default_factory,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
discriminator=discriminator,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
**extra,
)
if examples is not None:
kwargs["examples"] = examples
if regex is not None:
warnings.warn(
"`regex` has been deprecated, please use `pattern` instead",
category=DeprecationWarning,
stacklevel=4,
)
current_json_schema_extra = json_schema_extra or extra
if PYDANTIC_VERSION_MINOR_TUPLE < (2, 7):
self.deprecated = deprecated
else:
kwargs["deprecated"] = deprecated
kwargs["regex"] = pattern or regex
kwargs.update(**current_json_schema_extra)
use_kwargs = {k: v for k, v in kwargs.items() if v is not _Unset}
super().__init__(**use_kwargs)
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.default})"
class Form(Body): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
media_type: str = "application/x-www-form-urlencoded",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
media_type=media_type,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)
class File(Form): # type: ignore[misc]
def __init__(
self,
default: Any = Undefined,
*,
default_factory: Union[Callable[[], Any], None] = _Unset,
annotation: Optional[Any] = None,
media_type: str = "multipart/form-data",
alias: Optional[str] = None,
alias_priority: Union[int, None] = _Unset,
# TODO: update when deprecating Pydantic v1, import these types
# validation_alias: str | AliasPath | AliasChoices | None
validation_alias: Union[str, None] = None,
serialization_alias: Union[str, None] = None,
title: Optional[str] = None,
description: Optional[str] = None,
gt: Optional[float] = None,
ge: Optional[float] = None,
lt: Optional[float] = None,
le: Optional[float] = None,
min_length: Optional[int] = None,
max_length: Optional[int] = None,
pattern: Optional[str] = None,
regex: Annotated[
Optional[str],
deprecated(
"Deprecated in FastAPI 0.100.0 and Pydantic v2, use `pattern` instead."
),
] = None,
discriminator: Union[str, None] = None,
strict: Union[bool, None] = _Unset,
multiple_of: Union[float, None] = _Unset,
allow_inf_nan: Union[bool, None] = _Unset,
max_digits: Union[int, None] = _Unset,
decimal_places: Union[int, None] = _Unset,
examples: Optional[List[Any]] = None,
example: Annotated[
Optional[Any],
deprecated(
"Deprecated in OpenAPI 3.1.0 that now uses JSON Schema 2020-12, "
"although still supported. Use examples instead."
),
] = _Unset,
openapi_examples: Optional[Dict[str, Example]] = None,
deprecated: Union[deprecated, str, bool, None] = None,
include_in_schema: bool = True,
json_schema_extra: Union[Dict[str, Any], None] = None,
**extra: Any,
):
super().__init__(
default=default,
default_factory=default_factory,
annotation=annotation,
media_type=media_type,
alias=alias,
alias_priority=alias_priority,
validation_alias=validation_alias,
serialization_alias=serialization_alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
pattern=pattern,
regex=regex,
discriminator=discriminator,
strict=strict,
multiple_of=multiple_of,
allow_inf_nan=allow_inf_nan,
max_digits=max_digits,
decimal_places=decimal_places,
deprecated=deprecated,
example=example,
examples=examples,
openapi_examples=openapi_examples,
include_in_schema=include_in_schema,
json_schema_extra=json_schema_extra,
**extra,
)

View File

@@ -1,11 +1,11 @@
import types
from enum import Enum
from typing import Any, Callable, Dict, Optional, Set, Tuple, Type, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union
from pydantic import BaseModel
DecoratedCallable = TypeVar("DecoratedCallable", bound=Callable[..., Any])
UnionType = getattr(types, "UnionType", Union)
ModelNameMap = Dict[Union[Type[BaseModel], Type[Enum]], str]
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
DependencyCacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...], str]
ModelNameMap = dict[Union[type[BaseModel], type[Enum]], str]
IncEx = Union[set[int], set[str], dict[int, Any], dict[str, Any]]
DependencyCacheKey = tuple[Optional[Callable[..., Any]], tuple[str, ...], str]

View File

@@ -1,22 +1,16 @@
import re
import warnings
from dataclasses import is_dataclass
from collections.abc import MutableMapping
from typing import (
TYPE_CHECKING,
Any,
Dict,
MutableMapping,
Optional,
Set,
Type,
Union,
cast,
)
from weakref import WeakKeyDictionary
import fastapi
from fastapi._compat import (
PYDANTIC_V2,
BaseConfig,
ModelField,
PydanticSchemaGenerationError,
@@ -24,19 +18,20 @@ from fastapi._compat import (
UndefinedType,
Validator,
annotation_is_pydantic_v1,
lenient_issubclass,
may_v1,
)
from fastapi.datastructures import DefaultPlaceholder, DefaultType
from fastapi.exceptions import FastAPIDeprecationWarning, PydanticV1NotSupportedError
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from typing_extensions import Literal
from ._compat import v2
if TYPE_CHECKING: # pragma: nocover
from .routing import APIRoute
# Cache for `create_cloned_field`
_CLONED_TYPES_CACHE: MutableMapping[Type[BaseModel], Type[BaseModel]] = (
_CLONED_TYPES_CACHE: MutableMapping[type[BaseModel], type[BaseModel]] = (
WeakKeyDictionary()
)
@@ -58,7 +53,7 @@ def is_body_allowed_for_status_code(status_code: Union[int, str, None]) -> bool:
return not (current_status_code < 200 or current_status_code in {204, 205, 304})
def get_path_param_names(path: str) -> Set[str]:
def get_path_param_names(path: str) -> set[str]:
return set(re.findall("{(.*?)}", path))
@@ -76,132 +71,47 @@ _invalid_args_message = (
def create_model_field(
name: str,
type_: Any,
class_validators: Optional[Dict[str, Validator]] = None,
class_validators: Optional[dict[str, Validator]] = None,
default: Optional[Any] = Undefined,
required: Union[bool, UndefinedType] = Undefined,
model_config: Union[Type[BaseConfig], None] = None,
model_config: Union[type[BaseConfig], None] = None,
field_info: Optional[FieldInfo] = None,
alias: Optional[str] = None,
mode: Literal["validation", "serialization"] = "validation",
version: Literal["1", "auto"] = "auto",
) -> ModelField:
if annotation_is_pydantic_v1(type_):
raise PydanticV1NotSupportedError(
"pydantic.v1 models are no longer supported by FastAPI."
f" Please update the response model {type_!r}."
)
class_validators = class_validators or {}
v1_model_config = may_v1.BaseConfig
v1_field_info = field_info or may_v1.FieldInfo()
v1_kwargs = {
"name": name,
"field_info": v1_field_info,
"type_": type_,
"class_validators": class_validators,
"default": default,
"required": required,
"model_config": v1_model_config,
"alias": alias,
}
if (
annotation_is_pydantic_v1(type_)
or isinstance(field_info, may_v1.FieldInfo)
or version == "1"
):
from fastapi._compat import v1
try:
return v1.ModelField(**v1_kwargs) # type: ignore[no-any-return]
except RuntimeError:
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
elif PYDANTIC_V2:
from ._compat import v2
field_info = field_info or FieldInfo(
annotation=type_, default=default, alias=alias
)
kwargs = {"mode": mode, "name": name, "field_info": field_info}
try:
return v2.ModelField(**kwargs) # type: ignore[return-value,arg-type]
except PydanticSchemaGenerationError:
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
# Pydantic v2 is not installed, but it's not a Pydantic v1 ModelField, it could be
# a Pydantic v1 type, like a constrained int
from fastapi._compat import v1
field_info = field_info or FieldInfo(annotation=type_, default=default, alias=alias)
kwargs = {"mode": mode, "name": name, "field_info": field_info}
try:
return v1.ModelField(**v1_kwargs) # type: ignore[no-any-return]
except RuntimeError:
raise fastapi.exceptions.FastAPIError(_invalid_args_message) from None
return v2.ModelField(**kwargs) # type: ignore[return-value,arg-type]
except PydanticSchemaGenerationError:
raise fastapi.exceptions.FastAPIError(
_invalid_args_message.format(type_=type_)
) from None
def create_cloned_field(
field: ModelField,
*,
cloned_types: Optional[MutableMapping[Type[BaseModel], Type[BaseModel]]] = None,
cloned_types: Optional[MutableMapping[type[BaseModel], type[BaseModel]]] = None,
) -> ModelField:
if PYDANTIC_V2:
from ._compat import v2
if isinstance(field, v2.ModelField):
return field
from fastapi._compat import v1
# cloned_types caches already cloned types to support recursive models and improve
# performance by avoiding unnecessary cloning
if cloned_types is None:
cloned_types = _CLONED_TYPES_CACHE
original_type = field.type_
if is_dataclass(original_type) and hasattr(original_type, "__pydantic_model__"):
original_type = original_type.__pydantic_model__
use_type = original_type
if lenient_issubclass(original_type, v1.BaseModel):
original_type = cast(Type[v1.BaseModel], original_type)
use_type = cloned_types.get(original_type)
if use_type is None:
use_type = v1.create_model(original_type.__name__, __base__=original_type)
cloned_types[original_type] = use_type
for f in original_type.__fields__.values():
use_type.__fields__[f.name] = create_cloned_field(
f,
cloned_types=cloned_types,
)
new_field = create_model_field(name=field.name, type_=use_type, version="1")
new_field.has_alias = field.has_alias # type: ignore[attr-defined]
new_field.alias = field.alias # type: ignore[misc]
new_field.class_validators = field.class_validators # type: ignore[attr-defined]
new_field.default = field.default # type: ignore[misc]
new_field.default_factory = field.default_factory # type: ignore[attr-defined]
new_field.required = field.required # type: ignore[misc]
new_field.model_config = field.model_config # type: ignore[attr-defined]
new_field.field_info = field.field_info
new_field.allow_none = field.allow_none # type: ignore[attr-defined]
new_field.validate_always = field.validate_always # type: ignore[attr-defined]
if field.sub_fields: # type: ignore[attr-defined]
new_field.sub_fields = [ # type: ignore[attr-defined]
create_cloned_field(sub_field, cloned_types=cloned_types)
for sub_field in field.sub_fields # type: ignore[attr-defined]
]
if field.key_field: # type: ignore[attr-defined]
new_field.key_field = create_cloned_field( # type: ignore[attr-defined]
field.key_field, # type: ignore[attr-defined]
cloned_types=cloned_types,
)
new_field.validators = field.validators # type: ignore[attr-defined]
new_field.pre_validators = field.pre_validators # type: ignore[attr-defined]
new_field.post_validators = field.post_validators # type: ignore[attr-defined]
new_field.parse_json = field.parse_json # type: ignore[attr-defined]
new_field.shape = field.shape # type: ignore[attr-defined]
new_field.populate_validators() # type: ignore[attr-defined]
return new_field
return field
def generate_operation_id_for_path(
*, name: str, path: str, method: str
) -> str: # pragma: nocover
warnings.warn(
"fastapi.utils.generate_operation_id_for_path() was deprecated, "
message="fastapi.utils.generate_operation_id_for_path() was deprecated, "
"it is not used internally, and will be removed soon",
DeprecationWarning,
category=FastAPIDeprecationWarning,
stacklevel=2,
)
operation_id = f"{name}{path}"
@@ -218,7 +128,7 @@ def generate_unique_id(route: "APIRoute") -> str:
return operation_id
def deep_dict_update(main_dict: Dict[Any, Any], update_dict: Dict[Any, Any]) -> None:
def deep_dict_update(main_dict: dict[Any, Any], update_dict: dict[Any, Any]) -> None:
for key, value in update_dict.items():
if (
key in main_dict