chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,85 @@
import argparse
import logging
import sys
from functools import partial
from urllib.parse import urlparse
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
import mcp.types as types
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.shared.message import SessionMessage
from mcp.shared.session import RequestResponder
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("client")
async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, Exception):
logger.error("Error: %s", message)
return
logger.info("Received message from server: %s", message)
async def run_session(
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
client_info: types.Implementation | None = None,
):
async with ClientSession(
read_stream,
write_stream,
message_handler=message_handler,
client_info=client_info,
) as session:
logger.info("Initializing session")
await session.initialize()
logger.info("Initialized")
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
env_dict = dict(env)
if urlparse(command_or_url).scheme in ("http", "https"):
# Use SSE client for HTTP(S) URLs
async with sse_client(command_or_url) as streams:
await run_session(*streams)
else:
# Use stdio client for commands
server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict)
async with stdio_client(server_parameters) as streams:
await run_session(*streams)
def cli():
parser = argparse.ArgumentParser()
parser.add_argument("command_or_url", help="Command or URL to connect to")
parser.add_argument("args", nargs="*", help="Additional arguments")
parser.add_argument(
"-e",
"--env",
nargs=2,
action="append",
metavar=("KEY", "VALUE"),
help="Environment variables to set. Can be used multiple times.",
default=[],
)
args = parser.parse_args()
anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio")
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,21 @@
"""
OAuth2 Authentication implementation for HTTPX.
Implements authorization code flow with PKCE and automatic token refresh.
"""
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
from mcp.client.auth.oauth2 import (
OAuthClientProvider,
PKCEParameters,
TokenStorage,
)
__all__ = [
"OAuthClientProvider",
"OAuthFlowError",
"OAuthRegistrationError",
"OAuthTokenError",
"PKCEParameters",
"TokenStorage",
]

View File

@@ -0,0 +1,10 @@
class OAuthFlowError(Exception):
"""Base exception for OAuth flow errors."""
class OAuthTokenError(OAuthFlowError):
"""Raised when token operations fail."""
class OAuthRegistrationError(OAuthFlowError):
"""Raised when client registration fails."""

View File

@@ -0,0 +1,148 @@
import time
from collections.abc import Awaitable, Callable
from typing import Any
from uuid import uuid4
import httpx
import jwt
from pydantic import BaseModel, Field
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
from mcp.shared.auth import OAuthClientMetadata
class JWTParameters(BaseModel):
"""JWT parameters."""
assertion: str | None = Field(
default=None,
description="JWT assertion for JWT authentication. "
"Will be used instead of generating a new assertion if provided.",
)
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
if self.assertion is not None:
# Prebuilt JWT (e.g. acquired out-of-band)
assertion = self.assertion
else:
if not self.jwt_signing_key:
raise OAuthFlowError("Missing signing key for JWT bearer grant") # pragma: no cover
if not self.issuer:
raise OAuthFlowError("Missing issuer for JWT bearer grant") # pragma: no cover
if not self.subject:
raise OAuthFlowError("Missing subject for JWT bearer grant") # pragma: no cover
audience = self.audience if self.audience else with_audience_fallback
if not audience:
raise OAuthFlowError("Missing audience for JWT bearer grant") # pragma: no cover
now = int(time.time())
claims: dict[str, Any] = {
"iss": self.issuer,
"sub": self.subject,
"aud": audience,
"exp": now + self.jwt_lifetime_seconds,
"iat": now,
"jti": str(uuid4()),
}
claims.update(self.claims or {})
assertion = jwt.encode(
claims,
self.jwt_signing_key,
algorithm=self.jwt_signing_algorithm or "RS256",
)
return assertion
class RFC7523OAuthClientProvider(OAuthClientProvider):
"""OAuth client provider for RFC7532 clients."""
jwt_parameters: JWTParameters | None = None
def __init__(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
jwt_parameters: JWTParameters | None = None,
) -> None:
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
self.jwt_parameters = jwt_parameters
async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
) -> httpx.Request: # pragma: no cover
"""Build token exchange request for authorization_code flow."""
token_data = token_data or {}
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
self._add_client_authentication_jwt(token_data=token_data)
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
"""Perform the authorization flow."""
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
token_request = await self._exchange_token_jwt_bearer()
return token_request
else:
return await super()._perform_authorization()
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
"""Add JWT assertion for client authentication to token endpoint parameters."""
if not self.jwt_parameters:
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
if not self.context.oauth_metadata:
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
# We need to set the audience to the issuer identifier of the authorization server
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
issuer = str(self.context.oauth_metadata.issuer)
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
token_data["client_assertion"] = assertion
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
# We need to set the audience to the resource server, the audience is difference from the one in claims
# it represents the resource server that will validate the token
token_data["audience"] = self.context.get_resource_url()
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
"""Build token exchange request for JWT bearer grant."""
if not self.context.client_info:
raise OAuthFlowError("Missing client info") # pragma: no cover
if not self.jwt_parameters:
raise OAuthFlowError("Missing JWT parameters") # pragma: no cover
if not self.context.oauth_metadata:
raise OAuthTokenError("Missing OAuth metadata") # pragma: no cover
# We need to set the audience to the issuer identifier of the authorization server
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
issuer = str(self.context.oauth_metadata.issuer)
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
token_data = {
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
"assertion": assertion,
}
if self.context.should_include_resource_param(self.context.protocol_version): # pragma: no branch
token_data["resource"] = self.context.get_resource_url()
if self.context.client_metadata.scope: # pragma: no branch
token_data["scope"] = self.context.client_metadata.scope
token_url = self._get_token_endpoint()
return httpx.Request(
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
)

View File

