chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
OAuth2 Authentication implementation for HTTPX.
|
||||
|
||||
Implements authorization code flow with PKCE and automatic token refresh.
|
||||
"""
|
||||
|
||||
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
|
||||
from mcp.client.auth.oauth2 import (
|
||||
OAuthClientProvider,
|
||||
PKCEParameters,
|
||||
TokenStorage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OAuthClientProvider",
|
||||
"OAuthFlowError",
|
||||
"OAuthRegistrationError",
|
||||
"OAuthTokenError",
|
||||
"PKCEParameters",
|
||||
"TokenStorage",
|
||||
]
|
||||
@@ -0,0 +1,10 @@
|
||||
class OAuthFlowError(Exception):
|
||||
"""Base exception for OAuth flow errors."""
|
||||
|
||||
|
||||
class OAuthTokenError(OAuthFlowError):
|
||||
"""Raised when token operations fail."""
|
||||
|
||||
|
||||
class OAuthRegistrationError(OAuthFlowError):
|
||||
"""Raised when client registration fails."""
|
||||
@@ -0,0 +1,148 @@
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
|
||||
|
||||
class JWTParameters(BaseModel):
|
||||
"""JWT parameters."""
|
||||
|
||||
assertion: str | None = Field(
|
||||
default=None,
|
||||
description="JWT assertion for JWT authentication. "
|
||||
"Will be used instead of generating a new assertion if provided.",
|
||||
)
|
||||
|
||||
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
|
||||
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
|
||||
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
|
||||
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
|
||||
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
|
||||
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
|
||||
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
|
||||
|
||||
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
|
||||
if self.assertion is not None:
|
||||
# Prebuilt JWT (e.g. acquired out-of-band)
|
||||
assertion = self.assertion
|
||||
else:
|
||||
if not self.jwt_signing_key:
|
||||
raise OAuthFlowError("Missing signing key for JWT bearer grant") # pragma: no cover
|
||||
if not self.issuer:
|
||||
raise OAuthFlowError("Missing issuer for JWT bearer grant") # pragma: no cover
|
||||
if not self.subject:
|
||||
raise OAuthFlowError("Missing subject for JWT bearer grant") # pragma: no cover
|
||||
|
||||
audience = self.audience if self.audience else with_audience_fallback
|
||||
if not audience:
|
||||
raise OAuthFlowError("Missing audience for JWT bearer grant") # pragma: no cover
|
||||
|
||||
now = int(time.time())
|
||||
claims: dict[str, Any] = {
|
||||
"iss": self.issuer,
|
||||
"sub": self.subject,
|
||||
"aud": audience,
|
||||
"exp": now + self.jwt_lifetime_seconds,
|
||||
"iat": now,
|
||||
"jti": str(uuid4()),
|
||||
}
|
||||
claims.update(self.claims or {})
|
||||
|
||||
assertion = jwt.encode(
|
||||
claims,
|
||||
self.jwt_signing_key,
|
||||
algorithm=self.jwt_signing_algorithm or "RS256",
|
||||
)
|
||||
return assertion
|
||||
|
||||
|
||||
class RFC7523OAuthClientProvider(OAuthClientProvider):
|
||||
"""OAuth client provider for RFC7532 clients."""
|
||||
|
||||
jwt_parameters: JWTParameters | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
storage: TokenStorage,
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
|
||||
timeout: float = 300.0,
|
||||
jwt_parameters: JWTParameters | None = None,
|
||||
) -> None:
|
||||
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
|
||||
self.jwt_parameters = jwt_parameters
|
||||
|
||||
async def _exchange_token_authorization_code(
|
||||
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
|
||||
) -> httpx.Request: # pragma: no cover
|
||||
"""Build token exchange request for authorization_code flow."""
|
||||
token_data = token_data or {}
|
||||
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
|
||||
self._add_client_authentication_jwt(token_data=token_data)
|
||||
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
|
||||
"""Perform the authorization flow."""
|
||||
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
|
||||
token_request = await self._exchange_token_jwt_bearer()
|
||||
return token_request
|
||||
else:
|
||||
return await super()._perform_authorization()
|
||||
|
||||
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
|
||||
"""Add JWT assertion for client authentication to token endpoint parameters."""
|
||||
if not self.jwt_parameters:
|
||||
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
|
||||
|
||||
# We need to set the audience to the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
|
||||
issuer = str(self.context.oauth_metadata.issuer)
|
||||
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
|
||||
|
||||
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
|
||||
token_data["client_assertion"] = assertion
|
||||
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||
# We need to set the audience to the resource server, the audience is difference from the one in claims
|
||||
# it represents the resource server that will validate the token
|
||||
token_data["audience"] = self.context.get_resource_url()
|
||||
|
||||
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
|
||||
"""Build token exchange request for JWT bearer grant."""
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("Missing client info") # pragma: no cover
|
||||
if not self.jwt_parameters:
|
||||
raise OAuthFlowError("Missing JWT parameters") # pragma: no cover
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthTokenError("Missing OAuth metadata") # pragma: no cover
|
||||
|
||||
# We need to set the audience to the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
|
||||
issuer = str(self.context.oauth_metadata.issuer)
|
||||
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
|
||||
|
||||
token_data = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
"assertion": assertion,
|
||||
}
|
||||
|
||||
if self.context.should_include_resource_param(self.context.protocol_version): # pragma: no branch
|
||||
token_data["resource"] = self.context.get_resource_url()
|
||||
|
||||
if self.context.client_metadata.scope: # pragma: no branch
|
||||
token_data["scope"] = self.context.client_metadata.scope
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
return httpx.Request(
|
||||
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
@@ -0,0 +1,557 @@
|
||||
"""
|
||||
OAuth2 Authentication implementation for HTTPX.
|
||||
|
||||
Implements authorization code flow with PKCE and automatic token refresh.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import urlencode, urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from mcp.client.auth import OAuthFlowError, OAuthTokenError
|
||||
from mcp.client.auth.utils import (
|
||||
build_oauth_authorization_server_metadata_discovery_urls,
|
||||
build_protected_resource_metadata_discovery_urls,
|
||||
create_client_registration_request,
|
||||
create_oauth_metadata_request,
|
||||
extract_field_from_www_auth,
|
||||
extract_resource_metadata_from_www_auth,
|
||||
extract_scope_from_www_auth,
|
||||
get_client_metadata_scopes,
|
||||
handle_auth_metadata_response,
|
||||
handle_protected_resource_response,
|
||||
handle_registration_response,
|
||||
handle_token_response_scopes,
|
||||
)
|
||||
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthToken,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from mcp.shared.auth_utils import (
|
||||
calculate_token_expiry,
|
||||
check_resource_allowed,
|
||||
resource_url_from_server_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PKCEParameters(BaseModel):
|
||||
"""PKCE (Proof Key for Code Exchange) parameters."""
|
||||
|
||||
code_verifier: str = Field(..., min_length=43, max_length=128)
|
||||
code_challenge: str = Field(..., min_length=43, max_length=128)
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> "PKCEParameters":
|
||||
"""Generate new PKCE parameters."""
|
||||
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
|
||||
|
||||
|
||||
class TokenStorage(Protocol):
|
||||
"""Protocol for token storage implementations."""
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
"""Get stored tokens."""
|
||||
...
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
"""Store tokens."""
|
||||
...
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
"""Get stored client information."""
|
||||
...
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
"""Store client information."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthContext:
|
||||
"""OAuth flow context."""
|
||||
|
||||
server_url: str
|
||||
client_metadata: OAuthClientMetadata
|
||||
storage: TokenStorage
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
|
||||
timeout: float = 300.0
|
||||
|
||||
# Discovered metadata
|
||||
protected_resource_metadata: ProtectedResourceMetadata | None = None
|
||||
oauth_metadata: OAuthMetadata | None = None
|
||||
auth_server_url: str | None = None
|
||||
protocol_version: str | None = None
|
||||
|
||||
# Client registration
|
||||
client_info: OAuthClientInformationFull | None = None
|
||||
|
||||
# Token management
|
||||
current_tokens: OAuthToken | None = None
|
||||
token_expiry_time: float | None = None
|
||||
|
||||
# State
|
||||
lock: anyio.Lock = field(default_factory=anyio.Lock)
|
||||
|
||||
def get_authorization_base_url(self, server_url: str) -> str:
|
||||
"""Extract base URL by removing path component."""
|
||||
parsed = urlparse(server_url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def update_token_expiry(self, token: OAuthToken) -> None:
|
||||
"""Update token expiry time using shared util function."""
|
||||
self.token_expiry_time = calculate_token_expiry(token.expires_in)
|
||||
|
||||
def is_token_valid(self) -> bool:
|
||||
"""Check if current token is valid."""
|
||||
return bool(
|
||||
self.current_tokens
|
||||
and self.current_tokens.access_token
|
||||
and (not self.token_expiry_time or time.time() <= self.token_expiry_time)
|
||||
)
|
||||
|
||||
def can_refresh_token(self) -> bool:
|
||||
"""Check if token can be refreshed."""
|
||||
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""Clear current tokens."""
|
||||
self.current_tokens = None
|
||||
self.token_expiry_time = None
|
||||
|
||||
def get_resource_url(self) -> str:
|
||||
"""Get resource URL for RFC 8707.
|
||||
|
||||
Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
|
||||
"""
|
||||
resource = resource_url_from_server_url(self.server_url)
|
||||
|
||||
# If PRM provides a resource that's a valid parent, use it
|
||||
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
|
||||
prm_resource = str(self.protected_resource_metadata.resource)
|
||||
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
|
||||
resource = prm_resource
|
||||
|
||||
return resource
|
||||
|
||||
def should_include_resource_param(self, protocol_version: str | None = None) -> bool:
|
||||
"""Determine if the resource parameter should be included in OAuth requests.
|
||||
|
||||
Returns True if:
|
||||
- Protected resource metadata is available, OR
|
||||
- MCP-Protocol-Version header is 2025-06-18 or later
|
||||
"""
|
||||
# If we have protected resource metadata, include the resource param
|
||||
if self.protected_resource_metadata is not None:
|
||||
return True
|
||||
|
||||
# If no protocol version provided, don't include resource param
|
||||
if not protocol_version:
|
||||
return False
|
||||
|
||||
# Check if protocol version is 2025-06-18 or later
|
||||
# Version format is YYYY-MM-DD, so string comparison works
|
||||
return protocol_version >= "2025-06-18"
|
||||
|
||||
|
||||
class OAuthClientProvider(httpx.Auth):
|
||||
"""
|
||||
OAuth2 authentication for httpx.
|
||||
Handles OAuth flow with automatic client registration and token storage.
|
||||
"""
|
||||
|
||||
requires_response_body = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
storage: TokenStorage,
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
|
||||
timeout: float = 300.0,
|
||||
):
|
||||
"""Initialize OAuth2 authentication."""
|
||||
self.context = OAuthContext(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._initialized = False
|
||||
|
||||
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
|
||||
"""
|
||||
Handle protected resource metadata discovery response.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
|
||||
Returns:
|
||||
True if metadata was successfully discovered, False if we should try next URL
|
||||
"""
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
metadata = ProtectedResourceMetadata.model_validate_json(content)
|
||||
self.context.protected_resource_metadata = metadata
|
||||
if metadata.authorization_servers: # pragma: no branch
|
||||
self.context.auth_server_url = str(metadata.authorization_servers[0])
|
||||
return True
|
||||
|
||||
except ValidationError: # pragma: no cover
|
||||
# Invalid metadata - try next URL
|
||||
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
|
||||
return False
|
||||
elif response.status_code == 404: # pragma: no cover
|
||||
# Not found - try next URL in fallback chain
|
||||
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
|
||||
return False
|
||||
else:
|
||||
# Other error - fail immediately
|
||||
raise OAuthFlowError(
|
||||
f"Protected Resource Metadata request failed: {response.status_code}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def _register_client(self) -> httpx.Request | None:
|
||||
"""Build registration request or skip if already registered."""
|
||||
if self.context.client_info:
|
||||
return None
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint:
|
||||
registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
registration_url = urljoin(auth_base_url, "/register")
|
||||
|
||||
registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
|
||||
return httpx.Request(
|
||||
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request:
|
||||
"""Perform the authorization flow."""
|
||||
auth_code, code_verifier = await self._perform_authorization_code_grant()
|
||||
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
|
||||
return token_request
|
||||
|
||||
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
|
||||
"""Perform the authorization redirect and get auth code."""
|
||||
if self.context.client_metadata.redirect_uris is None:
|
||||
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.redirect_handler:
|
||||
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.callback_handler:
|
||||
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
|
||||
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
auth_endpoint = urljoin(auth_base_url, "/authorize")
|
||||
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("No client info available for authorization") # pragma: no cover
|
||||
|
||||
# Generate PKCE parameters
|
||||
pkce_params = PKCEParameters.generate()
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
auth_params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.context.client_info.client_id,
|
||||
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
|
||||
"state": state,
|
||||
"code_challenge": pkce_params.code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover
|
||||
|
||||
if self.context.client_metadata.scope: # pragma: no branch
|
||||
auth_params["scope"] = self.context.client_metadata.scope
|
||||
|
||||
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
await self.context.redirect_handler(authorization_url)
|
||||
|
||||
# Wait for callback
|
||||
auth_code, returned_state = await self.context.callback_handler()
|
||||
|
||||
if returned_state is None or not secrets.compare_digest(returned_state, state):
|
||||
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover
|
||||
|
||||
if not auth_code:
|
||||
raise OAuthFlowError("No authorization code received") # pragma: no cover
|
||||
|
||||
# Return auth code and code verifier for token exchange
|
||||
return auth_code, pkce_params.code_verifier
|
||||
|
||||
def _get_token_endpoint(self) -> str:
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
|
||||
token_url = str(self.context.oauth_metadata.token_endpoint)
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
return token_url
|
||||
|
||||
async def _exchange_token_authorization_code(
|
||||
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
|
||||
) -> httpx.Request:
|
||||
"""Build token exchange request for authorization_code flow."""
|
||||
if self.context.client_metadata.redirect_uris is None:
|
||||
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("Missing client info") # pragma: no cover
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
token_data = token_data or {}
|
||||
token_data.update(
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
|
||||
"client_id": self.context.client_info.client_id,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
)
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
token_data["resource"] = self.context.get_resource_url() # RFC 8707
|
||||
|
||||
if self.context.client_info.client_secret:
|
||||
token_data["client_secret"] = self.context.client_info.client_secret
|
||||
|
||||
return httpx.Request(
|
||||
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
|
||||
async def _handle_token_response(self, response: httpx.Response) -> None:
|
||||
"""Handle token exchange response."""
|
||||
if response.status_code != 200:
|
||||
body = await response.aread() # pragma: no cover
|
||||
body_text = body.decode("utf-8") # pragma: no cover
|
||||
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
|
||||
|
||||
# Parse and validate response with scope validation
|
||||
token_response = await handle_token_response_scopes(response)
|
||||
|
||||
# Store tokens in context
|
||||
self.context.current_tokens = token_response
|
||||
self.context.update_token_expiry(token_response)
|
||||
await self.context.storage.set_tokens(token_response)
|
||||
|
||||
async def _refresh_token(self) -> httpx.Request:
|
||||
"""Build token refresh request."""
|
||||
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
|
||||
raise OAuthTokenError("No refresh token available") # pragma: no cover
|
||||
|
||||
if not self.context.client_info:
|
||||
raise OAuthTokenError("No client info available") # pragma: no cover
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
|
||||
token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
|
||||
refresh_data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self.context.current_tokens.refresh_token,
|
||||
"client_id": self.context.client_info.client_id,
|
||||
}
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
|
||||
|
||||
if self.context.client_info.client_secret: # pragma: no branch
|
||||
refresh_data["client_secret"] = self.context.client_info.client_secret
|
||||
|
||||
return httpx.Request(
|
||||
"POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
|
||||
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
|
||||
"""Handle token refresh response. Returns True if successful."""
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Token refresh failed: {response.status_code}")
|
||||
self.context.clear_tokens()
|
||||
return False
|
||||
|
||||
try:
|
||||
content = await response.aread()
|
||||
token_response = OAuthToken.model_validate_json(content)
|
||||
|
||||
self.context.current_tokens = token_response
|
||||
self.context.update_token_expiry(token_response)
|
||||
await self.context.storage.set_tokens(token_response)
|
||||
|
||||
return True
|
||||
except ValidationError:
|
||||
logger.exception("Invalid refresh response")
|
||||
self.context.clear_tokens()
|
||||
return False
|
||||
|
||||
async def _initialize(self) -> None: # pragma: no cover
|
||||
"""Load stored tokens and client info."""
|
||||
self.context.current_tokens = await self.context.storage.get_tokens()
|
||||
self.context.client_info = await self.context.storage.get_client_info()
|
||||
self._initialized = True
|
||||
|
||||
def _add_auth_header(self, request: httpx.Request) -> None:
|
||||
"""Add authorization header to request if we have valid tokens."""
|
||||
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
|
||||
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
|
||||
|
||||
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
|
||||
content = await response.aread()
|
||||
metadata = OAuthMetadata.model_validate_json(content)
|
||||
self.context.oauth_metadata = metadata
|
||||
|
||||
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
||||
"""HTTPX auth flow integration."""
|
||||
async with self.context.lock:
|
||||
if not self._initialized:
|
||||
await self._initialize() # pragma: no cover
|
||||
|
||||
# Capture protocol version from request headers
|
||||
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
|
||||
|
||||
if not self.context.is_token_valid() and self.context.can_refresh_token():
|
||||
# Try to refresh token
|
||||
refresh_request = await self._refresh_token() # pragma: no cover
|
||||
refresh_response = yield refresh_request # pragma: no cover
|
||||
|
||||
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
|
||||
# Refresh failed, need full re-authentication
|
||||
self._initialized = False
|
||||
|
||||
if self.context.is_token_valid():
|
||||
self._add_auth_header(request)
|
||||
|
||||
response = yield request
|
||||
|
||||
if response.status_code == 401:
|
||||
# Perform full OAuth flow
|
||||
try:
|
||||
# OAuth flow must be inline due to generator constraints
|
||||
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
|
||||
|
||||
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
|
||||
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
|
||||
www_auth_resource_metadata_url, self.context.server_url
|
||||
)
|
||||
|
||||
for url in prm_discovery_urls: # pragma: no branch
|
||||
discovery_request = create_oauth_metadata_request(url)
|
||||
|
||||
discovery_response = yield discovery_request # sending request
|
||||
|
||||
prm = await handle_protected_resource_response(discovery_response)
|
||||
if prm:
|
||||
self.context.protected_resource_metadata = prm
|
||||
|
||||
# todo: try all authorization_servers to find the OASM
|
||||
assert (
|
||||
len(prm.authorization_servers) > 0
|
||||
) # this is always true as authorization_servers has a min length of 1
|
||||
|
||||
self.context.auth_server_url = str(prm.authorization_servers[0])
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Protected resource metadata discovery failed: {url}")
|
||||
|
||||
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
|
||||
self.context.auth_server_url, self.context.server_url
|
||||
)
|
||||
|
||||
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
|
||||
for url in asm_discovery_urls: # pragma: no cover
|
||||
oauth_metadata_request = create_oauth_metadata_request(url)
|
||||
oauth_metadata_response = yield oauth_metadata_request
|
||||
|
||||
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
|
||||
if not ok:
|
||||
break
|
||||
if ok and asm:
|
||||
self.context.oauth_metadata = asm
|
||||
break
|
||||
else:
|
||||
logger.debug(f"OAuth metadata discovery failed: {url}")
|
||||
|
||||
# Step 3: Apply scope selection strategy
|
||||
self.context.client_metadata.scope = get_client_metadata_scopes(
|
||||
www_auth_resource_metadata_url,
|
||||
self.context.protected_resource_metadata,
|
||||
self.context.oauth_metadata,
|
||||
)
|
||||
|
||||
# Step 4: Register client if needed
|
||||
registration_request = create_client_registration_request(
|
||||
self.context.oauth_metadata,
|
||||
self.context.client_metadata,
|
||||
self.context.get_authorization_base_url(self.context.server_url),
|
||||
)
|
||||
if not self.context.client_info:
|
||||
registration_response = yield registration_request
|
||||
client_information = await handle_registration_response(registration_response)
|
||||
self.context.client_info = client_information
|
||||
await self.context.storage.set_client_info(client_information)
|
||||
|
||||
# Step 5: Perform authorization and complete token exchange
|
||||
token_response = yield await self._perform_authorization()
|
||||
await self._handle_token_response(token_response)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("OAuth flow error")
|
||||
raise
|
||||
|
||||
# Retry with new tokens
|
||||
self._add_auth_header(request)
|
||||
yield request
|
||||
elif response.status_code == 403:
|
||||
# Step 1: Extract error field from WWW-Authenticate header
|
||||
error = extract_field_from_www_auth(response, "error")
|
||||
|
||||
# Step 2: Check if we need to step-up authorization
|
||||
if error == "insufficient_scope": # pragma: no branch
|
||||
try:
|
||||
# Step 2a: Update the required scopes
|
||||
self.context.client_metadata.scope = get_client_metadata_scopes(
|
||||
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
|
||||
)
|
||||
|
||||
# Step 2b: Perform (re-)authorization and token exchange
|
||||
token_response = yield await self._perform_authorization()
|
||||
await self._handle_token_response(token_response)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("OAuth flow error")
|
||||
raise
|
||||
|
||||
# Retry with new tokens
|
||||
self._add_auth_header(request)
|
||||
yield request
|
||||
@@ -0,0 +1,267 @@
|
||||
import logging
|
||||
import re
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from httpx import Request, Response
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
|
||||
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthToken,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
|
||||
"""
|
||||
Extract field from WWW-Authenticate header.
|
||||
|
||||
Returns:
|
||||
Field value if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
www_auth_header = response.headers.get("WWW-Authenticate")
|
||||
if not www_auth_header:
|
||||
return None
|
||||
|
||||
# Pattern matches: field_name="value" or field_name=value (unquoted)
|
||||
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
|
||||
match = re.search(pattern, www_auth_header)
|
||||
|
||||
if match:
|
||||
# Return quoted value if present, otherwise unquoted value
|
||||
return match.group(1) or match.group(2)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_scope_from_www_auth(response: Response) -> str | None:
|
||||
"""
|
||||
Extract scope parameter from WWW-Authenticate header as per RFC6750.
|
||||
|
||||
Returns:
|
||||
Scope string if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
return extract_field_from_www_auth(response, "scope")
|
||||
|
||||
|
||||
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
|
||||
"""
|
||||
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
|
||||
|
||||
Returns:
|
||||
Resource metadata URL if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
if not response or response.status_code != 401:
|
||||
return None # pragma: no cover
|
||||
|
||||
return extract_field_from_www_auth(response, "resource_metadata")
|
||||
|
||||
|
||||
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Build ordered list of URLs to try for protected resource metadata discovery.
|
||||
|
||||
Per SEP-985, the client MUST:
|
||||
1. Try resource_metadata from WWW-Authenticate header (if present)
|
||||
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
|
||||
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
|
||||
|
||||
Args:
|
||||
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
|
||||
server_url: server url
|
||||
|
||||
Returns:
|
||||
Ordered list of URLs to try for discovery
|
||||
"""
|
||||
urls: list[str] = []
|
||||
|
||||
# Priority 1: WWW-Authenticate header with resource_metadata parameter
|
||||
if www_auth_url:
|
||||
urls.append(www_auth_url)
|
||||
|
||||
# Priority 2-3: Well-known URIs (RFC 9728)
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# Priority 2: Path-based well-known URI (if server has a path component)
|
||||
if parsed.path and parsed.path != "/":
|
||||
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
|
||||
urls.append(path_based_url)
|
||||
|
||||
# Priority 3: Root-based well-known URI
|
||||
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
urls.append(root_based_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def get_client_metadata_scopes(
|
||||
www_authenticate_scope: str | None,
|
||||
protected_resource_metadata: ProtectedResourceMetadata | None,
|
||||
authorization_server_metadata: OAuthMetadata | None = None,
|
||||
) -> str | None:
|
||||
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
|
||||
# Per MCP spec, scope selection priority order:
|
||||
# 1. Use scope from WWW-Authenticate header (if provided)
|
||||
# 2. Use all scopes from PRM scopes_supported (if available)
|
||||
# 3. Omit scope parameter if neither is available
|
||||
|
||||
if www_authenticate_scope is not None:
|
||||
# Priority 1: WWW-Authenticate header scope
|
||||
return www_authenticate_scope
|
||||
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
|
||||
# Priority 2: PRM scopes_supported
|
||||
return " ".join(protected_resource_metadata.scopes_supported)
|
||||
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
|
||||
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
|
||||
else:
|
||||
# Priority 3: Omit scope parameter
|
||||
return None
|
||||
|
||||
|
||||
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Generate ordered list of (url, type) tuples for discovery attempts.
|
||||
|
||||
Args:
|
||||
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
|
||||
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
|
||||
"""
|
||||
|
||||
if not auth_server_url:
|
||||
# Legacy path using the 2025-03-26 spec:
|
||||
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
|
||||
parsed = urlparse(server_url)
|
||||
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]
|
||||
|
||||
urls: list[str] = []
|
||||
parsed = urlparse(auth_server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# RFC 8414: Path-aware OAuth discovery
|
||||
if parsed.path and parsed.path != "/":
|
||||
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
|
||||
urls.append(urljoin(base_url, oauth_path))
|
||||
|
||||
# RFC 8414 section 5: Path-aware OIDC discovery
|
||||
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
|
||||
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
|
||||
urls.append(urljoin(base_url, oidc_path))
|
||||
|
||||
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
|
||||
urls.append(urljoin(base_url, oidc_path))
|
||||
return urls
|
||||
|
||||
# OAuth root
|
||||
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
|
||||
|
||||
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
|
||||
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
async def handle_protected_resource_response(
|
||||
response: Response,
|
||||
) -> ProtectedResourceMetadata | None:
|
||||
"""
|
||||
Handle protected resource metadata discovery response.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
|
||||
Returns:
|
||||
True if metadata was successfully discovered, False if we should try next URL
|
||||
"""
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
metadata = ProtectedResourceMetadata.model_validate_json(content)
|
||||
return metadata
|
||||
|
||||
except ValidationError: # pragma: no cover
|
||||
# Invalid metadata - try next URL
|
||||
return None
|
||||
else:
|
||||
# Not found - try next URL in fallback chain
|
||||
return None
|
||||
|
||||
|
||||
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
asm = OAuthMetadata.model_validate_json(content)
|
||||
return True, asm
|
||||
except ValidationError: # pragma: no cover
|
||||
return True, None
|
||||
elif response.status_code < 400 or response.status_code >= 500:
|
||||
return False, None # Non-4XX error, stop trying
|
||||
return True, None
|
||||
|
||||
|
||||
def create_oauth_metadata_request(url: str) -> Request:
|
||||
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
|
||||
|
||||
|
||||
def create_client_registration_request(
|
||||
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
|
||||
) -> Request:
|
||||
"""Build registration request or skip if already registered."""
|
||||
|
||||
if auth_server_metadata and auth_server_metadata.registration_endpoint:
|
||||
registration_url = str(auth_server_metadata.registration_endpoint)
|
||||
else:
|
||||
registration_url = urljoin(auth_base_url, "/register")
|
||||
|
||||
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
|
||||
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
|
||||
|
||||
|
||||
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
|
||||
"""Handle registration response."""
|
||||
if response.status_code not in (200, 201):
|
||||
await response.aread()
|
||||
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
|
||||
|
||||
try:
|
||||
content = await response.aread()
|
||||
client_info = OAuthClientInformationFull.model_validate_json(content)
|
||||
return client_info
|
||||
# self.context.client_info = client_info
|
||||
# await self.context.storage.set_client_info(client_info)
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise OAuthRegistrationError(f"Invalid registration response: {e}")
|
||||
|
||||
|
||||
async def handle_token_response_scopes(
|
||||
response: Response,
|
||||
) -> OAuthToken:
|
||||
"""Parse and validate token response with optional scope validation.
|
||||
|
||||
Parses token response JSON. Callers should check response.status_code before calling.
|
||||
|
||||
Args:
|
||||
response: HTTP response from token endpoint (status already checked by caller)
|
||||
|
||||
Returns:
|
||||
Validated OAuthToken model
|
||||
|
||||
Raises:
|
||||
OAuthTokenError: If response JSON is invalid
|
||||
"""
|
||||
try:
|
||||
content = await response.aread()
|
||||
token_response = OAuthToken.model_validate_json(content)
|
||||
return token_response
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise OAuthTokenError(f"Invalid token response: {e}")
|
||||
Reference in New Issue
Block a user