增加环绕侦察场景适配
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
import http.client
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from fastapi import routing
|
||||
from fastapi._compat import (
|
||||
JsonSchemaValue,
|
||||
ModelField,
|
||||
Undefined,
|
||||
get_compat_model_name_map,
|
||||
@@ -19,8 +19,10 @@ from fastapi.dependencies.utils import (
|
||||
_get_flat_fields_from_params,
|
||||
get_flat_dependant,
|
||||
get_flat_params,
|
||||
get_validation_alias,
|
||||
)
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.exceptions import FastAPIDeprecationWarning
|
||||
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX
|
||||
from fastapi.openapi.models import OpenAPI
|
||||
from fastapi.params import Body, ParamTypes
|
||||
@@ -36,8 +38,6 @@ from starlette.responses import JSONResponse
|
||||
from starlette.routing import BaseRoute
|
||||
from typing_extensions import Literal
|
||||
|
||||
from .._compat import _is_model_field
|
||||
|
||||
validation_error_definition = {
|
||||
"title": "ValidationError",
|
||||
"type": "object",
|
||||
@@ -65,7 +65,7 @@ validation_error_response_definition = {
|
||||
},
|
||||
}
|
||||
|
||||
status_code_ranges: Dict[str, str] = {
|
||||
status_code_ranges: dict[str, str] = {
|
||||
"1XX": "Information",
|
||||
"2XX": "Success",
|
||||
"3XX": "Redirection",
|
||||
@@ -77,18 +77,27 @@ status_code_ranges: Dict[str, str] = {
|
||||
|
||||
def get_openapi_security_definitions(
|
||||
flat_dependant: Dependant,
|
||||
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
||||
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
security_definitions = {}
|
||||
operation_security = []
|
||||
for security_requirement in flat_dependant.security_requirements:
|
||||
# Use a dict to merge scopes for same security scheme
|
||||
operation_security_dict: dict[str, list[str]] = {}
|
||||
for security_dependency in flat_dependant._security_dependencies:
|
||||
security_definition = jsonable_encoder(
|
||||
security_requirement.security_scheme.model,
|
||||
security_dependency._security_scheme.model,
|
||||
by_alias=True,
|
||||
exclude_none=True,
|
||||
)
|
||||
security_name = security_requirement.security_scheme.scheme_name
|
||||
security_name = security_dependency._security_scheme.scheme_name
|
||||
security_definitions[security_name] = security_definition
|
||||
operation_security.append({security_name: security_requirement.scopes})
|
||||
# Merge scopes for the same security scheme
|
||||
if security_name not in operation_security_dict:
|
||||
operation_security_dict[security_name] = []
|
||||
for scope in security_dependency.oauth_scopes or []:
|
||||
if scope not in operation_security_dict[security_name]:
|
||||
operation_security_dict[security_name].append(scope)
|
||||
operation_security = [
|
||||
{name: scopes} for name, scopes in operation_security_dict.items()
|
||||
]
|
||||
return security_definitions, operation_security
|
||||
|
||||
|
||||
@@ -96,11 +105,11 @@ def _get_openapi_operation_parameters(
|
||||
*,
|
||||
dependant: Dependant,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
field_mapping: dict[
|
||||
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
|
||||
],
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> List[Dict[str, Any]]:
|
||||
) -> list[dict[str, Any]]:
|
||||
parameters = []
|
||||
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
||||
path_params = _get_flat_fields_from_params(flat_dependant.path_params)
|
||||
@@ -132,7 +141,7 @@ def _get_openapi_operation_parameters(
|
||||
field_mapping=field_mapping,
|
||||
separate_input_output_schemas=separate_input_output_schemas,
|
||||
)
|
||||
name = param.alias
|
||||
name = get_validation_alias(param)
|
||||
convert_underscores = getattr(
|
||||
param.field_info,
|
||||
"convert_underscores",
|
||||
@@ -140,7 +149,7 @@ def _get_openapi_operation_parameters(
|
||||
)
|
||||
if (
|
||||
param_type == ParamTypes.header
|
||||
and param.alias == param.name
|
||||
and name == param.name
|
||||
and convert_underscores
|
||||
):
|
||||
name = param.name.replace("_", "-")
|
||||
@@ -169,14 +178,14 @@ def get_openapi_operation_request_body(
|
||||
*,
|
||||
body_field: Optional[ModelField],
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
field_mapping: dict[
|
||||
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
|
||||
],
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> Optional[dict[str, Any]]:
|
||||
if not body_field:
|
||||
return None
|
||||
assert _is_model_field(body_field)
|
||||
assert isinstance(body_field, ModelField)
|
||||
body_schema = get_schema_from_model_field(
|
||||
field=body_field,
|
||||
model_name_map=model_name_map,
|
||||
@@ -186,10 +195,10 @@ def get_openapi_operation_request_body(
|
||||
field_info = cast(Body, body_field.field_info)
|
||||
request_media_type = field_info.media_type
|
||||
required = body_field.required
|
||||
request_body_oai: Dict[str, Any] = {}
|
||||
request_body_oai: dict[str, Any] = {}
|
||||
if required:
|
||||
request_body_oai["required"] = required
|
||||
request_media_content: Dict[str, Any] = {"schema": body_schema}
|
||||
request_media_content: dict[str, Any] = {"schema": body_schema}
|
||||
if field_info.openapi_examples:
|
||||
request_media_content["examples"] = jsonable_encoder(
|
||||
field_info.openapi_examples
|
||||
@@ -204,9 +213,9 @@ def generate_operation_id(
|
||||
*, route: routing.APIRoute, method: str
|
||||
) -> str: # pragma: nocover
|
||||
warnings.warn(
|
||||
"fastapi.openapi.utils.generate_operation_id() was deprecated, "
|
||||
message="fastapi.openapi.utils.generate_operation_id() was deprecated, "
|
||||
"it is not used internally, and will be removed soon",
|
||||
DeprecationWarning,
|
||||
category=FastAPIDeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if route.operation_id:
|
||||
@@ -222,9 +231,9 @@ def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
|
||||
|
||||
|
||||
def get_openapi_operation_metadata(
|
||||
*, route: routing.APIRoute, method: str, operation_ids: Set[str]
|
||||
) -> Dict[str, Any]:
|
||||
operation: Dict[str, Any] = {}
|
||||
*, route: routing.APIRoute, method: str, operation_ids: set[str]
|
||||
) -> dict[str, Any]:
|
||||
operation: dict[str, Any] = {}
|
||||
if route.tags:
|
||||
operation["tags"] = route.tags
|
||||
operation["summary"] = generate_operation_summary(route=route, method=method)
|
||||
@@ -250,19 +259,19 @@ def get_openapi_operation_metadata(
|
||||
def get_openapi_path(
|
||||
*,
|
||||
route: routing.APIRoute,
|
||||
operation_ids: Set[str],
|
||||
operation_ids: set[str],
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: Dict[
|
||||
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
field_mapping: dict[
|
||||
tuple[ModelField, Literal["validation", "serialization"]], dict[str, Any]
|
||||
],
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
||||
path = {}
|
||||
security_schemes: Dict[str, Any] = {}
|
||||
definitions: Dict[str, Any] = {}
|
||||
security_schemes: dict[str, Any] = {}
|
||||
definitions: dict[str, Any] = {}
|
||||
assert route.methods is not None, "Methods must be a list"
|
||||
if isinstance(route.response_class, DefaultPlaceholder):
|
||||
current_response_class: Type[Response] = route.response_class.value
|
||||
current_response_class: type[Response] = route.response_class.value
|
||||
else:
|
||||
current_response_class = route.response_class
|
||||
assert current_response_class, "A response class is needed to generate OpenAPI"
|
||||
@@ -272,7 +281,7 @@ def get_openapi_path(
|
||||
operation = get_openapi_operation_metadata(
|
||||
route=route, method=method, operation_ids=operation_ids
|
||||
)
|
||||
parameters: List[Dict[str, Any]] = []
|
||||
parameters: list[dict[str, Any]] = []
|
||||
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
|
||||
security_definitions, operation_security = get_openapi_security_definitions(
|
||||
flat_dependant=flat_dependant
|
||||
@@ -380,7 +389,7 @@ def get_openapi_path(
|
||||
"An additional response must be a dict"
|
||||
)
|
||||
field = route.response_fields.get(additional_status_code)
|
||||
additional_field_schema: Optional[Dict[str, Any]] = None
|
||||
additional_field_schema: Optional[dict[str, Any]] = None
|
||||
if field:
|
||||
additional_field_schema = get_schema_from_model_field(
|
||||
field=field,
|
||||
@@ -435,17 +444,17 @@ def get_openapi_path(
|
||||
|
||||
def get_fields_from_routes(
|
||||
routes: Sequence[BaseRoute],
|
||||
) -> List[ModelField]:
|
||||
body_fields_from_routes: List[ModelField] = []
|
||||
responses_from_routes: List[ModelField] = []
|
||||
request_fields_from_routes: List[ModelField] = []
|
||||
callback_flat_models: List[ModelField] = []
|
||||
) -> list[ModelField]:
|
||||
body_fields_from_routes: list[ModelField] = []
|
||||
responses_from_routes: list[ModelField] = []
|
||||
request_fields_from_routes: list[ModelField] = []
|
||||
callback_flat_models: list[ModelField] = []
|
||||
for route in routes:
|
||||
if getattr(route, "include_in_schema", None) and isinstance(
|
||||
route, routing.APIRoute
|
||||
):
|
||||
if route.body_field:
|
||||
assert _is_model_field(route.body_field), (
|
||||
assert isinstance(route.body_field, ModelField), (
|
||||
"A request body must be a Pydantic Field"
|
||||
)
|
||||
body_fields_from_routes.append(route.body_field)
|
||||
@@ -473,15 +482,15 @@ def get_openapi(
|
||||
description: Optional[str] = None,
|
||||
routes: Sequence[BaseRoute],
|
||||
webhooks: Optional[Sequence[BaseRoute]] = None,
|
||||
tags: Optional[List[Dict[str, Any]]] = None,
|
||||
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
|
||||
tags: Optional[list[dict[str, Any]]] = None,
|
||||
servers: Optional[list[dict[str, Union[str, Any]]]] = None,
|
||||
terms_of_service: Optional[str] = None,
|
||||
contact: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
license_info: Optional[Dict[str, Union[str, Any]]] = None,
|
||||
contact: Optional[dict[str, Union[str, Any]]] = None,
|
||||
license_info: Optional[dict[str, Union[str, Any]]] = None,
|
||||
separate_input_output_schemas: bool = True,
|
||||
external_docs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
info: Dict[str, Any] = {"title": title, "version": version}
|
||||
external_docs: Optional[dict[str, Any]] = None,
|
||||
) -> dict[str, Any]:
|
||||
info: dict[str, Any] = {"title": title, "version": version}
|
||||
if summary:
|
||||
info["summary"] = summary
|
||||
if description:
|
||||
@@ -492,13 +501,13 @@ def get_openapi(
|
||||
info["contact"] = contact
|
||||
if license_info:
|
||||
info["license"] = license_info
|
||||
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
|
||||
output: dict[str, Any] = {"openapi": openapi_version, "info": info}
|
||||
if servers:
|
||||
output["servers"] = servers
|
||||
components: Dict[str, Dict[str, Any]] = {}
|
||||
paths: Dict[str, Dict[str, Any]] = {}
|
||||
webhook_paths: Dict[str, Dict[str, Any]] = {}
|
||||
operation_ids: Set[str] = set()
|
||||
components: dict[str, dict[str, Any]] = {}
|
||||
paths: dict[str, dict[str, Any]] = {}
|
||||
webhook_paths: dict[str, dict[str, Any]] = {}
|
||||
operation_ids: set[str] = set()
|
||||
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
|
||||
model_name_map = get_compat_model_name_map(all_fields)
|
||||
field_mapping, definitions = get_definitions(
|
||||
|
||||
Reference in New Issue
Block a user