chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
MCP OAuth server authorization components.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,5 @@
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
def stringify_pydantic_error(validation_error: ValidationError) -> str:
|
||||
return "\n".join(f"{'.'.join(str(loc) for loc in e['loc'])}: {e['msg']}" for e in validation_error.errors())
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Request handlers for MCP authorization endpoints.
|
||||
"""
|
||||
@@ -0,0 +1,224 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||
from starlette.datastructures import FormData, QueryParams
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import RedirectResponse, Response
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import (
|
||||
AuthorizationErrorCode,
|
||||
AuthorizationParams,
|
||||
AuthorizeError,
|
||||
OAuthAuthorizationServerProvider,
|
||||
construct_redirect_uri,
|
||||
)
|
||||
from mcp.shared.auth import InvalidRedirectUriError, InvalidScopeError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthorizationRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||
client_id: str = Field(..., description="The client ID")
|
||||
redirect_uri: AnyUrl | None = Field(None, description="URL to redirect to after authorization")
|
||||
|
||||
# see OAuthClientMetadata; we only support `code`
|
||||
response_type: Literal["code"] = Field(..., description="Must be 'code' for authorization code flow")
|
||||
code_challenge: str = Field(..., description="PKCE code challenge")
|
||||
code_challenge_method: Literal["S256"] = Field("S256", description="PKCE code challenge method, must be S256")
|
||||
state: str | None = Field(None, description="Optional state parameter")
|
||||
scope: str | None = Field(
|
||||
None,
|
||||
description="Optional scope; if specified, should be a space-separated list of scope strings",
|
||||
)
|
||||
resource: str | None = Field(
|
||||
None,
|
||||
description="RFC 8707 resource indicator - the MCP server this token will be used with",
|
||||
)
|
||||
|
||||
|
||||
class AuthorizationErrorResponse(BaseModel):
|
||||
error: AuthorizationErrorCode
|
||||
error_description: str | None
|
||||
error_uri: AnyUrl | None = None
|
||||
# must be set if provided in the request
|
||||
state: str | None = None
|
||||
|
||||
|
||||
def best_effort_extract_string(key: str, params: None | FormData | QueryParams) -> str | None:
|
||||
if params is None: # pragma: no cover
|
||||
return None
|
||||
value = params.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class AnyUrlModel(RootModel[AnyUrl]):
|
||||
root: AnyUrl
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
# implements authorization requests for grant_type=code;
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1
|
||||
|
||||
state = None
|
||||
redirect_uri = None
|
||||
client = None
|
||||
params = None
|
||||
|
||||
async def error_response(
|
||||
error: AuthorizationErrorCode,
|
||||
error_description: str | None,
|
||||
attempt_load_client: bool = True,
|
||||
):
|
||||
# Error responses take two different formats:
|
||||
# 1. The request has a valid client ID & redirect_uri: we issue a redirect
|
||||
# back to the redirect_uri with the error response fields as query
|
||||
# parameters. This allows the client to be notified of the error.
|
||||
# 2. Otherwise, we return an error response directly to the end user;
|
||||
# we choose to do so in JSON, but this is left undefined in the
|
||||
# specification.
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1
|
||||
#
|
||||
# This logic is a bit awkward to handle, because the error might be thrown
|
||||
# very early in request validation, before we've done the usual Pydantic
|
||||
# validation, loaded the client, etc. To handle this, error_response()
|
||||
# contains fallback logic which attempts to load the parameters directly
|
||||
# from the request.
|
||||
|
||||
nonlocal client, redirect_uri, state
|
||||
if client is None and attempt_load_client:
|
||||
# make last-ditch attempt to load the client
|
||||
client_id = best_effort_extract_string("client_id", params)
|
||||
client = await self.provider.get_client(client_id) if client_id else None
|
||||
if redirect_uri is None and client:
|
||||
# make last-ditch effort to load the redirect uri
|
||||
try:
|
||||
if params is not None and "redirect_uri" not in params:
|
||||
raw_redirect_uri = None
|
||||
else:
|
||||
raw_redirect_uri = AnyUrlModel.model_validate(
|
||||
best_effort_extract_string("redirect_uri", params)
|
||||
).root
|
||||
redirect_uri = client.validate_redirect_uri(raw_redirect_uri)
|
||||
except (ValidationError, InvalidRedirectUriError):
|
||||
# if the redirect URI is invalid, ignore it & just return the
|
||||
# initial error
|
||||
pass
|
||||
|
||||
# the error response MUST contain the state specified by the client, if any
|
||||
if state is None: # pragma: no cover
|
||||
# make last-ditch effort to load state
|
||||
state = best_effort_extract_string("state", params)
|
||||
|
||||
error_resp = AuthorizationErrorResponse(
|
||||
error=error,
|
||||
error_description=error_description,
|
||||
state=state,
|
||||
)
|
||||
|
||||
if redirect_uri and client:
|
||||
return RedirectResponse(
|
||||
url=construct_redirect_uri(str(redirect_uri), **error_resp.model_dump(exclude_none=True)),
|
||||
status_code=302,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
else:
|
||||
return PydanticJSONResponse(
|
||||
status_code=400,
|
||||
content=error_resp,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse request parameters
|
||||
if request.method == "GET":
|
||||
# Convert query_params to dict for pydantic validation
|
||||
params = request.query_params
|
||||
else:
|
||||
# Parse form data for POST requests
|
||||
params = await request.form()
|
||||
|
||||
# Save state if it exists, even before validation
|
||||
state = best_effort_extract_string("state", params)
|
||||
|
||||
try:
|
||||
auth_request = AuthorizationRequest.model_validate(params)
|
||||
state = auth_request.state # Update with validated state
|
||||
except ValidationError as validation_error:
|
||||
error: AuthorizationErrorCode = "invalid_request"
|
||||
for e in validation_error.errors():
|
||||
if e["loc"] == ("response_type",) and e["type"] == "literal_error":
|
||||
error = "unsupported_response_type"
|
||||
break
|
||||
return await error_response(error, stringify_pydantic_error(validation_error))
|
||||
|
||||
# Get client information
|
||||
client = await self.provider.get_client(
|
||||
auth_request.client_id,
|
||||
)
|
||||
if not client:
|
||||
# For client_id validation errors, return direct error (no redirect)
|
||||
return await error_response(
|
||||
error="invalid_request",
|
||||
error_description=f"Client ID '{auth_request.client_id}' not found",
|
||||
attempt_load_client=False,
|
||||
)
|
||||
|
||||
# Validate redirect_uri against client's registered URIs
|
||||
try:
|
||||
redirect_uri = client.validate_redirect_uri(auth_request.redirect_uri)
|
||||
except InvalidRedirectUriError as validation_error:
|
||||
# For redirect_uri validation errors, return direct error (no redirect)
|
||||
return await error_response(
|
||||
error="invalid_request",
|
||||
error_description=validation_error.message,
|
||||
)
|
||||
|
||||
# Validate scope - for scope errors, we can redirect
|
||||
try:
|
||||
scopes = client.validate_scope(auth_request.scope)
|
||||
except InvalidScopeError as validation_error:
|
||||
# For scope errors, redirect with error parameters
|
||||
return await error_response(
|
||||
error="invalid_scope",
|
||||
error_description=validation_error.message,
|
||||
)
|
||||
|
||||
# Setup authorization parameters
|
||||
auth_params = AuthorizationParams(
|
||||
state=state,
|
||||
scopes=scopes,
|
||||
code_challenge=auth_request.code_challenge,
|
||||
redirect_uri=redirect_uri,
|
||||
redirect_uri_provided_explicitly=auth_request.redirect_uri is not None,
|
||||
resource=auth_request.resource, # RFC 8707
|
||||
)
|
||||
|
||||
try:
|
||||
# Let the provider pick the next URI to redirect to
|
||||
return RedirectResponse(
|
||||
url=await self.provider.authorize(
|
||||
client,
|
||||
auth_params,
|
||||
),
|
||||
status_code=302,
|
||||
headers={"Cache-Control": "no-store"},
|
||||
)
|
||||
except AuthorizeError as e:
|
||||
# Handle authorization errors as defined in RFC 6749 Section 4.1.2.1
|
||||
return await error_response(error=e.error, error_description=e.error_description)
|
||||
|
||||
except Exception as validation_error: # pragma: no cover
|
||||
# Catch-all for unexpected errors
|
||||
logger.exception("Unexpected error in authorization_handler", exc_info=validation_error)
|
||||
return await error_response(error="server_error", error_description="An unexpected error occurred")
|
||||
@@ -0,0 +1,29 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.shared.auth import OAuthMetadata, ProtectedResourceMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataHandler:
|
||||
metadata: OAuthMetadata
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
return PydanticJSONResponse(
|
||||
content=self.metadata,
|
||||
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtectedResourceMetadataHandler:
|
||||
metadata: ProtectedResourceMetadata
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
return PydanticJSONResponse(
|
||||
content=self.metadata,
|
||||
headers={"Cache-Control": "public, max-age=3600"}, # Cache for 1 hour
|
||||
)
|
||||
@@ -0,0 +1,131 @@
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, RootModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, RegistrationError, RegistrationErrorCode
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata
|
||||
|
||||
|
||||
class RegistrationRequest(RootModel[OAuthClientMetadata]):
|
||||
# this wrapper is a no-op; it's just to separate out the types exposed to the
|
||||
# provider from what we use in the HTTP handler
|
||||
root: OAuthClientMetadata
|
||||
|
||||
|
||||
class RegistrationErrorResponse(BaseModel):
|
||||
error: RegistrationErrorCode
|
||||
error_description: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegistrationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
options: ClientRegistrationOptions
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
# Implements dynamic client registration as defined in https://datatracker.ietf.org/doc/html/rfc7591#section-3.1
|
||||
try:
|
||||
# Parse request body as JSON
|
||||
body = await request.json()
|
||||
client_metadata = OAuthClientMetadata.model_validate(body)
|
||||
|
||||
# Scope validation is handled below
|
||||
except ValidationError as validation_error:
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description=stringify_pydantic_error(validation_error),
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
client_id = str(uuid4())
|
||||
client_secret = None
|
||||
if client_metadata.token_endpoint_auth_method != "none": # pragma: no branch
|
||||
# cryptographically secure random 32-byte hex string
|
||||
client_secret = secrets.token_hex(32)
|
||||
|
||||
if client_metadata.scope is None and self.options.default_scopes is not None:
|
||||
client_metadata.scope = " ".join(self.options.default_scopes)
|
||||
elif client_metadata.scope is not None and self.options.valid_scopes is not None:
|
||||
requested_scopes = set(client_metadata.scope.split())
|
||||
valid_scopes = set(self.options.valid_scopes)
|
||||
if not requested_scopes.issubset(valid_scopes): # pragma: no branch
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="Requested scopes are not valid: "
|
||||
f"{', '.join(requested_scopes - valid_scopes)}",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
if not {"authorization_code", "refresh_token"}.issubset(set(client_metadata.grant_types)):
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="grant_types must be authorization_code and refresh_token",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# The MCP spec requires servers to use the authorization `code` flow
|
||||
# with PKCE
|
||||
if "code" not in client_metadata.response_types:
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(
|
||||
error="invalid_client_metadata",
|
||||
error_description="response_types must include 'code' for authorization_code grant",
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
client_id_issued_at = int(time.time())
|
||||
client_secret_expires_at = (
|
||||
client_id_issued_at + self.options.client_secret_expiry_seconds
|
||||
if self.options.client_secret_expiry_seconds is not None
|
||||
else None
|
||||
)
|
||||
|
||||
client_info = OAuthClientInformationFull(
|
||||
client_id=client_id,
|
||||
client_id_issued_at=client_id_issued_at,
|
||||
client_secret=client_secret,
|
||||
client_secret_expires_at=client_secret_expires_at,
|
||||
# passthrough information from the client request
|
||||
redirect_uris=client_metadata.redirect_uris,
|
||||
token_endpoint_auth_method=client_metadata.token_endpoint_auth_method,
|
||||
grant_types=client_metadata.grant_types,
|
||||
response_types=client_metadata.response_types,
|
||||
client_name=client_metadata.client_name,
|
||||
client_uri=client_metadata.client_uri,
|
||||
logo_uri=client_metadata.logo_uri,
|
||||
scope=client_metadata.scope,
|
||||
contacts=client_metadata.contacts,
|
||||
tos_uri=client_metadata.tos_uri,
|
||||
policy_uri=client_metadata.policy_uri,
|
||||
jwks_uri=client_metadata.jwks_uri,
|
||||
jwks=client_metadata.jwks,
|
||||
software_id=client_metadata.software_id,
|
||||
software_version=client_metadata.software_version,
|
||||
)
|
||||
try:
|
||||
# Register client
|
||||
await self.provider.register_client(client_info)
|
||||
|
||||
# Return client information
|
||||
return PydanticJSONResponse(content=client_info, status_code=201)
|
||||
except RegistrationError as e:
|
||||
# Handle registration errors as defined in RFC 7591 Section 3.2.2
|
||||
return PydanticJSONResponse(
|
||||
content=RegistrationErrorResponse(error=e.error, error_description=e.error_description),
|
||||
status_code=400,
|
||||
)
|
||||
@@ -0,0 +1,94 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from mcp.server.auth.errors import (
|
||||
stringify_pydantic_error,
|
||||
)
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import AccessToken, OAuthAuthorizationServerProvider, RefreshToken
|
||||
|
||||
|
||||
class RevocationRequest(BaseModel):
|
||||
"""
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7009#section-2.1
|
||||
"""
|
||||
|
||||
token: str
|
||||
token_type_hint: Literal["access_token", "refresh_token"] | None = None
|
||||
client_id: str
|
||||
client_secret: str | None
|
||||
|
||||
|
||||
class RevocationErrorResponse(BaseModel):
|
||||
error: Literal["invalid_request", "unauthorized_client"]
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class RevocationHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
client_authenticator: ClientAuthenticator
|
||||
|
||||
async def handle(self, request: Request) -> Response:
|
||||
"""
|
||||
Handler for the OAuth 2.0 Token Revocation endpoint.
|
||||
"""
|
||||
try:
|
||||
form_data = await request.form()
|
||||
revocation_request = RevocationRequest.model_validate(dict(form_data))
|
||||
except ValidationError as e:
|
||||
return PydanticJSONResponse(
|
||||
status_code=400,
|
||||
content=RevocationErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=stringify_pydantic_error(e),
|
||||
),
|
||||
)
|
||||
|
||||
# Authenticate client
|
||||
try:
|
||||
client = await self.client_authenticator.authenticate(
|
||||
revocation_request.client_id, revocation_request.client_secret
|
||||
)
|
||||
except AuthenticationError as e: # pragma: no cover
|
||||
return PydanticJSONResponse(
|
||||
status_code=401,
|
||||
content=RevocationErrorResponse(
|
||||
error="unauthorized_client",
|
||||
error_description=e.message,
|
||||
),
|
||||
)
|
||||
|
||||
loaders = [
|
||||
self.provider.load_access_token,
|
||||
partial(self.provider.load_refresh_token, client),
|
||||
]
|
||||
if revocation_request.token_type_hint == "refresh_token": # pragma: no cover
|
||||
loaders = reversed(loaders)
|
||||
|
||||
token: None | AccessToken | RefreshToken = None
|
||||
for loader in loaders:
|
||||
token = await loader(revocation_request.token)
|
||||
if token is not None:
|
||||
break
|
||||
|
||||
# if token is not found, just return HTTP 200 per the RFC
|
||||
if token and token.client_id == client.client_id:
|
||||
# Revoke token; provider is not meant to be able to do validation
|
||||
# at this point that would result in an error
|
||||
await self.provider.revoke_token(token)
|
||||
|
||||
# Return successful empty response
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,238 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import AnyHttpUrl, AnyUrl, BaseModel, Field, RootModel, ValidationError
|
||||
from starlette.requests import Request
|
||||
|
||||
from mcp.server.auth.errors import stringify_pydantic_error
|
||||
from mcp.server.auth.json_response import PydanticJSONResponse
|
||||
from mcp.server.auth.middleware.client_auth import AuthenticationError, ClientAuthenticator
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenError, TokenErrorCode
|
||||
from mcp.shared.auth import OAuthToken
|
||||
|
||||
|
||||
class AuthorizationCodeRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3
|
||||
grant_type: Literal["authorization_code"]
|
||||
code: str = Field(..., description="The authorization code")
|
||||
redirect_uri: AnyUrl | None = Field(None, description="Must be the same as redirect URI provided in /authorize")
|
||||
client_id: str
|
||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||
client_secret: str | None = None
|
||||
# See https://datatracker.ietf.org/doc/html/rfc7636#section-4.5
|
||||
code_verifier: str = Field(..., description="PKCE code verifier")
|
||||
# RFC 8707 resource indicator
|
||||
resource: str | None = Field(None, description="Resource indicator for the token")
|
||||
|
||||
|
||||
class RefreshTokenRequest(BaseModel):
|
||||
# See https://datatracker.ietf.org/doc/html/rfc6749#section-6
|
||||
grant_type: Literal["refresh_token"]
|
||||
refresh_token: str = Field(..., description="The refresh token")
|
||||
scope: str | None = Field(None, description="Optional scope parameter")
|
||||
client_id: str
|
||||
# we use the client_secret param, per https://datatracker.ietf.org/doc/html/rfc6749#section-2.3.1
|
||||
client_secret: str | None = None
|
||||
# RFC 8707 resource indicator
|
||||
resource: str | None = Field(None, description="Resource indicator for the token")
|
||||
|
||||
|
||||
class TokenRequest(
|
||||
RootModel[
|
||||
Annotated[
|
||||
AuthorizationCodeRequest | RefreshTokenRequest,
|
||||
Field(discriminator="grant_type"),
|
||||
]
|
||||
]
|
||||
):
|
||||
root: Annotated[
|
||||
AuthorizationCodeRequest | RefreshTokenRequest,
|
||||
Field(discriminator="grant_type"),
|
||||
]
|
||||
|
||||
|
||||
class TokenErrorResponse(BaseModel):
|
||||
"""
|
||||
See https://datatracker.ietf.org/doc/html/rfc6749#section-5.2
|
||||
"""
|
||||
|
||||
error: TokenErrorCode
|
||||
error_description: str | None = None
|
||||
error_uri: AnyHttpUrl | None = None
|
||||
|
||||
|
||||
class TokenSuccessResponse(RootModel[OAuthToken]):
|
||||
# this is just a wrapper over OAuthToken; the only reason we do this
|
||||
# is to have some separation between the HTTP response type, and the
|
||||
# type returned by the provider
|
||||
root: OAuthToken
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenHandler:
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any]
|
||||
client_authenticator: ClientAuthenticator
|
||||
|
||||
def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
|
||||
status_code = 200
|
||||
if isinstance(obj, TokenErrorResponse):
|
||||
status_code = 400
|
||||
|
||||
return PydanticJSONResponse(
|
||||
content=obj,
|
||||
status_code=status_code,
|
||||
headers={
|
||||
"Cache-Control": "no-store",
|
||||
"Pragma": "no-cache",
|
||||
},
|
||||
)
|
||||
|
||||
async def handle(self, request: Request):
|
||||
try:
|
||||
form_data = await request.form()
|
||||
token_request = TokenRequest.model_validate(dict(form_data)).root
|
||||
except ValidationError as validation_error:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=stringify_pydantic_error(validation_error),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
client_info = await self.client_authenticator.authenticate(
|
||||
client_id=token_request.client_id,
|
||||
client_secret=token_request.client_secret,
|
||||
)
|
||||
except AuthenticationError as e: # pragma: no cover
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="unauthorized_client",
|
||||
error_description=e.message,
|
||||
)
|
||||
)
|
||||
|
||||
if token_request.grant_type not in client_info.grant_types: # pragma: no cover
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="unsupported_grant_type",
|
||||
error_description=(f"Unsupported grant type (supported grant types are {client_info.grant_types})"),
|
||||
)
|
||||
)
|
||||
|
||||
tokens: OAuthToken
|
||||
|
||||
match token_request:
|
||||
case AuthorizationCodeRequest():
|
||||
auth_code = await self.provider.load_authorization_code(client_info, token_request.code)
|
||||
if auth_code is None or auth_code.client_id != token_request.client_id:
|
||||
# if code belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="authorization code does not exist",
|
||||
)
|
||||
)
|
||||
|
||||
# make auth codes expire after a deadline
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.5
|
||||
if auth_code.expires_at < time.time():
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="authorization code has expired",
|
||||
)
|
||||
)
|
||||
|
||||
# verify redirect_uri doesn't change between /authorize and /tokens
|
||||
# see https://datatracker.ietf.org/doc/html/rfc6749#section-10.6
|
||||
if auth_code.redirect_uri_provided_explicitly:
|
||||
authorize_request_redirect_uri = auth_code.redirect_uri
|
||||
else: # pragma: no cover
|
||||
authorize_request_redirect_uri = None
|
||||
|
||||
# Convert both sides to strings for comparison to handle AnyUrl vs string issues
|
||||
token_redirect_str = str(token_request.redirect_uri) if token_request.redirect_uri is not None else None
|
||||
auth_redirect_str = (
|
||||
str(authorize_request_redirect_uri) if authorize_request_redirect_uri is not None else None
|
||||
)
|
||||
|
||||
if token_redirect_str != auth_redirect_str:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_request",
|
||||
error_description=("redirect_uri did not match the one used when creating auth code"),
|
||||
)
|
||||
)
|
||||
|
||||
# Verify PKCE code verifier
|
||||
sha256 = hashlib.sha256(token_request.code_verifier.encode()).digest()
|
||||
hashed_code_verifier = base64.urlsafe_b64encode(sha256).decode().rstrip("=")
|
||||
|
||||
if hashed_code_verifier != auth_code.code_challenge:
|
||||
# see https://datatracker.ietf.org/doc/html/rfc7636#section-4.6
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="incorrect code_verifier",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange authorization code for tokens
|
||||
tokens = await self.provider.exchange_authorization_code(client_info, auth_code)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error=e.error,
|
||||
error_description=e.error_description,
|
||||
)
|
||||
)
|
||||
|
||||
case RefreshTokenRequest(): # pragma: no cover
|
||||
refresh_token = await self.provider.load_refresh_token(client_info, token_request.refresh_token)
|
||||
if refresh_token is None or refresh_token.client_id != token_request.client_id:
|
||||
# if token belongs to different client, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="refresh token does not exist",
|
||||
)
|
||||
)
|
||||
|
||||
if refresh_token.expires_at and refresh_token.expires_at < time.time():
|
||||
# if the refresh token has expired, pretend it doesn't exist
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_grant",
|
||||
error_description="refresh token has expired",
|
||||
)
|
||||
)
|
||||
|
||||
# Parse scopes if provided
|
||||
scopes = token_request.scope.split(" ") if token_request.scope else refresh_token.scopes
|
||||
|
||||
for scope in scopes:
|
||||
if scope not in refresh_token.scopes:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error="invalid_scope",
|
||||
error_description=(f"cannot request scope `{scope}` not provided by refresh token"),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Exchange refresh token for new tokens
|
||||
tokens = await self.provider.exchange_refresh_token(client_info, refresh_token, scopes)
|
||||
except TokenError as e:
|
||||
return self.response(
|
||||
TokenErrorResponse(
|
||||
error=e.error,
|
||||
error_description=e.error_description,
|
||||
)
|
||||
)
|
||||
|
||||
return self.response(TokenSuccessResponse(root=tokens))
|
||||
@@ -0,0 +1,10 @@
|
||||
from typing import Any
|
||||
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class PydanticJSONResponse(JSONResponse):
|
||||
# use pydantic json serialization instead of the stock `json.dumps`,
|
||||
# so that we can handle serializing pydantic models like AnyHttpUrl
|
||||
def render(self, content: Any) -> bytes:
|
||||
return content.model_dump_json(exclude_none=True).encode("utf-8")
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Middleware for MCP authorization.
|
||||
"""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,48 @@
|
||||
import contextvars
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
|
||||
from mcp.server.auth.provider import AccessToken
|
||||
|
||||
# Create a contextvar to store the authenticated user
|
||||
# The default is None, indicating no authenticated user is present
|
||||
auth_context_var = contextvars.ContextVar[AuthenticatedUser | None]("auth_context", default=None)
|
||||
|
||||
|
||||
def get_access_token() -> AccessToken | None:
|
||||
"""
|
||||
Get the access token from the current context.
|
||||
|
||||
Returns:
|
||||
The access token if an authenticated user is available, None otherwise.
|
||||
"""
|
||||
auth_user = auth_context_var.get()
|
||||
return auth_user.access_token if auth_user else None
|
||||
|
||||
|
||||
class AuthContextMiddleware:
|
||||
"""
|
||||
Middleware that extracts the authenticated user from the request
|
||||
and sets it in a contextvar for easy access throughout the request lifecycle.
|
||||
|
||||
This middleware should be added after the AuthenticationMiddleware in the
|
||||
middleware stack to ensure that the user is properly authenticated before
|
||||
being stored in the context.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
user = scope.get("user")
|
||||
if isinstance(user, AuthenticatedUser):
|
||||
# Set the authenticated user in the contextvar
|
||||
token = auth_context_var.set(user)
|
||||
try:
|
||||
await self.app(scope, receive, send)
|
||||
finally:
|
||||
auth_context_var.reset(token)
|
||||
else:
|
||||
# No authenticated user, just process the request
|
||||
await self.app(scope, receive, send)
|
||||
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.authentication import AuthCredentials, AuthenticationBackend, SimpleUser
|
||||
from starlette.requests import HTTPConnection
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from mcp.server.auth.provider import AccessToken, TokenVerifier
|
||||
|
||||
|
||||
class AuthenticatedUser(SimpleUser):
|
||||
"""User with authentication info."""
|
||||
|
||||
def __init__(self, auth_info: AccessToken):
|
||||
super().__init__(auth_info.client_id)
|
||||
self.access_token = auth_info
|
||||
self.scopes = auth_info.scopes
|
||||
|
||||
|
||||
class BearerAuthBackend(AuthenticationBackend):
|
||||
"""
|
||||
Authentication backend that validates Bearer tokens using a TokenVerifier.
|
||||
"""
|
||||
|
||||
def __init__(self, token_verifier: TokenVerifier):
|
||||
self.token_verifier = token_verifier
|
||||
|
||||
async def authenticate(self, conn: HTTPConnection):
|
||||
auth_header = next(
|
||||
(conn.headers.get(key) for key in conn.headers if key.lower() == "authorization"),
|
||||
None,
|
||||
)
|
||||
if not auth_header or not auth_header.lower().startswith("bearer "):
|
||||
return None
|
||||
|
||||
token = auth_header[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Validate the token with the verifier
|
||||
auth_info = await self.token_verifier.verify_token(token)
|
||||
|
||||
if not auth_info:
|
||||
return None
|
||||
|
||||
if auth_info.expires_at and auth_info.expires_at < int(time.time()):
|
||||
return None
|
||||
|
||||
return AuthCredentials(auth_info.scopes), AuthenticatedUser(auth_info)
|
||||
|
||||
|
||||
class RequireAuthMiddleware:
|
||||
"""
|
||||
Middleware that requires a valid Bearer token in the Authorization header.
|
||||
|
||||
This will validate the token with the auth provider and store the resulting
|
||||
auth info in the request state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: Any,
|
||||
required_scopes: list[str],
|
||||
resource_metadata_url: AnyHttpUrl | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize the middleware.
|
||||
|
||||
Args:
|
||||
app: ASGI application
|
||||
required_scopes: List of scopes that the token must have
|
||||
resource_metadata_url: Optional protected resource metadata URL for WWW-Authenticate header
|
||||
"""
|
||||
self.app = app
|
||||
self.required_scopes = required_scopes
|
||||
self.resource_metadata_url = resource_metadata_url
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
auth_user = scope.get("user")
|
||||
if not isinstance(auth_user, AuthenticatedUser):
|
||||
await self._send_auth_error(
|
||||
send, status_code=401, error="invalid_token", description="Authentication required"
|
||||
)
|
||||
return
|
||||
|
||||
auth_credentials = scope.get("auth")
|
||||
|
||||
for required_scope in self.required_scopes:
|
||||
# auth_credentials should always be provided; this is just paranoia
|
||||
if auth_credentials is None or required_scope not in auth_credentials.scopes:
|
||||
await self._send_auth_error(
|
||||
send, status_code=403, error="insufficient_scope", description=f"Required scope: {required_scope}"
|
||||
)
|
||||
return
|
||||
|
||||
await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send: Send, status_code: int, error: str, description: str) -> None:
|
||||
"""Send an authentication error response with WWW-Authenticate header."""
|
||||
# Build WWW-Authenticate header value
|
||||
www_auth_parts = [f'error="{error}"', f'error_description="{description}"']
|
||||
if self.resource_metadata_url: # pragma: no cover
|
||||
www_auth_parts.append(f'resource_metadata="{self.resource_metadata_url}"')
|
||||
|
||||
www_authenticate = f"Bearer {', '.join(www_auth_parts)}"
|
||||
|
||||
# Send response
|
||||
body = {"error": error, "error_description": description}
|
||||
body_bytes = json.dumps(body).encode()
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status_code,
|
||||
"headers": [
|
||||
(b"content-type", b"application/json"),
|
||||
(b"content-length", str(len(body_bytes)).encode()),
|
||||
(b"www-authenticate", www_authenticate.encode()),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.body",
|
||||
"body": body_bytes,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
|
||||
|
||||
class AuthenticationError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message # pragma: no cover
|
||||
|
||||
|
||||
class ClientAuthenticator:
|
||||
"""
|
||||
ClientAuthenticator is a callable which validates requests from a client
|
||||
application, used to verify /token calls.
|
||||
If, during registration, the client requested to be issued a secret, the
|
||||
authenticator asserts that /token calls must be authenticated with
|
||||
that same token.
|
||||
NOTE: clients can opt for no authentication during registration, in which case this
|
||||
logic is skipped.
|
||||
"""
|
||||
|
||||
def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
|
||||
"""
|
||||
Initialize the dependency.
|
||||
|
||||
Args:
|
||||
provider: Provider to look up client information
|
||||
"""
|
||||
self.provider = provider
|
||||
|
||||
async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull:
|
||||
# Look up client information
|
||||
client = await self.provider.get_client(client_id)
|
||||
if not client:
|
||||
raise AuthenticationError("Invalid client_id") # pragma: no cover
|
||||
|
||||
# If client from the store expects a secret, validate that the request provides
|
||||
# that secret
|
||||
if client.client_secret: # pragma: no branch
|
||||
if not client_secret:
|
||||
raise AuthenticationError("Client secret is required") # pragma: no cover
|
||||
|
||||
if client.client_secret != client_secret:
|
||||
raise AuthenticationError("Invalid client_secret") # pragma: no cover
|
||||
|
||||
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
|
||||
raise AuthenticationError("Client secret has expired") # pragma: no cover
|
||||
|
||||
return client
|
||||
@@ -0,0 +1,301 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, Literal, Protocol, TypeVar
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
|
||||
|
||||
|
||||
class AuthorizationParams(BaseModel):
|
||||
state: str | None
|
||||
scopes: list[str] | None
|
||||
code_challenge: str
|
||||
redirect_uri: AnyUrl
|
||||
redirect_uri_provided_explicitly: bool
|
||||
resource: str | None = None # RFC 8707 resource indicator
|
||||
|
||||
|
||||
class AuthorizationCode(BaseModel):
|
||||
code: str
|
||||
scopes: list[str]
|
||||
expires_at: float
|
||||
client_id: str
|
||||
code_challenge: str
|
||||
redirect_uri: AnyUrl
|
||||
redirect_uri_provided_explicitly: bool
|
||||
resource: str | None = None # RFC 8707 resource indicator
|
||||
|
||||
|
||||
class RefreshToken(BaseModel):
|
||||
token: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
expires_at: int | None = None
|
||||
|
||||
|
||||
class AccessToken(BaseModel):
|
||||
token: str
|
||||
client_id: str
|
||||
scopes: list[str]
|
||||
expires_at: int | None = None
|
||||
resource: str | None = None # RFC 8707 resource indicator
|
||||
|
||||
|
||||
RegistrationErrorCode = Literal[
|
||||
"invalid_redirect_uri",
|
||||
"invalid_client_metadata",
|
||||
"invalid_software_statement",
|
||||
"unapproved_software_statement",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegistrationError(Exception):
|
||||
error: RegistrationErrorCode
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
AuthorizationErrorCode = Literal[
|
||||
"invalid_request",
|
||||
"unauthorized_client",
|
||||
"access_denied",
|
||||
"unsupported_response_type",
|
||||
"invalid_scope",
|
||||
"server_error",
|
||||
"temporarily_unavailable",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthorizeError(Exception):
|
||||
error: AuthorizationErrorCode
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
TokenErrorCode = Literal[
|
||||
"invalid_request",
|
||||
"invalid_client",
|
||||
"invalid_grant",
|
||||
"unauthorized_client",
|
||||
"unsupported_grant_type",
|
||||
"invalid_scope",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TokenError(Exception):
|
||||
error: TokenErrorCode
|
||||
error_description: str | None = None
|
||||
|
||||
|
||||
class TokenVerifier(Protocol):
|
||||
"""Protocol for verifying bearer tokens."""
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
"""Verify a bearer token and return access info if valid."""
|
||||
|
||||
|
||||
# NOTE: FastMCP doesn't render any of these types in the user response, so it's
|
||||
# OK to add fields to subclasses which should not be exposed externally.
|
||||
AuthorizationCodeT = TypeVar("AuthorizationCodeT", bound=AuthorizationCode)
|
||||
RefreshTokenT = TypeVar("RefreshTokenT", bound=RefreshToken)
|
||||
AccessTokenT = TypeVar("AccessTokenT", bound=AccessToken)
|
||||
|
||||
|
||||
class OAuthAuthorizationServerProvider(Protocol, Generic[AuthorizationCodeT, RefreshTokenT, AccessTokenT]):
|
||||
async def get_client(self, client_id: str) -> OAuthClientInformationFull | None:
|
||||
"""
|
||||
Retrieves client information by client ID.
|
||||
|
||||
Implementors MAY raise NotImplementedError if dynamic client registration is
|
||||
disabled in ClientRegistrationOptions.
|
||||
|
||||
Args:
|
||||
client_id: The ID of the client to retrieve.
|
||||
|
||||
Returns:
|
||||
The client information, or None if the client does not exist.
|
||||
"""
|
||||
|
||||
async def register_client(self, client_info: OAuthClientInformationFull) -> None:
|
||||
"""
|
||||
Saves client information as part of registering it.
|
||||
|
||||
Implementors MAY raise NotImplementedError if dynamic client registration is
|
||||
disabled in ClientRegistrationOptions.
|
||||
|
||||
Args:
|
||||
client_info: The client metadata to register.
|
||||
|
||||
Raises:
|
||||
RegistrationError: If the client metadata is invalid.
|
||||
"""
|
||||
|
||||
async def authorize(self, client: OAuthClientInformationFull, params: AuthorizationParams) -> str:
|
||||
"""
|
||||
Called as part of the /authorize endpoint, and returns a URL that the client
|
||||
will be redirected to.
|
||||
Many MCP implementations will redirect to a third-party provider to perform
|
||||
a second OAuth exchange with that provider. In this sort of setup, the client
|
||||
has an OAuth connection with the MCP server, and the MCP server has an OAuth
|
||||
connection with the 3rd-party provider. At the end of this flow, the client
|
||||
should be redirected to the redirect_uri from params.redirect_uri.
|
||||
|
||||
+--------+ +------------+ +-------------------+
|
||||
| | | | | |
|
||||
| Client | --> | MCP Server | --> | 3rd Party OAuth |
|
||||
| | | | | Server |
|
||||
+--------+ +------------+ +-------------------+
|
||||
| ^ |
|
||||
+------------+ | | |
|
||||
| | | | Redirect |
|
||||
|redirect_uri|<-----+ +------------------+
|
||||
| |
|
||||
+------------+
|
||||
|
||||
Implementations will need to define another handler on the MCP server return
|
||||
flow to perform the second redirect, and generate and store an authorization
|
||||
code as part of completing the OAuth authorization step.
|
||||
|
||||
Implementations SHOULD generate an authorization code with at least 160 bits of
|
||||
entropy,
|
||||
and MUST generate an authorization code with at least 128 bits of entropy.
|
||||
See https://datatracker.ietf.org/doc/html/rfc6749#section-10.10.
|
||||
|
||||
Args:
|
||||
client: The client requesting authorization.
|
||||
params: The parameters of the authorization request.
|
||||
|
||||
Returns:
|
||||
A URL to redirect the client to for authorization.
|
||||
|
||||
Raises:
|
||||
AuthorizeError: If the authorization request is invalid.
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: str
|
||||
) -> AuthorizationCodeT | None:
|
||||
"""
|
||||
Loads an AuthorizationCode by its code.
|
||||
|
||||
Args:
|
||||
client: The client that requested the authorization code.
|
||||
authorization_code: The authorization code to get the challenge for.
|
||||
|
||||
Returns:
|
||||
The AuthorizationCode, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
async def exchange_authorization_code(
|
||||
self, client: OAuthClientInformationFull, authorization_code: AuthorizationCodeT
|
||||
) -> OAuthToken:
|
||||
"""
|
||||
Exchanges an authorization code for an access token and refresh token.
|
||||
|
||||
Args:
|
||||
client: The client exchanging the authorization code.
|
||||
authorization_code: The authorization code to exchange.
|
||||
|
||||
Returns:
|
||||
The OAuth token, containing access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
TokenError: If the request is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_refresh_token(self, client: OAuthClientInformationFull, refresh_token: str) -> RefreshTokenT | None:
|
||||
"""
|
||||
Loads a RefreshToken by its token string.
|
||||
|
||||
Args:
|
||||
client: The client that is requesting to load the refresh token.
|
||||
refresh_token: The refresh token string to load.
|
||||
|
||||
Returns:
|
||||
The RefreshToken object if found, or None if not found.
|
||||
"""
|
||||
...
|
||||
|
||||
async def exchange_refresh_token(
|
||||
self,
|
||||
client: OAuthClientInformationFull,
|
||||
refresh_token: RefreshTokenT,
|
||||
scopes: list[str],
|
||||
) -> OAuthToken:
|
||||
"""
|
||||
Exchanges a refresh token for an access token and refresh token.
|
||||
|
||||
Implementations SHOULD rotate both the access token and refresh token.
|
||||
|
||||
Args:
|
||||
client: The client exchanging the refresh token.
|
||||
refresh_token: The refresh token to exchange.
|
||||
scopes: Optional scopes to request with the new access token.
|
||||
|
||||
Returns:
|
||||
The OAuth token, containing access and refresh tokens.
|
||||
|
||||
Raises:
|
||||
TokenError: If the request is invalid
|
||||
"""
|
||||
...
|
||||
|
||||
async def load_access_token(self, token: str) -> AccessTokenT | None:
|
||||
"""
|
||||
Loads an access token by its token.
|
||||
|
||||
Args:
|
||||
token: The access token to verify.
|
||||
|
||||
Returns:
|
||||
The AuthInfo, or None if the token is invalid.
|
||||
"""
|
||||
|
||||
async def revoke_token(
|
||||
self,
|
||||
token: AccessTokenT | RefreshTokenT,
|
||||
) -> None:
|
||||
"""
|
||||
Revokes an access or refresh token.
|
||||
|
||||
If the given token is invalid or already revoked, this method should do nothing.
|
||||
|
||||
Implementations SHOULD revoke both the access token and its corresponding
|
||||
refresh token, regardless of which of the access token or refresh token is
|
||||
provided.
|
||||
|
||||
Args:
|
||||
token: the token to revoke
|
||||
"""
|
||||
|
||||
|
||||
def construct_redirect_uri(redirect_uri_base: str, **params: str | None) -> str:
|
||||
parsed_uri = urlparse(redirect_uri_base)
|
||||
query_params = [(k, v) for k, vs in parse_qs(parsed_uri.query).items() for v in vs]
|
||||
for k, v in params.items():
|
||||
if v is not None:
|
||||
query_params.append((k, v))
|
||||
|
||||
redirect_uri = urlunparse(parsed_uri._replace(query=urlencode(query_params)))
|
||||
return redirect_uri
|
||||
|
||||
|
||||
class ProviderTokenVerifier(TokenVerifier):
|
||||
"""Token verifier that uses an OAuthAuthorizationServerProvider.
|
||||
|
||||
This is provided for backwards compatibility with existing auth_server_provider
|
||||
configurations. For new implementations using AS/RS separation, consider using
|
||||
the TokenVerifier protocol with a dedicated implementation like IntrospectionTokenVerifier.
|
||||
"""
|
||||
|
||||
def __init__(self, provider: "OAuthAuthorizationServerProvider[AuthorizationCode, RefreshToken, AccessToken]"):
|
||||
self.provider = provider
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
"""Verify token using the provider's load_access_token method."""
|
||||
return await self.provider.load_access_token(token)
|
||||
@@ -0,0 +1,253 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import AnyHttpUrl
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
from starlette.routing import Route, request_response # type: ignore
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from mcp.server.auth.handlers.authorize import AuthorizationHandler
|
||||
from mcp.server.auth.handlers.metadata import MetadataHandler
|
||||
from mcp.server.auth.handlers.register import RegistrationHandler
|
||||
from mcp.server.auth.handlers.revoke import RevocationHandler
|
||||
from mcp.server.auth.handlers.token import TokenHandler
|
||||
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
|
||||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
|
||||
from mcp.server.auth.settings import ClientRegistrationOptions, RevocationOptions
|
||||
from mcp.server.streamable_http import MCP_PROTOCOL_VERSION_HEADER
|
||||
from mcp.shared.auth import OAuthMetadata
|
||||
|
||||
|
||||
def validate_issuer_url(url: AnyHttpUrl):
|
||||
"""
|
||||
Validate that the issuer URL meets OAuth 2.0 requirements.
|
||||
|
||||
Args:
|
||||
url: The issuer URL to validate
|
||||
|
||||
Raises:
|
||||
ValueError: If the issuer URL is invalid
|
||||
"""
|
||||
|
||||
# RFC 8414 requires HTTPS, but we allow localhost HTTP for testing
|
||||
if (
|
||||
url.scheme != "https"
|
||||
and url.host != "localhost"
|
||||
and (url.host is not None and not url.host.startswith("127.0.0.1"))
|
||||
):
|
||||
raise ValueError("Issuer URL must be HTTPS") # pragma: no cover
|
||||
|
||||
# No fragments or query parameters allowed
|
||||
if url.fragment:
|
||||
raise ValueError("Issuer URL must not have a fragment") # pragma: no cover
|
||||
if url.query:
|
||||
raise ValueError("Issuer URL must not have a query string") # pragma: no cover
|
||||
|
||||
|
||||
AUTHORIZATION_PATH = "/authorize"
|
||||
TOKEN_PATH = "/token"
|
||||
REGISTRATION_PATH = "/register"
|
||||
REVOCATION_PATH = "/revoke"
|
||||
|
||||
|
||||
def cors_middleware(
|
||||
handler: Callable[[Request], Response | Awaitable[Response]],
|
||||
allow_methods: list[str],
|
||||
) -> ASGIApp:
|
||||
cors_app = CORSMiddleware(
|
||||
app=request_response(handler),
|
||||
allow_origins="*",
|
||||
allow_methods=allow_methods,
|
||||
allow_headers=[MCP_PROTOCOL_VERSION_HEADER],
|
||||
)
|
||||
return cors_app
|
||||
|
||||
|
||||
def create_auth_routes(
|
||||
provider: OAuthAuthorizationServerProvider[Any, Any, Any],
|
||||
issuer_url: AnyHttpUrl,
|
||||
service_documentation_url: AnyHttpUrl | None = None,
|
||||
client_registration_options: ClientRegistrationOptions | None = None,
|
||||
revocation_options: RevocationOptions | None = None,
|
||||
) -> list[Route]:
|
||||
validate_issuer_url(issuer_url)
|
||||
|
||||
client_registration_options = client_registration_options or ClientRegistrationOptions()
|
||||
revocation_options = revocation_options or RevocationOptions()
|
||||
metadata = build_metadata(
|
||||
issuer_url,
|
||||
service_documentation_url,
|
||||
client_registration_options,
|
||||
revocation_options,
|
||||
)
|
||||
client_authenticator = ClientAuthenticator(provider)
|
||||
|
||||
# Create routes
|
||||
# Allow CORS requests for endpoints meant to be hit by the OAuth client
|
||||
# (with the client secret). This is intended to support things like MCP Inspector,
|
||||
# where the client runs in a web browser.
|
||||
routes = [
|
||||
Route(
|
||||
"/.well-known/oauth-authorization-server",
|
||||
endpoint=cors_middleware(
|
||||
MetadataHandler(metadata).handle,
|
||||
["GET", "OPTIONS"],
|
||||
),
|
||||
methods=["GET", "OPTIONS"],
|
||||
),
|
||||
Route(
|
||||
AUTHORIZATION_PATH,
|
||||
# do not allow CORS for authorization endpoint;
|
||||
# clients should just redirect to this
|
||||
endpoint=AuthorizationHandler(provider).handle,
|
||||
methods=["GET", "POST"],
|
||||
),
|
||||
Route(
|
||||
TOKEN_PATH,
|
||||
endpoint=cors_middleware(
|
||||
TokenHandler(provider, client_authenticator).handle,
|
||||
["POST", "OPTIONS"],
|
||||
),
|
||||
methods=["POST", "OPTIONS"],
|
||||
),
|
||||
]
|
||||
|
||||
if client_registration_options.enabled: # pragma: no branch
|
||||
registration_handler = RegistrationHandler(
|
||||
provider,
|
||||
options=client_registration_options,
|
||||
)
|
||||
routes.append(
|
||||
Route(
|
||||
REGISTRATION_PATH,
|
||||
endpoint=cors_middleware(
|
||||
registration_handler.handle,
|
||||
["POST", "OPTIONS"],
|
||||
),
|
||||
methods=["POST", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
if revocation_options.enabled: # pragma: no branch
|
||||
revocation_handler = RevocationHandler(provider, client_authenticator)
|
||||
routes.append(
|
||||
Route(
|
||||
REVOCATION_PATH,
|
||||
endpoint=cors_middleware(
|
||||
revocation_handler.handle,
|
||||
["POST", "OPTIONS"],
|
||||
),
|
||||
methods=["POST", "OPTIONS"],
|
||||
)
|
||||
)
|
||||
|
||||
return routes
|
||||
|
||||
|
||||
def build_metadata(
|
||||
issuer_url: AnyHttpUrl,
|
||||
service_documentation_url: AnyHttpUrl | None,
|
||||
client_registration_options: ClientRegistrationOptions,
|
||||
revocation_options: RevocationOptions,
|
||||
) -> OAuthMetadata:
|
||||
authorization_url = AnyHttpUrl(str(issuer_url).rstrip("/") + AUTHORIZATION_PATH)
|
||||
token_url = AnyHttpUrl(str(issuer_url).rstrip("/") + TOKEN_PATH)
|
||||
|
||||
# Create metadata
|
||||
metadata = OAuthMetadata(
|
||||
issuer=issuer_url,
|
||||
authorization_endpoint=authorization_url,
|
||||
token_endpoint=token_url,
|
||||
scopes_supported=client_registration_options.valid_scopes,
|
||||
response_types_supported=["code"],
|
||||
response_modes_supported=None,
|
||||
grant_types_supported=["authorization_code", "refresh_token"],
|
||||
token_endpoint_auth_methods_supported=["client_secret_post"],
|
||||
token_endpoint_auth_signing_alg_values_supported=None,
|
||||
service_documentation=service_documentation_url,
|
||||
ui_locales_supported=None,
|
||||
op_policy_uri=None,
|
||||
op_tos_uri=None,
|
||||
introspection_endpoint=None,
|
||||
code_challenge_methods_supported=["S256"],
|
||||
)
|
||||
|
||||
# Add registration endpoint if supported
|
||||
if client_registration_options.enabled: # pragma: no branch
|
||||
metadata.registration_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REGISTRATION_PATH)
|
||||
|
||||
# Add revocation endpoint if supported
|
||||
if revocation_options.enabled: # pragma: no branch
|
||||
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
|
||||
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def build_resource_metadata_url(resource_server_url: AnyHttpUrl) -> AnyHttpUrl:
|
||||
"""
|
||||
Build RFC 9728 compliant protected resource metadata URL.
|
||||
|
||||
Inserts /.well-known/oauth-protected-resource between host and resource path
|
||||
as specified in RFC 9728 §3.1.
|
||||
|
||||
Args:
|
||||
resource_server_url: The resource server URL (e.g., https://example.com/mcp)
|
||||
|
||||
Returns:
|
||||
The metadata URL (e.g., https://example.com/.well-known/oauth-protected-resource/mcp)
|
||||
"""
|
||||
parsed = urlparse(str(resource_server_url))
|
||||
# Handle trailing slash: if path is just "/", treat as empty
|
||||
resource_path = parsed.path if parsed.path != "/" else ""
|
||||
return AnyHttpUrl(f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-protected-resource{resource_path}")
|
||||
|
||||
|
||||
def create_protected_resource_routes(
|
||||
resource_url: AnyHttpUrl,
|
||||
authorization_servers: list[AnyHttpUrl],
|
||||
scopes_supported: list[str] | None = None,
|
||||
resource_name: str | None = None,
|
||||
resource_documentation: AnyHttpUrl | None = None,
|
||||
) -> list[Route]:
|
||||
"""
|
||||
Create routes for OAuth 2.0 Protected Resource Metadata (RFC 9728).
|
||||
|
||||
Args:
|
||||
resource_url: The URL of this resource server
|
||||
authorization_servers: List of authorization servers that can issue tokens
|
||||
scopes_supported: Optional list of scopes supported by this resource
|
||||
|
||||
Returns:
|
||||
List of Starlette routes for protected resource metadata
|
||||
"""
|
||||
from mcp.server.auth.handlers.metadata import ProtectedResourceMetadataHandler
|
||||
from mcp.shared.auth import ProtectedResourceMetadata
|
||||
|
||||
metadata = ProtectedResourceMetadata(
|
||||
resource=resource_url,
|
||||
authorization_servers=authorization_servers,
|
||||
scopes_supported=scopes_supported,
|
||||
resource_name=resource_name,
|
||||
resource_documentation=resource_documentation,
|
||||
# bearer_methods_supported defaults to ["header"] in the model
|
||||
)
|
||||
|
||||
handler = ProtectedResourceMetadataHandler(metadata)
|
||||
|
||||
# RFC 9728 §3.1: Register route at /.well-known/oauth-protected-resource + resource path
|
||||
metadata_url = build_resource_metadata_url(resource_url)
|
||||
# Extract just the path part for route registration
|
||||
parsed = urlparse(str(metadata_url))
|
||||
well_known_path = parsed.path
|
||||
|
||||
return [
|
||||
Route(
|
||||
well_known_path,
|
||||
endpoint=cors_middleware(handler.handle, ["GET", "OPTIONS"]),
|
||||
methods=["GET", "OPTIONS"],
|
||||
)
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
from pydantic import AnyHttpUrl, BaseModel, Field
|
||||
|
||||
|
||||
class ClientRegistrationOptions(BaseModel):
|
||||
enabled: bool = False
|
||||
client_secret_expiry_seconds: int | None = None
|
||||
valid_scopes: list[str] | None = None
|
||||
default_scopes: list[str] | None = None
|
||||
|
||||
|
||||
class RevocationOptions(BaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
|
||||
class AuthSettings(BaseModel):
|
||||
issuer_url: AnyHttpUrl = Field(
|
||||
...,
|
||||
description="OAuth authorization server URL that issues tokens for this resource server.",
|
||||
)
|
||||
service_documentation_url: AnyHttpUrl | None = None
|
||||
client_registration_options: ClientRegistrationOptions | None = None
|
||||
revocation_options: RevocationOptions | None = None
|
||||
required_scopes: list[str] | None = None
|
||||
|
||||
# Resource Server settings (when operating as RS only)
|
||||
resource_server_url: AnyHttpUrl | None = Field(
|
||||
...,
|
||||
description="The URL of the MCP server to be used as the resource identifier "
|
||||
"and base route to look up OAuth Protected Resource Metadata.",
|
||||
)
|
||||
Reference in New Issue
Block a user