增加环绕侦察场景适配
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,8 +1,8 @@
|
||||
import inspect
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any, Callable, List, Optional, Sequence, Union
|
||||
from functools import cached_property, partial
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from fastapi._compat import ModelField
|
||||
from fastapi.security.base import SecurityBase
|
||||
@@ -15,21 +15,27 @@ else: # pragma: no cover
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityRequirement:
|
||||
security_scheme: SecurityBase
|
||||
scopes: Optional[Sequence[str]] = None
|
||||
def _unwrapped_call(call: Optional[Callable[..., Any]]) -> Any:
|
||||
if call is None:
|
||||
return call # pragma: no cover
|
||||
unwrapped = inspect.unwrap(_impartial(call))
|
||||
return unwrapped
|
||||
|
||||
|
||||
def _impartial(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
while isinstance(func, partial):
|
||||
func = func.func
|
||||
return func
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dependant:
|
||||
path_params: List[ModelField] = field(default_factory=list)
|
||||
query_params: List[ModelField] = field(default_factory=list)
|
||||
header_params: List[ModelField] = field(default_factory=list)
|
||||
cookie_params: List[ModelField] = field(default_factory=list)
|
||||
body_params: List[ModelField] = field(default_factory=list)
|
||||
dependencies: List["Dependant"] = field(default_factory=list)
|
||||
security_requirements: List[SecurityRequirement] = field(default_factory=list)
|
||||
path_params: list[ModelField] = field(default_factory=list)
|
||||
query_params: list[ModelField] = field(default_factory=list)
|
||||
header_params: list[ModelField] = field(default_factory=list)
|
||||
cookie_params: list[ModelField] = field(default_factory=list)
|
||||
body_params: list[ModelField] = field(default_factory=list)
|
||||
dependencies: list["Dependant"] = field(default_factory=list)
|
||||
name: Optional[str] = None
|
||||
call: Optional[Callable[..., Any]] = None
|
||||
request_param_name: Optional[str] = None
|
||||
@@ -38,41 +44,145 @@ class Dependant:
|
||||
response_param_name: Optional[str] = None
|
||||
background_tasks_param_name: Optional[str] = None
|
||||
security_scopes_param_name: Optional[str] = None
|
||||
security_scopes: Optional[List[str]] = None
|
||||
own_oauth_scopes: Optional[list[str]] = None
|
||||
parent_oauth_scopes: Optional[list[str]] = None
|
||||
use_cache: bool = True
|
||||
path: Optional[str] = None
|
||||
scope: Union[Literal["function", "request"], None] = None
|
||||
|
||||
@cached_property
|
||||
def oauth_scopes(self) -> list[str]:
|
||||
scopes = self.parent_oauth_scopes.copy() if self.parent_oauth_scopes else []
|
||||
# This doesn't use a set to preserve order, just in case
|
||||
for scope in self.own_oauth_scopes or []:
|
||||
if scope not in scopes:
|
||||
scopes.append(scope)
|
||||
return scopes
|
||||
|
||||
@cached_property
|
||||
def cache_key(self) -> DependencyCacheKey:
|
||||
scopes_for_cache = (
|
||||
tuple(sorted(set(self.oauth_scopes or []))) if self._uses_scopes else ()
|
||||
)
|
||||
return (
|
||||
self.call,
|
||||
tuple(sorted(set(self.security_scopes or []))),
|
||||
scopes_for_cache,
|
||||
self.computed_scope or "",
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def is_gen_callable(self) -> bool:
|
||||
if inspect.isgeneratorfunction(self.call):
|
||||
def _uses_scopes(self) -> bool:
|
||||
if self.own_oauth_scopes:
|
||||
return True
|
||||
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||
return inspect.isgeneratorfunction(dunder_call)
|
||||
if self.security_scopes_param_name is not None:
|
||||
return True
|
||||
if self._is_security_scheme:
|
||||
return True
|
||||
for sub_dep in self.dependencies:
|
||||
if sub_dep._uses_scopes:
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def _is_security_scheme(self) -> bool:
|
||||
if self.call is None:
|
||||
return False # pragma: no cover
|
||||
unwrapped = _unwrapped_call(self.call)
|
||||
return isinstance(unwrapped, SecurityBase)
|
||||
|
||||
# Mainly to get the type of SecurityBase, but it's the same self.call
|
||||
@cached_property
|
||||
def _security_scheme(self) -> SecurityBase:
|
||||
unwrapped = _unwrapped_call(self.call)
|
||||
assert isinstance(unwrapped, SecurityBase)
|
||||
return unwrapped
|
||||
|
||||
@cached_property
|
||||
def _security_dependencies(self) -> list["Dependant"]:
|
||||
security_deps = [dep for dep in self.dependencies if dep._is_security_scheme]
|
||||
return security_deps
|
||||
|
||||
@cached_property
|
||||
def is_gen_callable(self) -> bool:
|
||||
if self.call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isgeneratorfunction(
|
||||
_impartial(self.call)
|
||||
) or inspect.isgeneratorfunction(_unwrapped_call(self.call)):
|
||||
return True
|
||||
if inspect.isclass(_unwrapped_call(self.call)):
|
||||
return False
|
||||
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isgeneratorfunction(
|
||||
_impartial(dunder_call)
|
||||
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_call)):
|
||||
return True
|
||||
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_unwrapped_call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isgeneratorfunction(
|
||||
_impartial(dunder_unwrapped_call)
|
||||
) or inspect.isgeneratorfunction(_unwrapped_call(dunder_unwrapped_call)):
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def is_async_gen_callable(self) -> bool:
|
||||
if inspect.isasyncgenfunction(self.call):
|
||||
if self.call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isasyncgenfunction(
|
||||
_impartial(self.call)
|
||||
) or inspect.isasyncgenfunction(_unwrapped_call(self.call)):
|
||||
return True
|
||||
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||
return inspect.isasyncgenfunction(dunder_call)
|
||||
if inspect.isclass(_unwrapped_call(self.call)):
|
||||
return False
|
||||
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isasyncgenfunction(
|
||||
_impartial(dunder_call)
|
||||
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_call)):
|
||||
return True
|
||||
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_unwrapped_call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isasyncgenfunction(
|
||||
_impartial(dunder_unwrapped_call)
|
||||
) or inspect.isasyncgenfunction(_unwrapped_call(dunder_unwrapped_call)):
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def is_coroutine_callable(self) -> bool:
|
||||
if inspect.isroutine(self.call):
|
||||
return iscoroutinefunction(self.call)
|
||||
if inspect.isclass(self.call):
|
||||
if self.call is None:
|
||||
return False # pragma: no cover
|
||||
if inspect.isroutine(_impartial(self.call)) and iscoroutinefunction(
|
||||
_impartial(self.call)
|
||||
):
|
||||
return True
|
||||
if inspect.isroutine(_unwrapped_call(self.call)) and iscoroutinefunction(
|
||||
_unwrapped_call(self.call)
|
||||
):
|
||||
return True
|
||||
if inspect.isclass(_unwrapped_call(self.call)):
|
||||
return False
|
||||
dunder_call = getattr(self.call, "__call__", None) # noqa: B004
|
||||
return iscoroutinefunction(dunder_call)
|
||||
dunder_call = getattr(_impartial(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_call is None:
|
||||
return False # pragma: no cover
|
||||
if iscoroutinefunction(_impartial(dunder_call)) or iscoroutinefunction(
|
||||
_unwrapped_call(dunder_call)
|
||||
):
|
||||
return True
|
||||
dunder_unwrapped_call = getattr(_unwrapped_call(self.call), "__call__", None) # noqa: B004
|
||||
if dunder_unwrapped_call is None:
|
||||
return False # pragma: no cover
|
||||
if iscoroutinefunction(
|
||||
_impartial(dunder_unwrapped_call)
|
||||
) or iscoroutinefunction(_unwrapped_call(dunder_unwrapped_call)):
|
||||
return True
|
||||
return False
|
||||
|
||||
@cached_property
|
||||
def computed_scope(self) -> Union[str, None]:
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
from collections.abc import Coroutine, Mapping, Sequence
|
||||
from contextlib import AsyncExitStack, contextmanager
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
@@ -21,17 +18,14 @@ from typing import (
|
||||
import anyio
|
||||
from fastapi import params
|
||||
from fastapi._compat import (
|
||||
PYDANTIC_V2,
|
||||
ModelField,
|
||||
RequiredParam,
|
||||
Undefined,
|
||||
_is_error_wrapper,
|
||||
_is_model_class,
|
||||
_regenerate_error_with_loc,
|
||||
copy_field_info,
|
||||
create_body_model,
|
||||
evaluate_forwardref,
|
||||
field_annotation_is_scalar,
|
||||
get_annotation_from_field_info,
|
||||
get_cached_model_fields,
|
||||
get_missing_field_error,
|
||||
is_bytes_field,
|
||||
@@ -42,23 +36,19 @@ from fastapi._compat import (
|
||||
is_uploadfile_or_nonable_uploadfile_annotation,
|
||||
is_uploadfile_sequence_annotation,
|
||||
lenient_issubclass,
|
||||
may_v1,
|
||||
sequence_types,
|
||||
serialize_sequence_value,
|
||||
value_is_sequence,
|
||||
)
|
||||
from fastapi._compat.shared import annotation_is_pydantic_v1
|
||||
from fastapi.background import BackgroundTasks
|
||||
from fastapi.concurrency import (
|
||||
asynccontextmanager,
|
||||
contextmanager_in_threadpool,
|
||||
)
|
||||
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
||||
from fastapi.dependencies.models import Dependant
|
||||
from fastapi.exceptions import DependencyScopeError
|
||||
from fastapi.logger import logger
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
||||
from fastapi.security.open_id_connect_url import OpenIdConnect
|
||||
from fastapi.security.oauth2 import SecurityScopes
|
||||
from fastapi.types import DependencyCacheKey
|
||||
from fastapi.utils import create_model_field, get_path_param_names
|
||||
from pydantic import BaseModel
|
||||
@@ -75,9 +65,7 @@ from starlette.datastructures import (
|
||||
from starlette.requests import HTTPConnection, Request
|
||||
from starlette.responses import Response
|
||||
from starlette.websockets import WebSocket
|
||||
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||||
|
||||
from .. import temp_pydantic_v1_params
|
||||
from typing_extensions import Literal, get_args, get_origin
|
||||
|
||||
multipart_not_installed_error = (
|
||||
'Form data requires "python-multipart" to be installed. \n'
|
||||
@@ -125,14 +113,14 @@ def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> De
|
||||
assert callable(depends.dependency), (
|
||||
"A parameter-less dependency must have a callable dependency"
|
||||
)
|
||||
use_security_scopes: List[str] = []
|
||||
own_oauth_scopes: list[str] = []
|
||||
if isinstance(depends, params.Security) and depends.scopes:
|
||||
use_security_scopes.extend(depends.scopes)
|
||||
own_oauth_scopes.extend(depends.scopes)
|
||||
return get_dependant(
|
||||
path=path,
|
||||
call=depends.dependency,
|
||||
scope=depends.scope,
|
||||
security_scopes=use_security_scopes,
|
||||
own_oauth_scopes=own_oauth_scopes,
|
||||
)
|
||||
|
||||
|
||||
@@ -140,11 +128,15 @@ def get_flat_dependant(
|
||||
dependant: Dependant,
|
||||
*,
|
||||
skip_repeats: bool = False,
|
||||
visited: Optional[List[DependencyCacheKey]] = None,
|
||||
visited: Optional[list[DependencyCacheKey]] = None,
|
||||
parent_oauth_scopes: Optional[list[str]] = None,
|
||||
) -> Dependant:
|
||||
if visited is None:
|
||||
visited = []
|
||||
visited.append(dependant.cache_key)
|
||||
use_parent_oauth_scopes = (parent_oauth_scopes or []) + (
|
||||
dependant.oauth_scopes or []
|
||||
)
|
||||
|
||||
flat_dependant = Dependant(
|
||||
path_params=dependant.path_params.copy(),
|
||||
@@ -152,36 +144,51 @@ def get_flat_dependant(
|
||||
header_params=dependant.header_params.copy(),
|
||||
cookie_params=dependant.cookie_params.copy(),
|
||||
body_params=dependant.body_params.copy(),
|
||||
security_requirements=dependant.security_requirements.copy(),
|
||||
name=dependant.name,
|
||||
call=dependant.call,
|
||||
request_param_name=dependant.request_param_name,
|
||||
websocket_param_name=dependant.websocket_param_name,
|
||||
http_connection_param_name=dependant.http_connection_param_name,
|
||||
response_param_name=dependant.response_param_name,
|
||||
background_tasks_param_name=dependant.background_tasks_param_name,
|
||||
security_scopes_param_name=dependant.security_scopes_param_name,
|
||||
own_oauth_scopes=dependant.own_oauth_scopes,
|
||||
parent_oauth_scopes=use_parent_oauth_scopes,
|
||||
use_cache=dependant.use_cache,
|
||||
path=dependant.path,
|
||||
scope=dependant.scope,
|
||||
)
|
||||
for sub_dependant in dependant.dependencies:
|
||||
if skip_repeats and sub_dependant.cache_key in visited:
|
||||
continue
|
||||
flat_sub = get_flat_dependant(
|
||||
sub_dependant, skip_repeats=skip_repeats, visited=visited
|
||||
sub_dependant,
|
||||
skip_repeats=skip_repeats,
|
||||
visited=visited,
|
||||
parent_oauth_scopes=flat_dependant.oauth_scopes,
|
||||
)
|
||||
flat_dependant.dependencies.append(flat_sub)
|
||||
flat_dependant.path_params.extend(flat_sub.path_params)
|
||||
flat_dependant.query_params.extend(flat_sub.query_params)
|
||||
flat_dependant.header_params.extend(flat_sub.header_params)
|
||||
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
|
||||
flat_dependant.body_params.extend(flat_sub.body_params)
|
||||
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
|
||||
flat_dependant.dependencies.extend(flat_sub.dependencies)
|
||||
|
||||
return flat_dependant
|
||||
|
||||
|
||||
def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
|
||||
def _get_flat_fields_from_params(fields: list[ModelField]) -> list[ModelField]:
|
||||
if not fields:
|
||||
return fields
|
||||
first_field = fields[0]
|
||||
if len(fields) == 1 and _is_model_class(first_field.type_):
|
||||
if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
return fields_to_extract
|
||||
return fields
|
||||
|
||||
|
||||
def get_flat_params(dependant: Dependant) -> List[ModelField]:
|
||||
def get_flat_params(dependant: Dependant) -> list[ModelField]:
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
||||
query_params = _get_flat_fields_from_params(flat_dependant.query_params)
|
||||
@@ -190,9 +197,23 @@ def get_flat_params(dependant: Dependant) -> List[ModelField]:
|
||||
return path_params + query_params + header_params + cookie_params
|
||||
|
||||
|
||||
def _get_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
if sys.version_info >= (3, 10):
|
||||
try:
|
||||
signature = inspect.signature(call, eval_str=True)
|
||||
except NameError:
|
||||
# Handle type annotations with if TYPE_CHECKING, not used by FastAPI
|
||||
# e.g. dependency return types
|
||||
signature = inspect.signature(call)
|
||||
else:
|
||||
signature = inspect.signature(call)
|
||||
return signature
|
||||
|
||||
|
||||
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
signature = inspect.signature(call)
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
signature = _get_signature(call)
|
||||
unwrapped = inspect.unwrap(call)
|
||||
globalns = getattr(unwrapped, "__globals__", {})
|
||||
typed_params = [
|
||||
inspect.Parameter(
|
||||
name=param.name,
|
||||
@@ -206,7 +227,7 @@ def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
||||
return typed_signature
|
||||
|
||||
|
||||
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
||||
def get_typed_annotation(annotation: Any, globalns: dict[str, Any]) -> Any:
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
||||
@@ -216,13 +237,14 @@ def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
||||
|
||||
|
||||
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
||||
signature = inspect.signature(call)
|
||||
signature = _get_signature(call)
|
||||
unwrapped = inspect.unwrap(call)
|
||||
annotation = signature.return_annotation
|
||||
|
||||
if annotation is inspect.Signature.empty:
|
||||
return None
|
||||
|
||||
globalns = getattr(call, "__globals__", {})
|
||||
globalns = getattr(unwrapped, "__globals__", {})
|
||||
return get_typed_annotation(annotation, globalns)
|
||||
|
||||
|
||||
@@ -231,7 +253,8 @@ def get_dependant(
|
||||
path: str,
|
||||
call: Callable[..., Any],
|
||||
name: Optional[str] = None,
|
||||
security_scopes: Optional[List[str]] = None,
|
||||
own_oauth_scopes: Optional[list[str]] = None,
|
||||
parent_oauth_scopes: Optional[list[str]] = None,
|
||||
use_cache: bool = True,
|
||||
scope: Union[Literal["function", "request"], None] = None,
|
||||
) -> Dependant:
|
||||
@@ -239,21 +262,15 @@ def get_dependant(
|
||||
call=call,
|
||||
name=name,
|
||||
path=path,
|
||||
security_scopes=security_scopes,
|
||||
use_cache=use_cache,
|
||||
scope=scope,
|
||||
own_oauth_scopes=own_oauth_scopes,
|
||||
parent_oauth_scopes=parent_oauth_scopes,
|
||||
)
|
||||
current_scopes = (parent_oauth_scopes or []) + (own_oauth_scopes or [])
|
||||
path_param_names = get_path_param_names(path)
|
||||
endpoint_signature = get_typed_signature(call)
|
||||
signature_params = endpoint_signature.parameters
|
||||
if isinstance(call, SecurityBase):
|
||||
use_scopes: List[str] = []
|
||||
if isinstance(call, (OAuth2, OpenIdConnect)):
|
||||
use_scopes = security_scopes or use_scopes
|
||||
security_requirement = SecurityRequirement(
|
||||
security_scheme=call, scopes=use_scopes
|
||||
)
|
||||
dependant.security_requirements.append(security_requirement)
|
||||
for param_name, param in signature_params.items():
|
||||
is_path_param = param_name in path_param_names
|
||||
param_details = analyze_param(
|
||||
@@ -274,15 +291,16 @@ def get_dependant(
|
||||
f'The dependency "{dependant.call.__name__}" has a scope of '
|
||||
'"request", it cannot depend on dependencies with scope "function".'
|
||||
)
|
||||
use_security_scopes = security_scopes or []
|
||||
sub_own_oauth_scopes: list[str] = []
|
||||
if isinstance(param_details.depends, params.Security):
|
||||
if param_details.depends.scopes:
|
||||
use_security_scopes.extend(param_details.depends.scopes)
|
||||
sub_own_oauth_scopes = list(param_details.depends.scopes)
|
||||
sub_dependant = get_dependant(
|
||||
path=path,
|
||||
call=param_details.depends.dependency,
|
||||
name=param_name,
|
||||
security_scopes=use_security_scopes,
|
||||
own_oauth_scopes=sub_own_oauth_scopes,
|
||||
parent_oauth_scopes=current_scopes,
|
||||
use_cache=param_details.depends.use_cache,
|
||||
scope=param_details.depends.scope,
|
||||
)
|
||||
@@ -298,9 +316,7 @@ def get_dependant(
|
||||
)
|
||||
continue
|
||||
assert param_details.field is not None
|
||||
if isinstance(
|
||||
param_details.field.field_info, (params.Body, temp_pydantic_v1_params.Body)
|
||||
):
|
||||
if isinstance(param_details.field.field_info, params.Body):
|
||||
dependant.body_params.append(param_details.field)
|
||||
else:
|
||||
add_param_to_fields(field=param_details.field, dependant=dependant)
|
||||
@@ -359,7 +375,7 @@ def analyze_param(
|
||||
fastapi_annotations = [
|
||||
arg
|
||||
for arg in annotated_args[1:]
|
||||
if isinstance(arg, (FieldInfo, may_v1.FieldInfo, params.Depends))
|
||||
if isinstance(arg, (FieldInfo, params.Depends))
|
||||
]
|
||||
fastapi_specific_annotations = [
|
||||
arg
|
||||
@@ -368,29 +384,27 @@ def analyze_param(
|
||||
arg,
|
||||
(
|
||||
params.Param,
|
||||
temp_pydantic_v1_params.Param,
|
||||
params.Body,
|
||||
temp_pydantic_v1_params.Body,
|
||||
params.Depends,
|
||||
),
|
||||
)
|
||||
]
|
||||
if fastapi_specific_annotations:
|
||||
fastapi_annotation: Union[
|
||||
FieldInfo, may_v1.FieldInfo, params.Depends, None
|
||||
] = fastapi_specific_annotations[-1]
|
||||
fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
|
||||
fastapi_specific_annotations[-1]
|
||||
)
|
||||
else:
|
||||
fastapi_annotation = None
|
||||
# Set default for Annotated FieldInfo
|
||||
if isinstance(fastapi_annotation, (FieldInfo, may_v1.FieldInfo)):
|
||||
if isinstance(fastapi_annotation, FieldInfo):
|
||||
# Copy `field_info` because we mutate `field_info.default` below.
|
||||
field_info = copy_field_info(
|
||||
field_info=fastapi_annotation, annotation=use_annotation
|
||||
field_info=fastapi_annotation, # type: ignore[arg-type]
|
||||
annotation=use_annotation,
|
||||
)
|
||||
assert field_info.default in {
|
||||
Undefined,
|
||||
may_v1.Undefined,
|
||||
} or field_info.default in {RequiredParam, may_v1.RequiredParam}, (
|
||||
assert (
|
||||
field_info.default == Undefined or field_info.default == RequiredParam
|
||||
), (
|
||||
f"`{field_info.__class__.__name__}` default value cannot be set in"
|
||||
f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
|
||||
)
|
||||
@@ -414,21 +428,20 @@ def analyze_param(
|
||||
)
|
||||
depends = value
|
||||
# Get FieldInfo from default value
|
||||
elif isinstance(value, (FieldInfo, may_v1.FieldInfo)):
|
||||
elif isinstance(value, FieldInfo):
|
||||
assert field_info is None, (
|
||||
"Cannot specify FastAPI annotations in `Annotated` and default value"
|
||||
f" together for {param_name!r}"
|
||||
)
|
||||
field_info = value
|
||||
if PYDANTIC_V2:
|
||||
if isinstance(field_info, FieldInfo):
|
||||
field_info.annotation = type_annotation
|
||||
field_info = value # type: ignore[assignment]
|
||||
if isinstance(field_info, FieldInfo):
|
||||
field_info.annotation = type_annotation
|
||||
|
||||
# Get Depends from type annotation
|
||||
if depends is not None and depends.dependency is None:
|
||||
# Copy `depends` before mutating it
|
||||
depends = copy(depends)
|
||||
depends.dependency = type_annotation
|
||||
depends = dataclasses.replace(depends, dependency=type_annotation)
|
||||
|
||||
# Handle non-param type annotations like Request
|
||||
if lenient_issubclass(
|
||||
@@ -459,14 +472,7 @@ def analyze_param(
|
||||
) or is_uploadfile_sequence_annotation(type_annotation):
|
||||
field_info = params.File(annotation=use_annotation, default=default_value)
|
||||
elif not field_annotation_is_scalar(annotation=type_annotation):
|
||||
if annotation_is_pydantic_v1(use_annotation):
|
||||
field_info = temp_pydantic_v1_params.Body(
|
||||
annotation=use_annotation, default=default_value
|
||||
)
|
||||
else:
|
||||
field_info = params.Body(
|
||||
annotation=use_annotation, default=default_value
|
||||
)
|
||||
field_info = params.Body(annotation=use_annotation, default=default_value)
|
||||
else:
|
||||
field_info = params.Query(annotation=use_annotation, default=default_value)
|
||||
|
||||
@@ -475,23 +481,17 @@ def analyze_param(
|
||||
if field_info is not None:
|
||||
# Handle field_info.in_
|
||||
if is_path_param:
|
||||
assert isinstance(
|
||||
field_info, (params.Path, temp_pydantic_v1_params.Path)
|
||||
), (
|
||||
assert isinstance(field_info, params.Path), (
|
||||
f"Cannot use `{field_info.__class__.__name__}` for path param"
|
||||
f" {param_name!r}"
|
||||
)
|
||||
elif (
|
||||
isinstance(field_info, (params.Param, temp_pydantic_v1_params.Param))
|
||||
isinstance(field_info, params.Param)
|
||||
and getattr(field_info, "in_", None) is None
|
||||
):
|
||||
field_info.in_ = params.ParamTypes.query
|
||||
use_annotation_from_field_info = get_annotation_from_field_info(
|
||||
use_annotation,
|
||||
field_info,
|
||||
param_name,
|
||||
)
|
||||
if isinstance(field_info, (params.Form, temp_pydantic_v1_params.Form)):
|
||||
use_annotation_from_field_info = use_annotation
|
||||
if isinstance(field_info, params.Form):
|
||||
ensure_multipart_is_installed()
|
||||
if not field_info.alias and getattr(field_info, "convert_underscores", None):
|
||||
alias = param_name.replace("_", "-")
|
||||
@@ -503,20 +503,19 @@ def analyze_param(
|
||||
type_=use_annotation_from_field_info,
|
||||
default=field_info.default,
|
||||
alias=alias,
|
||||
required=field_info.default
|
||||
in (RequiredParam, may_v1.RequiredParam, Undefined),
|
||||
required=field_info.default in (RequiredParam, Undefined),
|
||||
field_info=field_info,
|
||||
)
|
||||
if is_path_param:
|
||||
assert is_scalar_field(field=field), (
|
||||
"Path params must be of one of the supported types"
|
||||
)
|
||||
elif isinstance(field_info, (params.Query, temp_pydantic_v1_params.Query)):
|
||||
elif isinstance(field_info, params.Query):
|
||||
assert (
|
||||
is_scalar_field(field)
|
||||
or is_scalar_sequence_field(field)
|
||||
or (
|
||||
_is_model_class(field.type_)
|
||||
lenient_issubclass(field.type_, BaseModel)
|
||||
# For Pydantic v1
|
||||
and getattr(field, "shape", 1) == 1
|
||||
)
|
||||
@@ -542,34 +541,34 @@ def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
||||
|
||||
|
||||
async def _solve_generator(
|
||||
*, dependant: Dependant, stack: AsyncExitStack, sub_values: Dict[str, Any]
|
||||
*, dependant: Dependant, stack: AsyncExitStack, sub_values: dict[str, Any]
|
||||
) -> Any:
|
||||
assert dependant.call
|
||||
if dependant.is_gen_callable:
|
||||
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
|
||||
elif dependant.is_async_gen_callable:
|
||||
if dependant.is_async_gen_callable:
|
||||
cm = asynccontextmanager(dependant.call)(**sub_values)
|
||||
elif dependant.is_gen_callable:
|
||||
cm = contextmanager_in_threadpool(contextmanager(dependant.call)(**sub_values))
|
||||
return await stack.enter_async_context(cm)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SolvedDependency:
|
||||
values: Dict[str, Any]
|
||||
errors: List[Any]
|
||||
values: dict[str, Any]
|
||||
errors: list[Any]
|
||||
background_tasks: Optional[StarletteBackgroundTasks]
|
||||
response: Response
|
||||
dependency_cache: Dict[DependencyCacheKey, Any]
|
||||
dependency_cache: dict[DependencyCacheKey, Any]
|
||||
|
||||
|
||||
async def solve_dependencies(
|
||||
*,
|
||||
request: Union[Request, WebSocket],
|
||||
dependant: Dependant,
|
||||
body: Optional[Union[Dict[str, Any], FormData]] = None,
|
||||
body: Optional[Union[dict[str, Any], FormData]] = None,
|
||||
background_tasks: Optional[StarletteBackgroundTasks] = None,
|
||||
response: Optional[Response] = None,
|
||||
dependency_overrides_provider: Optional[Any] = None,
|
||||
dependency_cache: Optional[Dict[DependencyCacheKey, Any]] = None,
|
||||
dependency_cache: Optional[dict[DependencyCacheKey, Any]] = None,
|
||||
# TODO: remove this parameter later, no longer used, not removing it yet as some
|
||||
# people might be monkey patching this function (although that's not supported)
|
||||
async_exit_stack: AsyncExitStack,
|
||||
@@ -583,8 +582,8 @@ async def solve_dependencies(
|
||||
assert isinstance(function_astack, AsyncExitStack), (
|
||||
"fastapi_function_astack not found in request scope"
|
||||
)
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Any] = []
|
||||
values: dict[str, Any] = {}
|
||||
errors: list[Any] = []
|
||||
if response is None:
|
||||
response = Response()
|
||||
del response.headers["content-length"]
|
||||
@@ -608,7 +607,7 @@ async def solve_dependencies(
|
||||
path=use_path,
|
||||
call=call,
|
||||
name=sub_dependant.name,
|
||||
security_scopes=sub_dependant.security_scopes,
|
||||
parent_oauth_scopes=sub_dependant.oauth_scopes,
|
||||
scope=sub_dependant.scope,
|
||||
)
|
||||
|
||||
@@ -690,7 +689,7 @@ async def solve_dependencies(
|
||||
values[dependant.response_param_name] = response
|
||||
if dependant.security_scopes_param_name:
|
||||
values[dependant.security_scopes_param_name] = SecurityScopes(
|
||||
scopes=dependant.security_scopes
|
||||
scopes=dependant.oauth_scopes
|
||||
)
|
||||
return SolvedDependency(
|
||||
values=values,
|
||||
@@ -702,18 +701,16 @@ async def solve_dependencies(
|
||||
|
||||
|
||||
def _validate_value_with_model_field(
|
||||
*, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
|
||||
) -> Tuple[Any, List[Any]]:
|
||||
*, field: ModelField, value: Any, values: dict[str, Any], loc: tuple[str, ...]
|
||||
) -> tuple[Any, list[Any]]:
|
||||
if value is None:
|
||||
if field.required:
|
||||
return None, [get_missing_field_error(loc=loc)]
|
||||
else:
|
||||
return deepcopy(field.default), []
|
||||
v_, errors_ = field.validate(value, values, loc=loc)
|
||||
if _is_error_wrapper(errors_): # type: ignore[arg-type]
|
||||
return None, [errors_]
|
||||
elif isinstance(errors_, list):
|
||||
new_errors = may_v1._regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
if isinstance(errors_, list):
|
||||
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
||||
return None, new_errors
|
||||
else:
|
||||
return v_, []
|
||||
@@ -722,7 +719,7 @@ def _validate_value_with_model_field(
|
||||
def _get_multidict_value(
|
||||
field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
|
||||
) -> Any:
|
||||
alias = alias or field.alias
|
||||
alias = alias or get_validation_alias(field)
|
||||
if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
|
||||
value = values.getlist(alias)
|
||||
else:
|
||||
@@ -730,7 +727,7 @@ def _get_multidict_value(
|
||||
if (
|
||||
value is None
|
||||
or (
|
||||
isinstance(field.field_info, (params.Form, temp_pydantic_v1_params.Form))
|
||||
isinstance(field.field_info, params.Form)
|
||||
and isinstance(value, str) # For type checks
|
||||
and value == ""
|
||||
)
|
||||
@@ -746,9 +743,9 @@ def _get_multidict_value(
|
||||
def request_params_to_args(
|
||||
fields: Sequence[ModelField],
|
||||
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
||||
) -> Tuple[Dict[str, Any], List[Any]]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Dict[str, Any]] = []
|
||||
) -> tuple[dict[str, Any], list[Any]]:
|
||||
values: dict[str, Any] = {}
|
||||
errors: list[dict[str, Any]] = []
|
||||
|
||||
if not fields:
|
||||
return values, errors
|
||||
@@ -766,7 +763,7 @@ def request_params_to_args(
|
||||
first_field.field_info, "convert_underscores", True
|
||||
)
|
||||
|
||||
params_to_process: Dict[str, Any] = {}
|
||||
params_to_process: dict[str, Any] = {}
|
||||
|
||||
processed_keys = set()
|
||||
|
||||
@@ -779,27 +776,31 @@ def request_params_to_args(
|
||||
field.field_info, "convert_underscores", default_convert_underscores
|
||||
)
|
||||
if convert_underscores:
|
||||
alias = (
|
||||
field.alias
|
||||
if field.alias != field.name
|
||||
else field.name.replace("_", "-")
|
||||
)
|
||||
alias = get_validation_alias(field)
|
||||
if alias == field.name:
|
||||
alias = alias.replace("_", "-")
|
||||
value = _get_multidict_value(field, received_params, alias=alias)
|
||||
if value is not None:
|
||||
params_to_process[field.name] = value
|
||||
processed_keys.add(alias or field.alias)
|
||||
processed_keys.add(field.name)
|
||||
params_to_process[get_validation_alias(field)] = value
|
||||
processed_keys.add(alias or get_validation_alias(field))
|
||||
|
||||
for key, value in received_params.items():
|
||||
for key in received_params.keys():
|
||||
if key not in processed_keys:
|
||||
params_to_process[key] = value
|
||||
if hasattr(received_params, "getlist"):
|
||||
value = received_params.getlist(key)
|
||||
if isinstance(value, list) and (len(value) == 1):
|
||||
params_to_process[key] = value[0]
|
||||
else:
|
||||
params_to_process[key] = value
|
||||
else:
|
||||
params_to_process[key] = received_params.get(key)
|
||||
|
||||
if single_not_embedded_field:
|
||||
field_info = first_field.field_info
|
||||
assert isinstance(field_info, (params.Param, temp_pydantic_v1_params.Param)), (
|
||||
assert isinstance(field_info, params.Param), (
|
||||
"Params must be subclasses of Param"
|
||||
)
|
||||
loc: Tuple[str, ...] = (field_info.in_.value,)
|
||||
loc: tuple[str, ...] = (field_info.in_.value,)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=params_to_process, values=values, loc=loc
|
||||
)
|
||||
@@ -808,10 +809,10 @@ def request_params_to_args(
|
||||
for field in fields:
|
||||
value = _get_multidict_value(field, received_params)
|
||||
field_info = field.field_info
|
||||
assert isinstance(field_info, (params.Param, temp_pydantic_v1_params.Param)), (
|
||||
assert isinstance(field_info, params.Param), (
|
||||
"Params must be subclasses of Param"
|
||||
)
|
||||
loc = (field_info.in_.value, field.alias)
|
||||
loc = (field_info.in_.value, get_validation_alias(field))
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=field, value=value, values=values, loc=loc
|
||||
)
|
||||
@@ -835,13 +836,13 @@ def is_union_of_base_models(field_type: Any) -> bool:
|
||||
union_args = get_args(field_type)
|
||||
|
||||
for arg in union_args:
|
||||
if not _is_model_class(arg):
|
||||
if not lenient_issubclass(arg, BaseModel):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _should_embed_body_fields(fields: List[ModelField]) -> bool:
|
||||
def _should_embed_body_fields(fields: list[ModelField]) -> bool:
|
||||
if not fields:
|
||||
return False
|
||||
# More than one dependency could have the same field, it would show up as multiple
|
||||
@@ -857,8 +858,8 @@ def _should_embed_body_fields(fields: List[ModelField]) -> bool:
|
||||
# If it's a Form (or File) field, it has to be a BaseModel (or a union of BaseModels) to be top level
|
||||
# otherwise it has to be embedded, so that the key value pair can be extracted
|
||||
if (
|
||||
isinstance(first_field.field_info, (params.Form, temp_pydantic_v1_params.Form))
|
||||
and not _is_model_class(first_field.type_)
|
||||
isinstance(first_field.field_info, params.Form)
|
||||
and not lenient_issubclass(first_field.type_, BaseModel)
|
||||
and not is_union_of_base_models(first_field.type_)
|
||||
):
|
||||
return True
|
||||
@@ -866,28 +867,28 @@ def _should_embed_body_fields(fields: List[ModelField]) -> bool:
|
||||
|
||||
|
||||
async def _extract_form_body(
|
||||
body_fields: List[ModelField],
|
||||
body_fields: list[ModelField],
|
||||
received_body: FormData,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
values = {}
|
||||
|
||||
for field in body_fields:
|
||||
value = _get_multidict_value(field, received_body)
|
||||
field_info = field.field_info
|
||||
if (
|
||||
isinstance(field_info, (params.File, temp_pydantic_v1_params.File))
|
||||
isinstance(field_info, params.File)
|
||||
and is_bytes_field(field)
|
||||
and isinstance(value, UploadFile)
|
||||
):
|
||||
value = await value.read()
|
||||
elif (
|
||||
is_bytes_sequence_field(field)
|
||||
and isinstance(field_info, (params.File, temp_pydantic_v1_params.File))
|
||||
and isinstance(field_info, params.File)
|
||||
and value_is_sequence(value)
|
||||
):
|
||||
# For types
|
||||
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
||||
results: List[Union[bytes, str]] = []
|
||||
assert isinstance(value, sequence_types)
|
||||
results: list[Union[bytes, str]] = []
|
||||
|
||||
async def process_fn(
|
||||
fn: Callable[[], Coroutine[Any, Any, Any]],
|
||||
@@ -900,30 +901,35 @@ async def _extract_form_body(
|
||||
tg.start_soon(process_fn, sub_value.read)
|
||||
value = serialize_sequence_value(field=field, value=results)
|
||||
if value is not None:
|
||||
values[field.alias] = value
|
||||
for key, value in received_body.items():
|
||||
if key not in values:
|
||||
values[key] = value
|
||||
values[get_validation_alias(field)] = value
|
||||
field_aliases = {get_validation_alias(field) for field in body_fields}
|
||||
for key in received_body.keys():
|
||||
if key not in field_aliases:
|
||||
param_values = received_body.getlist(key)
|
||||
if len(param_values) == 1:
|
||||
values[key] = param_values[0]
|
||||
else:
|
||||
values[key] = param_values
|
||||
return values
|
||||
|
||||
|
||||
async def request_body_to_args(
|
||||
body_fields: List[ModelField],
|
||||
received_body: Optional[Union[Dict[str, Any], FormData]],
|
||||
body_fields: list[ModelField],
|
||||
received_body: Optional[Union[dict[str, Any], FormData]],
|
||||
embed_body_fields: bool,
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
values: Dict[str, Any] = {}
|
||||
errors: List[Dict[str, Any]] = []
|
||||
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
values: dict[str, Any] = {}
|
||||
errors: list[dict[str, Any]] = []
|
||||
assert body_fields, "request_body_to_args() should be called with fields"
|
||||
single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
|
||||
first_field = body_fields[0]
|
||||
body_to_process = received_body
|
||||
|
||||
fields_to_extract: List[ModelField] = body_fields
|
||||
fields_to_extract: list[ModelField] = body_fields
|
||||
|
||||
if (
|
||||
single_not_embedded_field
|
||||
and _is_model_class(first_field.type_)
|
||||
and lenient_issubclass(first_field.type_, BaseModel)
|
||||
and isinstance(received_body, FormData)
|
||||
):
|
||||
fields_to_extract = get_cached_model_fields(first_field.type_)
|
||||
@@ -932,17 +938,17 @@ async def request_body_to_args(
|
||||
body_to_process = await _extract_form_body(fields_to_extract, received_body)
|
||||
|
||||
if single_not_embedded_field:
|
||||
loc: Tuple[str, ...] = ("body",)
|
||||
loc: tuple[str, ...] = ("body",)
|
||||
v_, errors_ = _validate_value_with_model_field(
|
||||
field=first_field, value=body_to_process, values=values, loc=loc
|
||||
)
|
||||
return {first_field.name: v_}, errors_
|
||||
for field in body_fields:
|
||||
loc = ("body", field.alias)
|
||||
loc = ("body", get_validation_alias(field))
|
||||
value: Optional[Any] = None
|
||||
if body_to_process is not None:
|
||||
try:
|
||||
value = body_to_process.get(field.alias)
|
||||
value = body_to_process.get(get_validation_alias(field))
|
||||
# If the received body is a list, not a dict
|
||||
except AttributeError:
|
||||
errors.append(get_missing_field_error(loc))
|
||||
@@ -980,36 +986,23 @@ def get_body_field(
|
||||
fields=flat_dependant.body_params, model_name=model_name
|
||||
)
|
||||
required = any(True for f in flat_dependant.body_params if f.required)
|
||||
BodyFieldInfo_kwargs: Dict[str, Any] = {
|
||||
BodyFieldInfo_kwargs: dict[str, Any] = {
|
||||
"annotation": BodyModel,
|
||||
"alias": "body",
|
||||
}
|
||||
if not required:
|
||||
BodyFieldInfo_kwargs["default"] = None
|
||||
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
|
||||
BodyFieldInfo: Type[params.Body] = params.File
|
||||
elif any(
|
||||
isinstance(f.field_info, temp_pydantic_v1_params.File)
|
||||
for f in flat_dependant.body_params
|
||||
):
|
||||
BodyFieldInfo: Type[temp_pydantic_v1_params.Body] = temp_pydantic_v1_params.File # type: ignore[no-redef]
|
||||
BodyFieldInfo: type[params.Body] = params.File
|
||||
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
|
||||
BodyFieldInfo = params.Form
|
||||
elif any(
|
||||
isinstance(f.field_info, temp_pydantic_v1_params.Form)
|
||||
for f in flat_dependant.body_params
|
||||
):
|
||||
BodyFieldInfo = temp_pydantic_v1_params.Form # type: ignore[assignment]
|
||||
else:
|
||||
if annotation_is_pydantic_v1(BodyModel):
|
||||
BodyFieldInfo = temp_pydantic_v1_params.Body # type: ignore[assignment]
|
||||
else:
|
||||
BodyFieldInfo = params.Body
|
||||
BodyFieldInfo = params.Body
|
||||
|
||||
body_param_media_types = [
|
||||
f.field_info.media_type
|
||||
for f in flat_dependant.body_params
|
||||
if isinstance(f.field_info, (params.Body, temp_pydantic_v1_params.Body))
|
||||
if isinstance(f.field_info, params.Body)
|
||||
]
|
||||
if len(set(body_param_media_types)) == 1:
|
||||
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
|
||||
@@ -1021,3 +1014,8 @@ def get_body_field(
|
||||
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
|
||||
)
|
||||
return final_field
|
||||
|
||||
|
||||
def get_validation_alias(field: ModelField) -> str:
|
||||
va = getattr(field, "validation_alias", None)
|
||||
return va or field.alias
|
||||
|
||||
Reference in New Issue
Block a user