增加环绕侦察场景适配
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,22 +1,51 @@
|
||||
from typing import Optional
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.openapi.models import APIKey, APIKeyIn
|
||||
from fastapi.security.base import SecurityBase
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class APIKeyBase(SecurityBase):
|
||||
@staticmethod
|
||||
def check_api_key(api_key: Optional[str], auto_error: bool) -> Optional[str]:
|
||||
def __init__(
|
||||
self,
|
||||
location: APIKeyIn,
|
||||
name: str,
|
||||
description: Union[str, None],
|
||||
scheme_name: Union[str, None],
|
||||
auto_error: bool,
|
||||
):
|
||||
self.auto_error = auto_error
|
||||
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": location},
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
"""
|
||||
The WWW-Authenticate header is not standardized for API Key authentication but
|
||||
the HTTP specification requires that an error of 401 "Unauthorized" must
|
||||
include a WWW-Authenticate header.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc9110#name-401-unauthorized
|
||||
|
||||
For this, this method sends a custom challenge `APIKey`.
|
||||
"""
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "APIKey"},
|
||||
)
|
||||
|
||||
def check_api_key(self, api_key: Optional[str]) -> Optional[str]:
|
||||
if not api_key:
|
||||
if auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
if self.auto_error:
|
||||
raise self.make_not_authenticated_error()
|
||||
return None
|
||||
return api_key
|
||||
|
||||
@@ -100,17 +129,17 @@ class APIKeyQuery(APIKeyBase):
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.query},
|
||||
super().__init__(
|
||||
location=APIKeyIn.query,
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.query_params.get(self.model.name)
|
||||
return self.check_api_key(api_key, self.auto_error)
|
||||
return self.check_api_key(api_key)
|
||||
|
||||
|
||||
class APIKeyHeader(APIKeyBase):
|
||||
@@ -188,17 +217,17 @@ class APIKeyHeader(APIKeyBase):
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.header},
|
||||
super().__init__(
|
||||
location=APIKeyIn.header,
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.headers.get(self.model.name)
|
||||
return self.check_api_key(api_key, self.auto_error)
|
||||
return self.check_api_key(api_key)
|
||||
|
||||
|
||||
class APIKeyCookie(APIKeyBase):
|
||||
@@ -276,14 +305,14 @@ class APIKeyCookie(APIKeyBase):
|
||||
),
|
||||
] = True,
|
||||
):
|
||||
self.model: APIKey = APIKey(
|
||||
**{"in": APIKeyIn.cookie},
|
||||
super().__init__(
|
||||
location=APIKeyIn.cookie,
|
||||
name=name,
|
||||
scheme_name=scheme_name,
|
||||
description=description,
|
||||
auto_error=auto_error,
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
api_key = request.cookies.get(self.model.name)
|
||||
return self.check_api_key(api_key, self.auto_error)
|
||||
return self.check_api_key(api_key)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import binascii
|
||||
from base64 import b64decode
|
||||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.exceptions import HTTPException
|
||||
@@ -10,8 +10,7 @@ from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from pydantic import BaseModel
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class HTTPBasicCredentials(BaseModel):
|
||||
@@ -76,10 +75,22 @@ class HTTPBase(SecurityBase):
|
||||
description: Optional[str] = None,
|
||||
auto_error: bool = True,
|
||||
):
|
||||
self.model = HTTPBaseModel(scheme=scheme, description=description)
|
||||
self.model: HTTPBaseModel = HTTPBaseModel(
|
||||
scheme=scheme, description=description
|
||||
)
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_authenticate_headers(self) -> dict[str, str]:
|
||||
return {"WWW-Authenticate": f"{self.model.scheme.title()}"}
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers=self.make_authenticate_headers(),
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self, request: Request
|
||||
) -> Optional[HTTPAuthorizationCredentials]:
|
||||
@@ -87,9 +98,7 @@ class HTTPBase(SecurityBase):
|
||||
scheme, credentials = get_authorization_scheme_param(authorization)
|
||||
if not (authorization and scheme and credentials):
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
@@ -99,6 +108,8 @@ class HTTPBasic(HTTPBase):
|
||||
"""
|
||||
HTTP Basic authentication.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc7617
|
||||
|
||||
## Usage
|
||||
|
||||
Create an instance object and use that object as the dependency in `Depends()`.
|
||||
@@ -185,36 +196,28 @@ class HTTPBasic(HTTPBase):
|
||||
self.realm = realm
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_authenticate_headers(self) -> dict[str, str]:
|
||||
if self.realm:
|
||||
return {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
|
||||
return {"WWW-Authenticate": "Basic"}
|
||||
|
||||
async def __call__( # type: ignore
|
||||
self, request: Request
|
||||
) -> Optional[HTTPBasicCredentials]:
|
||||
authorization = request.headers.get("Authorization")
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if self.realm:
|
||||
unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
|
||||
else:
|
||||
unauthorized_headers = {"WWW-Authenticate": "Basic"}
|
||||
if not authorization or scheme.lower() != "basic":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers=unauthorized_headers,
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
invalid_user_credentials_exc = HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid authentication credentials",
|
||||
headers=unauthorized_headers,
|
||||
)
|
||||
try:
|
||||
data = b64decode(param).decode("ascii")
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error):
|
||||
raise invalid_user_credentials_exc # noqa: B904
|
||||
except (ValueError, UnicodeDecodeError, binascii.Error) as e:
|
||||
raise self.make_not_authenticated_error() from e
|
||||
username, separator, password = data.partition(":")
|
||||
if not separator:
|
||||
raise invalid_user_credentials_exc
|
||||
raise self.make_not_authenticated_error()
|
||||
return HTTPBasicCredentials(username=username, password=password)
|
||||
|
||||
|
||||
@@ -306,17 +309,12 @@ class HTTPBearer(HTTPBase):
|
||||
scheme, credentials = get_authorization_scheme_param(authorization)
|
||||
if not (authorization and scheme and credentials):
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
if scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
@@ -326,6 +324,12 @@ class HTTPDigest(HTTPBase):
|
||||
"""
|
||||
HTTP Digest authentication.
|
||||
|
||||
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
|
||||
but it doesn't implement the full Digest scheme, you would need to to subclass it
|
||||
and implement it in your code.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc7616
|
||||
|
||||
## Usage
|
||||
|
||||
Create an instance object and use that object as the dependency in `Depends()`.
|
||||
@@ -408,17 +412,12 @@ class HTTPDigest(HTTPBase):
|
||||
scheme, credentials = get_authorization_scheme_param(authorization)
|
||||
if not (authorization and scheme and credentials):
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
if scheme.lower() != "digest":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN,
|
||||
detail="Invalid authentication credentials",
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
from typing import Annotated, Any, Optional, Union, cast
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.exceptions import HTTPException
|
||||
@@ -8,10 +8,7 @@ from fastapi.param_functions import Form
|
||||
from fastapi.security.base import SecurityBase
|
||||
from fastapi.security.utils import get_authorization_scheme_param
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
||||
|
||||
# TODO: import from typing when deprecating Python 3.9
|
||||
from typing_extensions import Annotated
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class OAuth2PasswordRequestForm:
|
||||
@@ -323,7 +320,7 @@ class OAuth2(SecurityBase):
|
||||
self,
|
||||
*,
|
||||
flows: Annotated[
|
||||
Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]],
|
||||
Union[OAuthFlowsModel, dict[str, dict[str, Any]]],
|
||||
Doc(
|
||||
"""
|
||||
The dictionary of OAuth2 flows.
|
||||
@@ -377,13 +374,33 @@ class OAuth2(SecurityBase):
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
"""
|
||||
The OAuth 2 specification doesn't define the challenge that should be used,
|
||||
because a `Bearer` token is not really the only option to authenticate.
|
||||
|
||||
But declaring any other authentication challenge would be application-specific
|
||||
as it's not defined in the specification.
|
||||
|
||||
For practical reasons, this method uses the `Bearer` challenge by default, as
|
||||
it's probably the most common one.
|
||||
|
||||
If you are implementing an OAuth2 authentication scheme other than the provided
|
||||
ones in FastAPI (based on bearer tokens), you might want to override this.
|
||||
|
||||
Ref: https://datatracker.ietf.org/doc/html/rfc6749
|
||||
"""
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return authorization
|
||||
@@ -420,7 +437,7 @@ class OAuth2PasswordBearer(OAuth2):
|
||||
),
|
||||
] = None,
|
||||
scopes: Annotated[
|
||||
Optional[Dict[str, str]],
|
||||
Optional[dict[str, str]],
|
||||
Doc(
|
||||
"""
|
||||
The OAuth2 scopes that would be required by the *path operations* that
|
||||
@@ -491,11 +508,7 @@ class OAuth2PasswordBearer(OAuth2):
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return param
|
||||
@@ -537,7 +550,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
||||
),
|
||||
] = None,
|
||||
scopes: Annotated[
|
||||
Optional[Dict[str, str]],
|
||||
Optional[dict[str, str]],
|
||||
Doc(
|
||||
"""
|
||||
The OAuth2 scopes that would be required by the *path operations* that
|
||||
@@ -601,11 +614,7 @@ class OAuth2AuthorizationCodeBearer(OAuth2):
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None # pragma: nocover
|
||||
return param
|
||||
@@ -627,7 +636,7 @@ class SecurityScopes:
|
||||
def __init__(
|
||||
self,
|
||||
scopes: Annotated[
|
||||
Optional[List[str]],
|
||||
Optional[list[str]],
|
||||
Doc(
|
||||
"""
|
||||
This will be filled by FastAPI.
|
||||
@@ -636,7 +645,7 @@ class SecurityScopes:
|
||||
] = None,
|
||||
):
|
||||
self.scopes: Annotated[
|
||||
List[str],
|
||||
list[str],
|
||||
Doc(
|
||||
"""
|
||||
The list of all the scopes required by dependencies.
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
from typing import Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from annotated_doc import Doc
|
||||
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
|
||||
from fastapi.security.base import SecurityBase
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.requests import Request
|
||||
from starlette.status import HTTP_403_FORBIDDEN
|
||||
from typing_extensions import Annotated
|
||||
from starlette.status import HTTP_401_UNAUTHORIZED
|
||||
|
||||
|
||||
class OpenIdConnect(SecurityBase):
|
||||
"""
|
||||
OpenID Connect authentication class. An instance of it would be used as a
|
||||
dependency.
|
||||
|
||||
**Warning**: this is only a stub to connect the components with OpenAPI in FastAPI,
|
||||
but it doesn't implement the full OpenIdConnect scheme, for example, it doesn't use
|
||||
the OpenIDConnect URL. You would need to to subclass it and implement it in your
|
||||
code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -73,13 +77,18 @@ class OpenIdConnect(SecurityBase):
|
||||
self.scheme_name = scheme_name or self.__class__.__name__
|
||||
self.auto_error = auto_error
|
||||
|
||||
def make_not_authenticated_error(self) -> HTTPException:
|
||||
return HTTPException(
|
||||
status_code=HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization:
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
||||
)
|
||||
raise self.make_not_authenticated_error()
|
||||
else:
|
||||
return None
|
||||
return authorization
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_authorization_scheme_param(
|
||||
authorization_header_value: Optional[str],
|
||||
) -> Tuple[str, str]:
|
||||
) -> tuple[str, str]:
|
||||
if not authorization_header_value:
|
||||
return "", ""
|
||||
scheme, _, param = authorization_header_value.partition(" ")
|
||||
|
||||
Reference in New Issue
Block a user