@@ -0,0 +1,557 @@
"""
OAuth2 Authentication implementation for HTTPX.
Implements authorization code flow with PKCE and automatic token refresh.
"""
import base64
import hashlib
import logging
import secrets
import string
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Protocol
from urllib.parse import urlencode, urljoin, urlparse
import anyio
import httpx
from pydantic import BaseModel, Field, ValidationError
from mcp.client.auth import OAuthFlowError, OAuthTokenError
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
create_client_registration_request,
create_oauth_metadata_request,
extract_field_from_www_auth,
extract_resource_metadata_from_www_auth,
extract_scope_from_www_auth,
get_client_metadata_scopes,
handle_auth_metadata_response,
handle_protected_resource_response,
handle_registration_response,
handle_token_response_scopes,
)
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
ProtectedResourceMetadata,
)
from mcp.shared.auth_utils import (
calculate_token_expiry,
check_resource_allowed,
resource_url_from_server_url,
)
logger = logging.getLogger(__name__)
class PKCEParameters(BaseModel):
"""PKCE (Proof Key for Code Exchange) parameters."""
code_verifier: str = Field(..., min_length=43, max_length=128)
code_challenge: str = Field(..., min_length=43, max_length=128)
@classmethod
def generate(cls) -> "PKCEParameters":
"""Generate new PKCE parameters."""
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=")
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
class TokenStorage(Protocol):
"""Protocol for token storage implementations."""
async def get_tokens(self) -> OAuthToken | None:
"""Get stored tokens."""
...
async def set_tokens(self, tokens: OAuthToken) -> None:
"""Store tokens."""
...
async def get_client_info(self) -> OAuthClientInformationFull | None:
"""Get stored client information."""
...
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
"""Store client information."""
...
@dataclass
class OAuthContext:
"""OAuth flow context."""
server_url: str
client_metadata: OAuthClientMetadata
storage: TokenStorage
redirect_handler: Callable[[str], Awaitable[None]] | None
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
timeout: float = 300.0
# Discovered metadata
protected_resource_metadata: ProtectedResourceMetadata | None = None
oauth_metadata: OAuthMetadata | None = None
auth_server_url: str | None = None
protocol_version: str | None = None
# Client registration
client_info: OAuthClientInformationFull | None = None
# Token management
current_tokens: OAuthToken | None = None
token_expiry_time: float | None = None
# State
lock: anyio.Lock = field(default_factory=anyio.Lock)
def get_authorization_base_url(self, server_url: str) -> str:
"""Extract base URL by removing path component."""
parsed = urlparse(server_url)
return f"{parsed.scheme}://{parsed.netloc}"
def update_token_expiry(self, token: OAuthToken) -> None:
"""Update token expiry time using shared util function."""
self.token_expiry_time = calculate_token_expiry(token.expires_in)
def is_token_valid(self) -> bool:
"""Check if current token is valid."""
return bool(
self.current_tokens
and self.current_tokens.access_token
and (not self.token_expiry_time or time.time() <= self.token_expiry_time)
)
def can_refresh_token(self) -> bool:
"""Check if token can be refreshed."""
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)
def clear_tokens(self) -> None:
"""Clear current tokens."""
self.current_tokens = None
self.token_expiry_time = None
def get_resource_url(self) -> str:
"""Get resource URL for RFC 8707.
Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
"""
resource = resource_url_from_server_url(self.server_url)
# If PRM provides a resource that's a valid parent, use it
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
prm_resource = str(self.protected_resource_metadata.resource)
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
resource = prm_resource
return resource
def should_include_resource_param(self, protocol_version: str | None = None) -> bool:
"""Determine if the resource parameter should be included in OAuth requests.
Returns True if:
- Protected resource metadata is available, OR
- MCP-Protocol-Version header is 2025-06-18 or later
"""
# If we have protected resource metadata, include the resource param
if self.protected_resource_metadata is not None:
return True
# If no protocol version provided, don't include resource param
if not protocol_version:
return False
# Check if protocol version is 2025-06-18 or later
# Version format is YYYY-MM-DD, so string comparison works
return protocol_version >= "2025-06-18"
class OAuthClientProvider(httpx.Auth):
"""
OAuth2 authentication for httpx.
Handles OAuth flow with automatic client registration and token storage.
"""
requires_response_body = True
def __init__(
self,
server_url: str,
client_metadata: OAuthClientMetadata,
storage: TokenStorage,
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
timeout: float = 300.0,
):
"""Initialize OAuth2 authentication."""
self.context = OAuthContext(
server_url=server_url,
client_metadata=client_metadata,
storage=storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
timeout=timeout,
)
self._initialized = False
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
"""
Handle protected resource metadata discovery response.
Per SEP-985, supports fallback when discovery fails at one URL.
Returns:
True if metadata was successfully discovered, False if we should try next URL
"""
if response.status_code == 200:
try:
content = await response.aread()
metadata = ProtectedResourceMetadata.model_validate_json(content)
self.context.protected_resource_metadata = metadata
if metadata.authorization_servers: # pragma: no branch
self.context.auth_server_url = str(metadata.authorization_servers[0])
return True
except ValidationError: # pragma: no cover
# Invalid metadata - try next URL
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
return False
elif response.status_code == 404: # pragma: no cover
# Not found - try next URL in fallback chain
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
return False
else:
# Other error - fail immediately
raise OAuthFlowError(
f"Protected Resource Metadata request failed: {response.status_code}"
) # pragma: no cover
async def _register_client(self) -> httpx.Request | None:
"""Build registration request or skip if already registered."""
if self.context.client_info:
return None
if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint:
registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
registration_url = urljoin(auth_base_url, "/register")
registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
return httpx.Request(
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
)
async def _perform_authorization(self) -> httpx.Request:
"""Perform the authorization flow."""
auth_code, code_verifier = await self._perform_authorization_code_grant()
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
return token_request
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
"""Perform the authorization redirect and get auth code."""
if self.context.client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
if not self.context.redirect_handler:
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
if not self.context.callback_handler:
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
auth_endpoint = urljoin(auth_base_url, "/authorize")
if not self.context.client_info:
raise OAuthFlowError("No client info available for authorization") # pragma: no cover
# Generate PKCE parameters
pkce_params = PKCEParameters.generate()
state = secrets.token_urlsafe(32)
auth_params = {
"response_type": "code",
"client_id": self.context.client_info.client_id,
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
"state": state,
"code_challenge": pkce_params.code_challenge,
"code_challenge_method": "S256",
}
# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover
if self.context.client_metadata.scope: # pragma: no branch
auth_params["scope"] = self.context.client_metadata.scope
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
await self.context.redirect_handler(authorization_url)
# Wait for callback
auth_code, returned_state = await self.context.callback_handler()
if returned_state is None or not secrets.compare_digest(returned_state, state):
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover
if not auth_code:
raise OAuthFlowError("No authorization code received") # pragma: no cover
# Return auth code and code verifier for token exchange
return auth_code, pkce_params.code_verifier
def _get_token_endpoint(self) -> str:
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
token_url = str(self.context.oauth_metadata.token_endpoint)
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
token_url = urljoin(auth_base_url, "/token")
return token_url
async def _exchange_token_authorization_code(
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
) -> httpx.Request:
"""Build token exchange request for authorization_code flow."""
if self.context.client_metadata.redirect_uris is None:
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
if not self.context.client_info:
raise OAuthFlowError("Missing client info") # pragma: no cover
token_url = self._get_token_endpoint()
token_data = token_data or {}
token_data.update(
{
"grant_type": "authorization_code",
"code": auth_code,
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
"client_id": self.context.client_info.client_id,
"code_verifier": code_verifier,
}
)
# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url() # RFC 8707
if self.context.client_info.client_secret:
token_data["client_secret"] = self.context.client_info.client_secret
return httpx.Request(
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
)
async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response."""
if response.status_code != 200:
body = await response.aread() # pragma: no cover
body_text = body.decode("utf-8") # pragma: no cover
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
# Parse and validate response with scope validation
token_response = await handle_token_response_scopes(response)
# Store tokens in context
self.context.current_tokens = token_response
self.context.update_token_expiry(token_response)
await self.context.storage.set_tokens(token_response)
async def _refresh_token(self) -> httpx.Request:
"""Build token refresh request."""
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
raise OAuthTokenError("No refresh token available") # pragma: no cover
if not self.context.client_info:
raise OAuthTokenError("No client info available") # pragma: no cover
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
token_url = urljoin(auth_base_url, "/token")
refresh_data = {
"grant_type": "refresh_token",
"refresh_token": self.context.current_tokens.refresh_token,
"client_id": self.context.client_info.client_id,
}
# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
if self.context.client_info.client_secret: # pragma: no branch
refresh_data["client_secret"] = self.context.client_info.client_secret
return httpx.Request(
"POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
)
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
"""Handle token refresh response. Returns True if successful."""
if response.status_code != 200:
logger.warning(f"Token refresh failed: {response.status_code}")
self.context.clear_tokens()
return False
try:
content = await response.aread()
token_response = OAuthToken.model_validate_json(content)
self.context.current_tokens = token_response
self.context.update_token_expiry(token_response)
await self.context.storage.set_tokens(token_response)
return True
except ValidationError:
logger.exception("Invalid refresh response")
self.context.clear_tokens()
return False
async def _initialize(self) -> None: # pragma: no cover
"""Load stored tokens and client info."""
self.context.current_tokens = await self.context.storage.get_tokens()
self.context.client_info = await self.context.storage.get_client_info()
self._initialized = True
def _add_auth_header(self, request: httpx.Request) -> None:
"""Add authorization header to request if we have valid tokens."""
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
content = await response.aread()
metadata = OAuthMetadata.model_validate_json(content)
self.context.oauth_metadata = metadata
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
async with self.context.lock:
if not self._initialized:
await self._initialize() # pragma: no cover
# Capture protocol version from request headers
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token() # pragma: no cover
refresh_response = yield refresh_request # pragma: no cover
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
# Refresh failed, need full re-authentication
self._initialized = False
if self.context.is_token_valid():
self._add_auth_header(request)
response = yield request
if response.status_code == 401:
# Perform full OAuth flow
try:
# OAuth flow must be inline due to generator constraints
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url, self.context.server_url
)
for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
discovery_response = yield discovery_request # sending request
prm = await handle_protected_resource_response(discovery_response)
if prm:
self.context.protected_resource_metadata = prm
# todo: try all authorization_servers to find the OASM
assert (
len(prm.authorization_servers) > 0
) # this is always true as authorization_servers has a min length of 1
self.context.auth_server_url = str(prm.authorization_servers[0])
break
else:
logger.debug(f"Protected resource metadata discovery failed: {url}")
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, self.context.server_url
)
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no cover
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
if not ok:
break
if ok and asm:
self.context.oauth_metadata = asm
break
else:
logger.debug(f"OAuth metadata discovery failed: {url}")
# Step 3: Apply scope selection strategy
self.context.client_metadata.scope = get_client_metadata_scopes(
www_auth_resource_metadata_url,
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)
# Step 4: Register client if needed
registration_request = create_client_registration_request(
self.context.oauth_metadata,
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
if not self.context.client_info:
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)
# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise
# Retry with new tokens
self._add_auth_header(request)
yield request
elif response.status_code == 403:
# Step 1: Extract error field from WWW-Authenticate header
error = extract_field_from_www_auth(response, "error")
# Step 2: Check if we need to step-up authorization
if error == "insufficient_scope": # pragma: no branch
try:
# Step 2a: Update the required scopes
self.context.client_metadata.scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
)
# Step 2b: Perform (re-)authorization and token exchange
token_response = yield await self._perform_authorization()
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise
# Retry with new tokens
self._add_auth_header(request)
yield request

View File

@@ -0,0 +1,267 @@
import logging
import re
from urllib.parse import urljoin, urlparse
from httpx import Request, Response
from pydantic import ValidationError
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
from mcp.shared.auth import (
OAuthClientInformationFull,
OAuthClientMetadata,
OAuthMetadata,
OAuthToken,
ProtectedResourceMetadata,
)
from mcp.types import LATEST_PROTOCOL_VERSION
logger = logging.getLogger(__name__)
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
"""
Extract field from WWW-Authenticate header.
Returns:
Field value if found in WWW-Authenticate header, None otherwise
"""
www_auth_header = response.headers.get("WWW-Authenticate")
if not www_auth_header:
return None
# Pattern matches: field_name="value" or field_name=value (unquoted)
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
match = re.search(pattern, www_auth_header)
if match:
# Return quoted value if present, otherwise unquoted value
return match.group(1) or match.group(2)
return None
def extract_scope_from_www_auth(response: Response) -> str | None:
"""
Extract scope parameter from WWW-Authenticate header as per RFC6750.
Returns:
Scope string if found in WWW-Authenticate header, None otherwise
"""
return extract_field_from_www_auth(response, "scope")
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
"""
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
Returns:
Resource metadata URL if found in WWW-Authenticate header, None otherwise
"""
if not response or response.status_code != 401:
return None # pragma: no cover
return extract_field_from_www_auth(response, "resource_metadata")
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
"""
Build ordered list of URLs to try for protected resource metadata discovery.
Per SEP-985, the client MUST:
1. Try resource_metadata from WWW-Authenticate header (if present)
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
Args:
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
server_url: server url
Returns:
Ordered list of URLs to try for discovery
"""
urls: list[str] = []
# Priority 1: WWW-Authenticate header with resource_metadata parameter
if www_auth_url:
urls.append(www_auth_url)
# Priority 2-3: Well-known URIs (RFC 9728)
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
# Priority 2: Path-based well-known URI (if server has a path component)
if parsed.path and parsed.path != "/":
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
urls.append(path_based_url)
# Priority 3: Root-based well-known URI
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
urls.append(root_based_url)
return urls
def get_client_metadata_scopes(
www_authenticate_scope: str | None,
protected_resource_metadata: ProtectedResourceMetadata | None,
authorization_server_metadata: OAuthMetadata | None = None,
) -> str | None:
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
# Per MCP spec, scope selection priority order:
# 1. Use scope from WWW-Authenticate header (if provided)
# 2. Use all scopes from PRM scopes_supported (if available)
# 3. Omit scope parameter if neither is available
if www_authenticate_scope is not None:
# Priority 1: WWW-Authenticate header scope
return www_authenticate_scope
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
# Priority 2: PRM scopes_supported
return " ".join(protected_resource_metadata.scopes_supported)
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
else:
# Priority 3: Omit scope parameter
return None
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""
Generate ordered list of (url, type) tuples for discovery attempts.
Args:
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
"""
if not auth_server_url:
# Legacy path using the 2025-03-26 spec:
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
parsed = urlparse(server_url)
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]
urls: list[str] = []
parsed = urlparse(auth_server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
# RFC 8414: Path-aware OAuth discovery
if parsed.path and parsed.path != "/":
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oauth_path))
# RFC 8414 section 5: Path-aware OIDC discovery
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oidc_path))
# https://openid.net/specs/openid-connect-discovery-1_0.html
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
urls.append(urljoin(base_url, oidc_path))
return urls
# OAuth root
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
# https://openid.net/specs/openid-connect-discovery-1_0.html
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))
return urls
async def handle_protected_resource_response(
response: Response,
) -> ProtectedResourceMetadata | None:
"""
Handle protected resource metadata discovery response.
Per SEP-985, supports fallback when discovery fails at one URL.
Returns:
True if metadata was successfully discovered, False if we should try next URL
"""
if response.status_code == 200:
try:
content = await response.aread()
metadata = ProtectedResourceMetadata.model_validate_json(content)
return metadata
except ValidationError: # pragma: no cover
# Invalid metadata - try next URL
return None
else:
# Not found - try next URL in fallback chain
return None
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
if response.status_code == 200:
try:
content = await response.aread()
asm = OAuthMetadata.model_validate_json(content)
return True, asm
except ValidationError: # pragma: no cover
return True, None
elif response.status_code < 400 or response.status_code >= 500:
return False, None # Non-4XX error, stop trying
return True, None
def create_oauth_metadata_request(url: str) -> Request:
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
def create_client_registration_request(
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
) -> Request:
"""Build registration request or skip if already registered."""
if auth_server_metadata and auth_server_metadata.registration_endpoint:
registration_url = str(auth_server_metadata.registration_endpoint)
else:
registration_url = urljoin(auth_base_url, "/register")
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
"""Handle registration response."""
if response.status_code not in (200, 201):
await response.aread()
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
try:
content = await response.aread()
client_info = OAuthClientInformationFull.model_validate_json(content)
return client_info
# self.context.client_info = client_info
# await self.context.storage.set_client_info(client_info)
except ValidationError as e: # pragma: no cover
raise OAuthRegistrationError(f"Invalid registration response: {e}")
async def handle_token_response_scopes(
response: Response,
) -> OAuthToken:
"""Parse and validate token response with optional scope validation.
Parses token response JSON. Callers should check response.status_code before calling.
Args:
response: HTTP response from token endpoint (status already checked by caller)
Returns:
Validated OAuthToken model
Raises:
OAuthTokenError: If response JSON is invalid
"""
try:
content = await response.aread()
token_response = OAuthToken.model_validate_json(content)
return token_response
except ValidationError as e: # pragma: no cover
raise OAuthTokenError(f"Invalid token response: {e}")

