增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -1,11 +1,11 @@
import http.client
import inspect
import warnings
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
from collections.abc import Sequence
from typing import Any, Optional, Union, cast
from fastapi import routing
from fastapi._compat import (
JsonSchemaValue,
ModelField,
Undefined,
get_compat_model_name_map,
@@ -19,8 +19,10 @@ from fastapi.dependencies.utils import (
_get_flat_fields_from_params,
get_flat_dependant,
get_flat_params,
get_validation_alias,
)
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import FastAPIDeprecationWarning
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
from fastapi.openapi.models import OpenAPI
from fastapi.params import Body, ParamTypes
@@ -36,8 +38,6 @@ from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from typing_extensions import Literal
from .._compat import _is_model_field
validation_error_definition = {
"title": "ValidationError",
"type": "object",
@@ -65,7 +65,7 @@ validation_error_response_definition = {
},
}
status_code_ranges: Dict[str, str] = {
status_code_ranges: dict[str, str] = {
"1XX": "Information",
"2XX": "Success",
"3XX": "Redirection",
@@ -77,18 +77,27 @@ status_code_ranges: Dict[str, str] = {
def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
# Use a dict to merge scopes for same security scheme
operation_security_dict: dict[str, list[str]] = {}
for security_dependency in flat_dependant._security_dependencies:
security_definition = jsonable_encoder(
security_requirement.security_scheme.model,
security_dependency._security_scheme.model,
by_alias=True,
exclude_none=True,
)
security_name = security_requirement.security_scheme.scheme_name
security_name = security_dependency._security_scheme.scheme_name
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
# Merge scopes for the same security scheme
if security_name not in operation_security_dict:
operation_security_dict[security_name] = []
for scope in security_dependency.oauth_scopes or []:
if scope not in operation_security_dict[security_name]:
operation_security_dict[security_name].append(scope)
operation_security = [
{name: scopes} for name, scopes in operation_security_dict.items()
]
return security_definitions, operation_security
@@ -96,11 +105,11 @@ def _get_openapi_operation_parameters(
*,
dependant: Dependant,
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
],
separate_input_output_schemas: bool = True,
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
parameters = []
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
@@ -132,7 +141,7 @@ def _get_openapi_operation_parameters(
field_mapping=field_mapping,
separate_input_output_schemas=separate_input_output_schemas,
)
name = param.alias
name = get_validation_alias(param)
convert_underscores = getattr(
param.field_info,
"convert_underscores",
@@ -140,7 +149,7 @@ def _get_openapi_operation_parameters(
)
if (
param_type == ParamTypes.header
and param.alias == param.name
and name == param.name
and convert_underscores
):
name = param.name.replace("_", "-")
@@ -169,14 +178,14 @@ def get_openapi_operation_request_body(
*,
body_field: Optional[ModelField],
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
],
separate_input_output_schemas: bool = True,
) -> Optional[Dict[str, Any]]:
) -> Optional[dict[str, Any]]:
if not body_field:
return None
assert _is_model_field(body_field)
assert isinstance(body_field, ModelField)
body_schema = get_schema_from_model_field(
field=body_field,
model_name_map=model_name_map,
@@ -186,10 +195,10 @@ def get_openapi_operation_request_body(
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
required = body_field.required
request_body_oai: Dict[str, Any] = {}
request_body_oai: dict[str, Any] = {}
if required:
request_body_oai["required"] = required
request_media_content: Dict[str, Any] = {"schema": body_schema}
request_media_content: dict[str, Any] = {"schema": body_schema}
if field_info.openapi_examples:
request_media_content["examples"] = jsonable_encoder(
field_info.openapi_examples
@@ -204,9 +213,9 @@ def generate_operation_id(
*, route: routing.APIRoute, method: str
) -> str: # pragma: nocover
warnings.warn(
"fastapi.openapi.utils.generate_operation_id() was deprecated, "
message="fastapi.openapi.utils.generate_operation_id() was deprecated, "
"it is not used internally, and will be removed soon",
DeprecationWarning,
category=FastAPIDeprecationWarning,
stacklevel=2,
)
if route.operation_id:
@@ -222,9 +231,9 @@ def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
def get_openapi_operation_metadata(
*, route: routing.APIRoute, method: str, operation_ids: Set[str]
) -> Dict[str, Any]:
operation: Dict[str, Any] = {}
*, route: routing.APIRoute, method: str, operation_ids: set[str]
) -> dict[str, Any]:
operation: dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
operation["summary"] = generate_operation_summary(route=route, method=method)
@@ -250,19 +259,19 @@ def get_openapi_operation_metadata(
def get_openapi_path(
*,
route: routing.APIRoute,
operation_ids: Set[str],
operation_ids: set[str],
model_name_map: ModelNameMap,
field_mapping: Dict[
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
field_mapping: dict[
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
],
separate_input_output_schemas: bool = True,
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
path = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
security_schemes: dict[str, Any] = {}
definitions: dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
if isinstance(route.response_class, DefaultPlaceholder):
current_response_class: Type[Response] = route.response_class.value
current_response_class: type[Response] = route.response_class.value
else:
current_response_class = route.response_class
assert current_response_class, "A response class is needed to generate OpenAPI"
@@ -272,7 +281,7 @@ def get_openapi_path(
operation = get_openapi_operation_metadata(
route=route, method=method, operation_ids=operation_ids
)
parameters: List[Dict[str, Any]] = []
parameters: list[dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(
flat_dependant=flat_dependant
@@ -380,7 +389,7 @@ def get_openapi_path(
"An additional response must be a dict"
)
field = route.response_fields.get(additional_status_code)
additional_field_schema: Optional[Dict[str, Any]] = None
additional_field_schema: Optional[dict[str, Any]] = None
if field:
additional_field_schema = get_schema_from_model_field(
field=field,
@@ -435,17 +444,17 @@ def get_openapi_path(
def get_fields_from_routes(
routes: Sequence[BaseRoute],
) -> List[ModelField]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
request_fields_from_routes: List[ModelField] = []
callback_flat_models: List[ModelField] = []
) -> list[ModelField]:
body_fields_from_routes: list[ModelField] = []
responses_from_routes: list[ModelField] = []
request_fields_from_routes: list[ModelField] = []
callback_flat_models: list[ModelField] = []
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(
route, routing.APIRoute
):
if route.body_field:
assert _is_model_field(route.body_field), (
assert isinstance(route.body_field, ModelField), (
"A request body must be a Pydantic Field"
)
body_fields_from_routes.append(route.body_field)
@@ -473,15 +482,15 @@ def get_openapi(
description: Optional[str] = None,
routes: Sequence[BaseRoute],
webhooks: Optional[Sequence[BaseRoute]] = None,
tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
tags: Optional[list[dict[str, Any]]] = None,
servers: Optional[list[dict[str, Union[str, Any]]]] = None,
terms_of_service: Optional[str] = None,
contact: Optional[Dict[str, Union[str, Any]]] = None,
license_info: Optional[Dict[str, Union[str, Any]]] = None,
contact: Optional[dict[str, Union[str, Any]]] = None,
license_info: Optional[dict[str, Union[str, Any]]] = None,
separate_input_output_schemas: bool = True,
external_docs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
info: Dict[str, Any] = {"title": title, "version": version}
external_docs: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
info: dict[str, Any] = {"title": title, "version": version}
if summary:
info["summary"] = summary
if description:
@@ -492,13 +501,13 @@ def get_openapi(
info["contact"] = contact
if license_info:
info["license"] = license_info
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
output: dict[str, Any] = {"openapi": openapi_version, "info": info}
if servers:
output["servers"] = servers
components: Dict[str, Dict[str, Any]] = {}
paths: Dict[str, Dict[str, Any]] = {}
webhook_paths: Dict[str, Dict[str, Any]] = {}
operation_ids: Set[str] = set()
components: dict[str, dict[str, Any]] = {}
paths: dict[str, dict[str, Any]] = {}
webhook_paths: dict[str, dict[str, Any]] = {}
operation_ids: set[str] = set()
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
model_name_map = get_compat_model_name_map(all_fields)
field_mapping, definitions = get_definitions(