View File

@@ -0,0 +1,555 @@
import logging
from datetime import timedelta
from typing import Any, Protocol, overload
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from jsonschema import SchemaError, ValidationError, validate
from pydantic import AnyUrl, TypeAdapter
from typing_extensions import deprecated
import mcp.types as types
from mcp.shared.context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
logger = logging.getLogger("client")
class SamplingFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData: ... # pragma: no branch
class ElicitationFnT(Protocol):
async def __call__(
self,
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
class ListRootsFnT(Protocol):
async def __call__(
self, context: RequestContext["ClientSession", Any]
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
class LoggingFnT(Protocol):
async def __call__(
self,
params: types.LoggingMessageNotificationParams,
) -> None: ... # pragma: no branch
class MessageHandlerFnT(Protocol):
async def __call__(
self,
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None: ... # pragma: no branch
async def _default_message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
await anyio.lowlevel.checkpoint()
async def _default_sampling_callback(
context: RequestContext["ClientSession", Any],
params: types.CreateMessageRequestParams,
) -> types.CreateMessageResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Sampling not supported",
)
async def _default_elicitation_callback(
context: RequestContext["ClientSession", Any],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData( # pragma: no cover
code=types.INVALID_REQUEST,
message="Elicitation not supported",
)
async def _default_list_roots_callback(
context: RequestContext["ClientSession", Any],
) -> types.ListRootsResult | types.ErrorData:
return types.ErrorData(
code=types.INVALID_REQUEST,
message="List roots not supported",
)
async def _default_logging_callback(
params: types.LoggingMessageNotificationParams,
) -> None:
pass
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
class ClientSession(
BaseSession[
types.ClientRequest,
types.ClientNotification,
types.ClientResult,
types.ServerRequest,
types.ServerNotification,
]
):
def __init__(
self,
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
write_stream: MemoryObjectSendStream[SessionMessage],
read_timeout_seconds: timedelta | None = None,
sampling_callback: SamplingFnT | None = None,
elicitation_callback: ElicitationFnT | None = None,
list_roots_callback: ListRootsFnT | None = None,
logging_callback: LoggingFnT | None = None,
message_handler: MessageHandlerFnT | None = None,
client_info: types.Implementation | None = None,
) -> None:
super().__init__(
read_stream,
write_stream,
types.ServerRequest,
types.ServerNotification,
read_timeout_seconds=read_timeout_seconds,
)
self._client_info = client_info or DEFAULT_CLIENT_INFO
self._sampling_callback = sampling_callback or _default_sampling_callback
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
self._logging_callback = logging_callback or _default_logging_callback
self._message_handler = message_handler or _default_message_handler
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
self._server_capabilities: types.ServerCapabilities | None = None
async def initialize(self) -> types.InitializeResult:
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
elicitation = (
types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None
)
roots = (
# TODO: Should this be based on whether we
# _will_ send notifications, or only whether
# they're supported?
types.RootsCapability(listChanged=True)
if self._list_roots_callback is not _default_list_roots_callback
else None
)
result = await self.send_request(
types.ClientRequest(
types.InitializeRequest(
params=types.InitializeRequestParams(
protocolVersion=types.LATEST_PROTOCOL_VERSION,
capabilities=types.ClientCapabilities(
sampling=sampling,
elicitation=elicitation,
experimental=None,
roots=roots,
),
clientInfo=self._client_info,
),
)
),
types.InitializeResult,
)
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
self._server_capabilities = result.capabilities
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
return result
def get_server_capabilities(self) -> types.ServerCapabilities | None:
"""Return the server capabilities received during initialization.
Returns None if the session has not been initialized yet.
"""
return self._server_capabilities
async def send_ping(self) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(
types.ClientRequest(types.PingRequest()),
types.EmptyResult,
)
async def send_progress_notification(
self,
progress_token: str | int,
progress: float,
total: float | None = None,
message: str | None = None,
) -> None:
"""Send a progress notification."""
await self.send_notification(
types.ClientNotification(
types.ProgressNotification(
params=types.ProgressNotificationParams(
progressToken=progress_token,
progress=progress,
total=total,
message=message,
),
),
)
)
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
"""Send a logging/setLevel request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.SetLevelRequest(
params=types.SetLevelRequestParams(level=level),
)
),
types.EmptyResult,
)
@overload
@deprecated("Use list_resources(params=PaginatedRequestParams(...)) instead")
async def list_resources(self, cursor: str | None) -> types.ListResourcesResult: ...
@overload
async def list_resources(self, *, params: types.PaginatedRequestParams | None) -> types.ListResourcesResult: ...
@overload
async def list_resources(self) -> types.ListResourcesResult: ...
async def list_resources(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourcesResult:
"""Send a resources/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListResourcesRequest(params=request_params)),
types.ListResourcesResult,
)
@overload
@deprecated("Use list_resource_templates(params=PaginatedRequestParams(...)) instead")
async def list_resource_templates(self, cursor: str | None) -> types.ListResourceTemplatesResult: ...
@overload
async def list_resource_templates(
self, *, params: types.PaginatedRequestParams | None
) -> types.ListResourceTemplatesResult: ...
@overload
async def list_resource_templates(self) -> types.ListResourceTemplatesResult: ...
async def list_resource_templates(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListResourceTemplatesResult:
"""Send a resources/templates/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListResourceTemplatesRequest(params=request_params)),
types.ListResourceTemplatesResult,
)
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
"""Send a resources/read request."""
return await self.send_request(
types.ClientRequest(
types.ReadResourceRequest(
params=types.ReadResourceRequestParams(uri=uri),
)
),
types.ReadResourceResult,
)
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/subscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.SubscribeRequest(
params=types.SubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
"""Send a resources/unsubscribe request."""
return await self.send_request( # pragma: no cover
types.ClientRequest(
types.UnsubscribeRequest(
params=types.UnsubscribeRequestParams(uri=uri),
)
),
types.EmptyResult,
)
async def call_tool(
self,
name: str,
arguments: dict[str, Any] | None = None,
read_timeout_seconds: timedelta | None = None,
progress_callback: ProgressFnT | None = None,
*,
meta: dict[str, Any] | None = None,
) -> types.CallToolResult:
"""Send a tools/call request with optional progress callback support."""
_meta: types.RequestParams.Meta | None = None
if meta is not None:
_meta = types.RequestParams.Meta(**meta)
result = await self.send_request(
types.ClientRequest(
types.CallToolRequest(
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
)
),
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
)
if not result.isError:
await self._validate_tool_result(name, result)
return result
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
"""Validate the structured content of a tool result against its output schema."""
if name not in self._tool_output_schemas:
# refresh output schema cache
await self.list_tools()
output_schema = None
if name in self._tool_output_schemas:
output_schema = self._tool_output_schemas.get(name)
else:
logger.warning(f"Tool {name} not listed by server, cannot validate any structured content")
if output_schema is not None:
if result.structuredContent is None:
raise RuntimeError(
f"Tool {name} has an output schema but did not return structured content"
) # pragma: no cover
try:
validate(result.structuredContent, output_schema)
except ValidationError as e:
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") # pragma: no cover
except SchemaError as e: # pragma: no cover
raise RuntimeError(f"Invalid schema for tool {name}: {e}") # pragma: no cover
@overload
@deprecated("Use list_prompts(params=PaginatedRequestParams(...)) instead")
async def list_prompts(self, cursor: str | None) -> types.ListPromptsResult: ...
@overload
async def list_prompts(self, *, params: types.PaginatedRequestParams | None) -> types.ListPromptsResult: ...
@overload
async def list_prompts(self) -> types.ListPromptsResult: ...
async def list_prompts(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListPromptsResult:
"""Send a prompts/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
return await self.send_request(
types.ClientRequest(types.ListPromptsRequest(params=request_params)),
types.ListPromptsResult,
)
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
"""Send a prompts/get request."""
return await self.send_request(
types.ClientRequest(
types.GetPromptRequest(
params=types.GetPromptRequestParams(name=name, arguments=arguments),
)
),
types.GetPromptResult,
)
async def complete(
self,
ref: types.ResourceTemplateReference | types.PromptReference,
argument: dict[str, str],
context_arguments: dict[str, str] | None = None,
) -> types.CompleteResult:
"""Send a completion/complete request."""
context = None
if context_arguments is not None:
context = types.CompletionContext(arguments=context_arguments)
return await self.send_request(
types.ClientRequest(
types.CompleteRequest(
params=types.CompleteRequestParams(
ref=ref,
argument=types.CompletionArgument(**argument),
context=context,
),
)
),
types.CompleteResult,
)
@overload
@deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead")
async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ...
@overload
async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ...
@overload
async def list_tools(self) -> types.ListToolsResult: ...
async def list_tools(
self,
cursor: str | None = None,
*,
params: types.PaginatedRequestParams | None = None,
) -> types.ListToolsResult:
"""Send a tools/list request.
Args:
cursor: Simple cursor string for pagination (deprecated, use params instead)
params: Full pagination parameters including cursor and any future fields
"""
if params is not None and cursor is not None:
raise ValueError("Cannot specify both cursor and params")
if params is not None:
request_params = params
elif cursor is not None:
request_params = types.PaginatedRequestParams(cursor=cursor)
else:
request_params = None
result = await self.send_request(
types.ClientRequest(types.ListToolsRequest(params=request_params)),
types.ListToolsResult,
)
# Cache tool output schemas for future validation
# Note: don't clear the cache, as we may be using a cursor
for tool in result.tools:
self._tool_output_schemas[tool.name] = tool.outputSchema
return result
async def send_roots_list_changed(self) -> None: # pragma: no cover
"""Send a roots/list_changed notification."""
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
ctx = RequestContext[ClientSession, Any](
request_id=responder.request_id,
meta=responder.request_meta,
session=self,
lifespan_context=None,
)
match responder.request.root:
case types.CreateMessageRequest(params=params):
with responder:
response = await self._sampling_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ElicitRequest(params=params):
with responder:
response = await self._elicitation_callback(ctx, params)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.ListRootsRequest():
with responder:
response = await self._list_roots_callback(ctx)
client_response = ClientResponse.validate_python(response)
await responder.respond(client_response)
case types.PingRequest(): # pragma: no cover
with responder:
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
async def _handle_incoming(
self,
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
"""Handle incoming messages by forwarding to the message handler."""
await self._message_handler(req)
async def _received_notification(self, notification: types.ServerNotification) -> None:
"""Handle notifications from the server."""
# Process specific notification types
match notification.root:
case types.LoggingMessageNotification(params=params):
await self._logging_callback(params)
case _:
pass

View File

@@ -0,0 +1,366 @@
"""
SessionGroup concurrently manages multiple MCP session connections.
Tools, resources, and prompts are aggregated across servers. Servers may
be connected to or disconnected from at any point after initialization.
This abstractions can handle naming collisions using a custom user-provided
hook.
"""
import contextlib
import logging
from collections.abc import Callable
from datetime import timedelta
from types import TracebackType
from typing import Any, TypeAlias
import anyio
from pydantic import BaseModel
from typing_extensions import Self
import mcp
from mcp import types
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.exceptions import McpError
class SseServerParameters(BaseModel):
"""Parameters for intializing a sse_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: float = 5
# Timeout for SSE read operations.
sse_read_timeout: float = 60 * 5
class StreamableHttpParameters(BaseModel):
"""Parameters for intializing a streamablehttp_client."""
# The endpoint URL.
url: str
# Optional headers to include in requests.
headers: dict[str, Any] | None = None
# HTTP timeout for regular operations.
timeout: timedelta = timedelta(seconds=30)
# Timeout for SSE read operations.
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
# Close the client session when the transport closes.
terminate_on_close: bool = True
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
class ClientSessionGroup:
"""Client for managing connections to multiple MCP servers.
This class is responsible for encapsulating management of server connections.
It aggregates tools, resources, and prompts from all connected servers.
For auxiliary handlers, such as resource subscription, this is delegated to
the client and can be accessed via the session.
Example Usage:
name_fn = lambda name, server_info: f"{(server_info.name)}_{name}"
async with ClientSessionGroup(component_name_hook=name_fn) as group:
for server_param in server_params:
await group.connect_to_server(server_param)
...
"""
class _ComponentNames(BaseModel):
"""Used for reverse index to find components."""
prompts: set[str] = set()
resources: set[str] = set()
tools: set[str] = set()
# Standard MCP components.
_prompts: dict[str, types.Prompt]
_resources: dict[str, types.Resource]
_tools: dict[str, types.Tool]
# Client-server connection management.
_sessions: dict[mcp.ClientSession, _ComponentNames]
_tool_to_session: dict[str, mcp.ClientSession]
_exit_stack: contextlib.AsyncExitStack
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
# Optional fn consuming (component_name, serverInfo) for custom names.
# This is provide a means to mitigate naming conflicts across servers.
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
_component_name_hook: _ComponentNameHook | None
def __init__(
self,
exit_stack: contextlib.AsyncExitStack | None = None,
component_name_hook: _ComponentNameHook | None = None,
) -> None:
"""Initializes the MCP client."""
self._tools = {}
self._resources = {}
self._prompts = {}
self._sessions = {}
self._tool_to_session = {}
if exit_stack is None:
self._exit_stack = contextlib.AsyncExitStack()
self._owns_exit_stack = True
else:
self._exit_stack = exit_stack
self._owns_exit_stack = False
self._session_exit_stacks = {}
self._component_name_hook = component_name_hook
async def __aenter__(self) -> Self: # pragma: no cover
# Enter the exit stack only if we created it ourselves
if self._owns_exit_stack:
await self._exit_stack.__aenter__()
return self
async def __aexit__(
self,
_exc_type: type[BaseException] | None,
_exc_val: BaseException | None,
_exc_tb: TracebackType | None,
) -> bool | None: # pragma: no cover
"""Closes session exit stacks and main exit stack upon completion."""
# Only close the main exit stack if we created it
if self._owns_exit_stack:
await self._exit_stack.aclose()
# Concurrently close session stacks.
async with anyio.create_task_group() as tg:
for exit_stack in self._session_exit_stacks.values():
tg.start_soon(exit_stack.aclose)
@property
def sessions(self) -> list[mcp.ClientSession]:
"""Returns the list of sessions being managed."""
return list(self._sessions.keys()) # pragma: no cover
@property
def prompts(self) -> dict[str, types.Prompt]:
"""Returns the prompts as a dictionary of names to prompts."""
return self._prompts
@property
def resources(self) -> dict[str, types.Resource]:
"""Returns the resources as a dictionary of names to resources."""
return self._resources
@property
def tools(self) -> dict[str, types.Tool]:
"""Returns the tools as a dictionary of names to tools."""
return self._tools
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
"""Executes a tool given its name and arguments."""
session = self._tool_to_session[name]
session_tool_name = self.tools[name].name
return await session.call_tool(session_tool_name, args)
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
"""Disconnects from a single MCP server."""
session_known_for_components = session in self._sessions
session_known_for_stack = session in self._session_exit_stacks
if not session_known_for_components and not session_known_for_stack:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message="Provided session is not managed or already disconnected.",
)
)
if session_known_for_components: # pragma: no cover
component_names = self._sessions.pop(session) # Pop from _sessions tracking
# Remove prompts associated with the session.
for name in component_names.prompts:
if name in self._prompts:
del self._prompts[name]
# Remove resources associated with the session.
for name in component_names.resources:
if name in self._resources:
del self._resources[name]
# Remove tools associated with the session.
for name in component_names.tools:
if name in self._tools:
del self._tools[name]
if name in self._tool_to_session:
del self._tool_to_session[name]
# Clean up the session's resources via its dedicated exit stack
if session_known_for_stack:
session_stack_to_close = self._session_exit_stacks.pop(session) # pragma: no cover
await session_stack_to_close.aclose() # pragma: no cover
async def connect_with_session(
self, server_info: types.Implementation, session: mcp.ClientSession
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
await self._aggregate_components(server_info, session)
return session
async def connect_to_server(
self,
server_params: ServerParameters,
) -> mcp.ClientSession:
"""Connects to a single MCP server."""
server_info, session = await self._establish_session(server_params)
return await self.connect_with_session(server_info, session)
async def _establish_session(
self, server_params: ServerParameters
) -> tuple[types.Implementation, mcp.ClientSession]:
"""Establish a client session to an MCP server."""
session_stack = contextlib.AsyncExitStack()
try:
# Create read and write streams that facilitate io with the server.
if isinstance(server_params, StdioServerParameters):
client = mcp.stdio_client(server_params)
read, write = await session_stack.enter_async_context(client)
elif isinstance(server_params, SseServerParameters):
client = sse_client(
url=server_params.url,
headers=server_params.headers,
timeout=server_params.timeout,
sse_read_timeout=server_params.sse_read_timeout,
)
read, write = await session_stack.enter_async_context(client)
else:
client = streamablehttp_client(
url=server_params.url,
headers=server_params.headers,
timeout=server_params.timeout,
sse_read_timeout=server_params.sse_read_timeout,
terminate_on_close=server_params.terminate_on_close,
)
read, write, _ = await session_stack.enter_async_context(client)
session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
result = await session.initialize()
# Session successfully initialized.
# Store its stack and register the stack with the main group stack.
self._session_exit_stacks[session] = session_stack
# session_stack itself becomes a resource managed by the
# main _exit_stack.
await self._exit_stack.enter_async_context(session_stack)
return result.serverInfo, session
except Exception: # pragma: no cover
# If anything during this setup fails, ensure the session-specific
# stack is closed.
await session_stack.aclose()
raise
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
"""Aggregates prompts, resources, and tools from a given session."""
# Create a reverse index so we can find all prompts, resources, and
# tools belonging to this session. Used for removing components from
# the session group via self.disconnect_from_server.
component_names = self._ComponentNames()
# Temporary components dicts. We do not want to modify the aggregate
# lists in case of an intermediate failure.
prompts_temp: dict[str, types.Prompt] = {}
resources_temp: dict[str, types.Resource] = {}
tools_temp: dict[str, types.Tool] = {}
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
# Query the server for its prompts and aggregate to list.
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except McpError as err: # pragma: no cover
logging.warning(f"Could not fetch prompts: {err}")
# Query the server for its resources and aggregate to list.
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except McpError as err: # pragma: no cover
logging.warning(f"Could not fetch resources: {err}")
# Query the server for its tools and aggregate to list.
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except McpError as err: # pragma: no cover
logging.warning(f"Could not fetch tools: {err}")
# Clean up exit stack for session if we couldn't retrieve anything
# from the server.
if not any((prompts_temp, resources_temp, tools_temp)):
del self._session_exit_stacks[session] # pragma: no cover
# Check for duplicates.
matching_prompts = prompts_temp.keys() & self._prompts.keys()
if matching_prompts:
raise McpError( # pragma: no cover
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_prompts} already exist in group prompts.",
)
)
matching_resources = resources_temp.keys() & self._resources.keys()
if matching_resources:
raise McpError( # pragma: no cover
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_resources} already exist in group resources.",
)
)
matching_tools = tools_temp.keys() & self._tools.keys()
if matching_tools:
raise McpError(
types.ErrorData(
code=types.INVALID_PARAMS,
message=f"{matching_tools} already exist in group tools.",
)
)
# Aggregate components.
self._sessions[session] = component_names
self._prompts.update(prompts_temp)
self._resources.update(resources_temp)
self._tools.update(tools_temp)
self._tool_to_session.update(tool_to_session_temp)
def _component_name(self, name: str, server_info: types.Implementation) -> str:
if self._component_name_hook:
return self._component_name_hook(name, server_info)
return name

View File

@@ -0,0 +1,148 @@
import logging
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from httpx_sse._exceptions import SSEError
import mcp.types as types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
):
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
async with aconnect_sse(
client,
"GET",
url,
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse(): # pragma: no branch
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if ( # pragma: no cover
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = ( # pragma: no cover
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg) # pragma: no cover
raise ValueError(error_msg) # pragma: no cover
task_status.started(endpoint_url)
case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(f"Received server message: {message}")
except Exception as exc: # pragma: no cover
logger.exception("Error parsing server message") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
continue # pragma: no cover
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
except SSEError as sse_exc: # pragma: no cover
logger.exception("Encountered SSE exception") # pragma: no cover
raise sse_exc # pragma: no cover
except Exception as exc: # pragma: no cover
logger.exception("Error in sse_reader") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
finally:
await read_stream_writer.aclose()
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception: # pragma: no cover
logger.exception("Error in post_writer") # pragma: no cover
finally:
await write_stream.aclose()
endpoint_url = await tg.start(sse_reader)
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
tg.start_soon(post_writer, endpoint_url)
try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()

View File

@@ -0,0 +1,278 @@
import logging
import os
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Literal, TextIO
import anyio
import anyio.lowlevel
from anyio.abc import Process
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field
import mcp.types as types
from mcp.os.posix.utilities import terminate_posix_process_tree
from mcp.os.win32.utilities import (
FallbackProcess,
create_windows_process,
get_windows_executable_command,
terminate_windows_process_tree,
)
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
# Environment variables to inherit by default
DEFAULT_INHERITED_ENV_VARS = (
[
"APPDATA",
"HOMEDRIVE",
"HOMEPATH",
"LOCALAPPDATA",
"PATH",
"PATHEXT",
"PROCESSOR_ARCHITECTURE",
"SYSTEMDRIVE",
"SYSTEMROOT",
"TEMP",
"USERNAME",
"USERPROFILE",
]
if sys.platform == "win32"
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
)
# Timeout for process termination before falling back to force kill
PROCESS_TERMINATION_TIMEOUT = 2.0
def get_default_environment() -> dict[str, str]:
"""
Returns a default environment object including only environment variables deemed
safe to inherit.
"""
env: dict[str, str] = {}
for key in DEFAULT_INHERITED_ENV_VARS:
value = os.environ.get(key)
if value is None:
continue # pragma: no cover
if value.startswith("()"): # pragma: no cover
# Skip functions, which are a security risk
continue # pragma: no cover
env[key] = value
return env
class StdioServerParameters(BaseModel):
command: str
"""The executable to run to start the server."""
args: list[str] = Field(default_factory=list)
"""Command line arguments to pass to the executable."""
env: dict[str, str] | None = None
"""
The environment to use when spawning the process.
If not specified, the result of get_default_environment() will be used.
"""
cwd: str | Path | None = None
"""The working directory to use when spawning the process."""
encoding: str = "utf-8"
"""
The text encoding used when sending/receiving messages to the server
defaults to utf-8
"""
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
"""
The text encoding error handler.
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
explanations of possible values
"""
@asynccontextmanager
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
"""
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
try:
command = _get_executable_command(server.command)
# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
errlog=errlog,
cwd=server.cwd,
)
except OSError:
# Clean up streams if process creation fails
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
raise
async def stdout_reader():
assert process.stdout, "Opened process is missing stdout"
try:
async with read_stream_writer:
buffer = ""
async for chunk in TextReceiveStream(
process.stdout,
encoding=server.encoding,
errors=server.encoding_error_handler,
):
lines = (buffer + chunk).split("\n")
buffer = lines.pop()
for line in lines:
try:
message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc: # pragma: no cover
logger.exception("Failed to parse JSONRPC message from server")
await read_stream_writer.send(exc)
continue
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()
async def stdin_writer():
assert process.stdin, "Opened process is missing stdin"
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
await process.stdin.send(
(json + "\n").encode(
encoding=server.encoding,
errors=server.encoding_error_handler,
)
)
except anyio.ClosedResourceError: # pragma: no cover
await anyio.lowlevel.checkpoint()
async with (
anyio.create_task_group() as tg,
process,
):
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
try:
yield read_stream, write_stream
finally:
# MCP spec: stdio shutdown sequence
# 1. Close input stream to server
# 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time
# 3. Send SIGKILL if still not exited
if process.stdin: # pragma: no branch
try:
await process.stdin.aclose()
except Exception: # pragma: no cover
# stdin might already be closed, which is fine
pass
try:
# Give the process time to exit gracefully after stdin closes
with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT):
await process.wait()
except TimeoutError:
# Process didn't exit from stdin closure, use platform-specific termination
# which handles SIGTERM -> SIGKILL escalation
await _terminate_process_tree(process)
except ProcessLookupError: # pragma: no cover
# Process already exited, which is fine
pass
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
def _get_executable_command(command: str) -> str:
"""
Get the correct executable command normalized for the current platform.
Args:
command: Base command (e.g., 'uvx', 'npx')
Returns:
str: Platform-appropriate command
"""
if sys.platform == "win32": # pragma: no cover
return get_windows_executable_command(command)
else:
return command # pragma: no cover
async def _create_platform_compatible_process(
command: str,
args: list[str],
env: dict[str, str] | None = None,
errlog: TextIO = sys.stderr,
cwd: Path | str | None = None,
):
"""
Creates a subprocess in a platform-compatible way.
Unix: Creates process in a new session/process group for killpg support
Windows: Creates process in a Job Object for reliable child termination
"""
if sys.platform == "win32": # pragma: no cover
process = await create_windows_process(command, args, env, errlog, cwd)
else:
process = await anyio.open_process(
[command, *args],
env=env,
stderr=errlog,
cwd=cwd,
start_new_session=True,
) # pragma: no cover
return process
async def _terminate_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None:
"""
Terminate a process and all its children using platform-specific methods.
Unix: Uses os.killpg() for atomic process group termination
Windows: Uses Job Objects via pywin32 for reliable child process cleanup
Args:
process: The process to terminate
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
"""
if sys.platform == "win32": # pragma: no cover
await terminate_windows_process_tree(process, timeout_seconds)
else: # pragma: no cover
# FallbackProcess should only be used for Windows compatibility
assert isinstance(process, Process)
await terminate_posix_process_tree(process, timeout_seconds)

View File

@@ -0,0 +1,515 @@
"""
StreamableHTTP Client Transport Module
This module implements the StreamableHTTP transport for MCP clients,
providing support for HTTP POST requests with optional SSE streaming responses
and session management.
"""
import logging
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import ClientMessageMetadata, SessionMessage
from mcp.types import (
ErrorData,
InitializeResult,
JSONRPCError,
JSONRPCMessage,
JSONRPCNotification,
JSONRPCRequest,
JSONRPCResponse,
RequestId,
)
logger = logging.getLogger(__name__)
SessionMessageOrError = SessionMessage | Exception
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
StreamReader = MemoryObjectReceiveStream[SessionMessage]
GetSessionIdCallback = Callable[[], str | None]
MCP_SESSION_ID = "mcp-session-id"
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
LAST_EVENT_ID = "last-event-id"
CONTENT_TYPE = "content-type"
ACCEPT = "accept"
JSON = "application/json"
SSE = "text/event-stream"
class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""
class ResumptionError(StreamableHTTPError):
"""Raised when resumption request is invalid."""
@dataclass
class RequestContext:
"""Context for a request operation."""
client: httpx.AsyncClient
headers: dict[str, str]
session_id: str | None
session_message: SessionMessage
metadata: ClientMessageMetadata | None
read_stream_writer: StreamWriter
sse_read_timeout: float
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""
def __init__(
self,
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
auth: httpx.Auth | None = None,
) -> None:
"""Initialize the StreamableHTTP transport.
Args:
url: The endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
"""
self.url = url
self.headers = headers or {}
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
self.sse_read_timeout = (
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
)
self.auth = auth
self.session_id = None
self.protocol_version = None
self.request_headers = {
ACCEPT: f"{JSON}, {SSE}",
CONTENT_TYPE: JSON,
**self.headers,
}
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
"""Update headers with session ID and protocol version if available."""
headers = base_headers.copy()
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
if self.protocol_version:
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
return headers
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
def _maybe_extract_session_id_from_response(
self,
response: httpx.Response,
) -> None:
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")
def _maybe_extract_protocol_version_from_message(
self,
message: JSONRPCMessage,
) -> None:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result)
self.protocol_version = str(init_result.protocolVersion)
logger.info(f"Negotiated protocol version: {self.protocol_version}")
except Exception as exc: # pragma: no cover
logger.warning(
f"Failed to parse initialization response as InitializeResult: {exc}"
) # pragma: no cover
logger.warning(f"Raw result: {message.root.result}")
async def _handle_sse_event(
self,
sse: ServerSentEvent,
read_stream_writer: StreamWriter,
original_request_id: RequestId | None = None,
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
is_initialization: bool = False,
) -> bool:
"""Handle an SSE event, returning True if the response is complete."""
if sse.event == "message":
try:
message = JSONRPCMessage.model_validate_json(sse.data)
logger.debug(f"SSE message: {message}")
# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root.id = original_request_id
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
# Call resumption token callback if we have an ID
if sse.id and resumption_callback:
await resumption_callback(sse.id)
# If this is a response or error return True indicating completion
# Otherwise, return False to continue listening
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
except Exception as exc: # pragma: no cover
logger.exception("Error parsing SSE message")
await read_stream_writer.send(exc)
return False
else: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}")
return False
async def handle_get_stream(
self,
client: httpx.AsyncClient,
read_stream_writer: StreamWriter,
) -> None:
"""Handle GET stream for server-initiated messages."""
try:
if not self.session_id:
return
headers = self._prepare_request_headers(self.request_headers)
async with aconnect_sse(
client,
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("GET SSE connection established")
async for sse in event_source.aiter_sse():
await self._handle_sse_event(sse, read_stream_writer)
except Exception as exc:
logger.debug(f"GET stream error (non-fatal): {exc}") # pragma: no cover
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._prepare_request_headers(ctx.headers)
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover
# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.root.id
async with aconnect_sse(
ctx.client,
"GET",
self.url,
headers=headers,
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
) as event_source:
event_source.response.raise_for_status()
logger.debug("Resumption GET SSE connection established")
async for sse in event_source.aiter_sse(): # pragma: no branch
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
original_request_id,
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
)
if is_complete:
await event_source.response.aclose()
break
async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._prepare_request_headers(ctx.headers)
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)
async with ctx.client.stream(
"POST",
self.url,
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
headers=headers,
) as response:
if response.status_code == 202:
logger.debug("Received 202 Accepted")
return
if response.status_code == 404: # pragma: no branch
if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error( # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
message.root.id, # pragma: no cover
) # pragma: no cover
return # pragma: no cover
response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
content_type = response.headers.get(CONTENT_TYPE, "").lower()
if content_type.startswith(JSON):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
elif content_type.startswith(SSE):
await self._handle_sse_response(response, ctx, is_initialization)
else:
await self._handle_unexpected_content_type( # pragma: no cover
content_type, # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
) # pragma: no cover
async def _handle_json_response(
self,
response: httpx.Response,
read_stream_writer: StreamWriter,
is_initialization: bool = False,
) -> None:
"""Handle JSON response from the server."""
try:
content = await response.aread()
message = JSONRPCMessage.model_validate_json(content)
# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except Exception as exc: # pragma: no cover
logger.exception("Error parsing JSON response")
await read_stream_writer.send(exc)
async def _handle_sse_response(
self,
response: httpx.Response,
ctx: RequestContext,
is_initialization: bool = False,
) -> None:
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse(): # pragma: no branch
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
is_initialization=is_initialization,
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
await response.aclose()
break
except Exception as e:
logger.exception("Error reading SSE stream:") # pragma: no cover
await ctx.read_stream_writer.send(e) # pragma: no cover
async def _handle_unexpected_content_type(
self,
content_type: str,
read_stream_writer: StreamWriter,
) -> None: # pragma: no cover
"""Handle unexpected content type in response."""
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
logger.error(error_msg) # pragma: no cover
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover
async def _send_session_terminated_error(
self,
read_stream_writer: StreamWriter,
request_id: RequestId,
) -> None:
"""Send a session terminated error response."""
jsonrpc_error = JSONRPCError(
jsonrpc="2.0",
id=request_id,
error=ErrorData(code=32600, message="Session terminated"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
await read_stream_writer.send(session_message)
async def post_writer(
self,
client: httpx.AsyncClient,
write_stream_reader: StreamReader,
read_stream_writer: StreamWriter,
write_stream: MemoryObjectSendStream[SessionMessage],
start_get_stream: Callable[[], None],
tg: TaskGroup,
) -> None:
"""Handle writing requests to the server."""
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
message = session_message.message
metadata = (
session_message.metadata
if isinstance(session_message.metadata, ClientMessageMetadata)
else None
)
# Check if this is a resumption request
is_resumption = bool(metadata and metadata.resumption_token)
logger.debug(f"Sending client message: {message}")
# Handle initialized notification
if self._is_initialized_notification(message):
start_get_stream()
ctx = RequestContext(
client=client,
headers=self.request_headers,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
read_stream_writer=read_stream_writer,
sse_read_timeout=self.sse_read_timeout,
)
async def handle_request_async():
if is_resumption:
await self._handle_resumption_request(ctx)
else:
await self._handle_post_request(ctx)
# If this is a request, start a new task to handle it
if isinstance(message.root, JSONRPCRequest):
tg.start_soon(handle_request_async)
else:
await handle_request_async()
except Exception:
logger.exception("Error in post_writer") # pragma: no cover
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma: no cover
"""Terminate the session by sending a DELETE request."""
if not self.session_id:
return
try:
headers = self._prepare_request_headers(self.request_headers)
response = await client.delete(self.url, headers=headers)
if response.status_code == 405:
logger.debug("Server does not allow session termination")
elif response.status_code not in (200, 204):
logger.warning(f"Session termination failed: {response.status_code}")
except Exception as exc:
logger.warning(f"Session termination failed: {exc}")
def get_session_id(self) -> str | None:
"""Get the current session ID."""
return self.session_id
@asynccontextmanager
async def streamablehttp_client(
url: str,
headers: dict[str, str] | None = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
terminate_on_close: bool = True,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
) -> AsyncGenerator[
tuple[
MemoryObjectReceiveStream[SessionMessage | Exception],
MemoryObjectSendStream[SessionMessage],
GetSessionIdCallback,
],
None,
]:
"""
Client transport for StreamableHTTP.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Yields:
Tuple containing:
- read_stream: Stream for reading messages from the server
- write_stream: Stream for sending messages to the server
- get_session_id_callback: Function to retrieve the current session ID
"""
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth)
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
async with httpx_client_factory(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
auth=transport.auth,
) as client:
# Define callbacks that need access to tg
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
tg.start_soon(
transport.post_writer,
client,
write_stream_reader,
read_stream_writer,
write_stream,
start_get_stream,
tg,
)
try:
yield (
read_stream,
write_stream,
transport.get_session_id,
)
finally:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()

View File

@@ -0,0 +1,86 @@
import json
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import anyio
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import ValidationError
from websockets.asyncio.client import connect as ws_connect
from websockets.typing import Subprotocol
import mcp.types as types
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
@asynccontextmanager
async def websocket_client(
url: str,
) -> AsyncGenerator[
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]],
None,
]:
"""
WebSocket client transport for MCP, symmetrical to the server version.
Connects to 'url' using the 'mcp' subprotocol, then yields:
(read_stream, write_stream)
- read_stream: As you read from this stream, you'll receive either valid
JSONRPCMessage objects or Exception objects (when validation fails).
- write_stream: Write JSONRPCMessage objects to this stream to send them
over the WebSocket to the server.
"""
# Create two in-memory streams:
# - One for incoming messages (read_stream, written by ws_reader)
# - One for outgoing messages (write_stream, read by ws_writer)
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
# Connect using websockets, requesting the "mcp" subprotocol
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
async def ws_reader():
"""
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
and sends them into read_stream_writer.
"""
async with read_stream_writer:
async for raw_text in ws:
try:
message = types.JSONRPCMessage.model_validate_json(raw_text)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except ValidationError as exc: # pragma: no cover
# If JSON parse or model validation fails, send the exception
await read_stream_writer.send(exc)
async def ws_writer():
"""
Reads JSON-RPC messages from write_stream_reader and
sends them to the server.
"""
async with write_stream_reader:
async for session_message in write_stream_reader:
# Convert to a dict, then to JSON
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
await ws.send(json.dumps(msg_dict))
async with anyio.create_task_group() as tg:
# Start reader and writer tasks
tg.start_soon(ws_reader)
tg.start_soon(ws_writer)
# Yield the receive/send streams
yield (read_stream, write_stream)
# Once the caller's 'async with' block exits, we shut down
tg.cancel_scope.cancel()