chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,85 @@
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from functools import partial
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters, stdio_client
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import RequestResponder
|
||||
|
||||
if not sys.warnoptions:
|
||||
import warnings
|
||||
|
||||
warnings.simplefilter("ignore")
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
async def message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
if isinstance(message, Exception):
|
||||
logger.error("Error: %s", message)
|
||||
return
|
||||
|
||||
logger.info("Received message from server: %s", message)
|
||||
|
||||
|
||||
async def run_session(
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
client_info: types.Implementation | None = None,
|
||||
):
|
||||
async with ClientSession(
|
||||
read_stream,
|
||||
write_stream,
|
||||
message_handler=message_handler,
|
||||
client_info=client_info,
|
||||
) as session:
|
||||
logger.info("Initializing session")
|
||||
await session.initialize()
|
||||
logger.info("Initialized")
|
||||
|
||||
|
||||
async def main(command_or_url: str, args: list[str], env: list[tuple[str, str]]):
|
||||
env_dict = dict(env)
|
||||
|
||||
if urlparse(command_or_url).scheme in ("http", "https"):
|
||||
# Use SSE client for HTTP(S) URLs
|
||||
async with sse_client(command_or_url) as streams:
|
||||
await run_session(*streams)
|
||||
else:
|
||||
# Use stdio client for commands
|
||||
server_parameters = StdioServerParameters(command=command_or_url, args=args, env=env_dict)
|
||||
async with stdio_client(server_parameters) as streams:
|
||||
await run_session(*streams)
|
||||
|
||||
|
||||
def cli():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("command_or_url", help="Command or URL to connect to")
|
||||
parser.add_argument("args", nargs="*", help="Additional arguments")
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--env",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("KEY", "VALUE"),
|
||||
help="Environment variables to set. Can be used multiple times.",
|
||||
default=[],
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
anyio.run(partial(main, args.command_or_url, args.args, args.env), backend="trio")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
OAuth2 Authentication implementation for HTTPX.
|
||||
|
||||
Implements authorization code flow with PKCE and automatic token refresh.
|
||||
"""
|
||||
|
||||
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
|
||||
from mcp.client.auth.oauth2 import (
|
||||
OAuthClientProvider,
|
||||
PKCEParameters,
|
||||
TokenStorage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OAuthClientProvider",
|
||||
"OAuthFlowError",
|
||||
"OAuthRegistrationError",
|
||||
"OAuthTokenError",
|
||||
"PKCEParameters",
|
||||
"TokenStorage",
|
||||
]
|
||||
@@ -0,0 +1,10 @@
|
||||
class OAuthFlowError(Exception):
|
||||
"""Base exception for OAuth flow errors."""
|
||||
|
||||
|
||||
class OAuthTokenError(OAuthFlowError):
|
||||
"""Raised when token operations fail."""
|
||||
|
||||
|
||||
class OAuthRegistrationError(OAuthFlowError):
|
||||
"""Raised when client registration fails."""
|
||||
@@ -0,0 +1,148 @@
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
|
||||
|
||||
class JWTParameters(BaseModel):
|
||||
"""JWT parameters."""
|
||||
|
||||
assertion: str | None = Field(
|
||||
default=None,
|
||||
description="JWT assertion for JWT authentication. "
|
||||
"Will be used instead of generating a new assertion if provided.",
|
||||
)
|
||||
|
||||
issuer: str | None = Field(default=None, description="Issuer for JWT assertions.")
|
||||
subject: str | None = Field(default=None, description="Subject identifier for JWT assertions.")
|
||||
audience: str | None = Field(default=None, description="Audience for JWT assertions.")
|
||||
claims: dict[str, Any] | None = Field(default=None, description="Additional claims for JWT assertions.")
|
||||
jwt_signing_algorithm: str | None = Field(default="RS256", description="Algorithm for signing JWT assertions.")
|
||||
jwt_signing_key: str | None = Field(default=None, description="Private key for JWT signing.")
|
||||
jwt_lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.")
|
||||
|
||||
def to_assertion(self, with_audience_fallback: str | None = None) -> str:
|
||||
if self.assertion is not None:
|
||||
# Prebuilt JWT (e.g. acquired out-of-band)
|
||||
assertion = self.assertion
|
||||
else:
|
||||
if not self.jwt_signing_key:
|
||||
raise OAuthFlowError("Missing signing key for JWT bearer grant") # pragma: no cover
|
||||
if not self.issuer:
|
||||
raise OAuthFlowError("Missing issuer for JWT bearer grant") # pragma: no cover
|
||||
if not self.subject:
|
||||
raise OAuthFlowError("Missing subject for JWT bearer grant") # pragma: no cover
|
||||
|
||||
audience = self.audience if self.audience else with_audience_fallback
|
||||
if not audience:
|
||||
raise OAuthFlowError("Missing audience for JWT bearer grant") # pragma: no cover
|
||||
|
||||
now = int(time.time())
|
||||
claims: dict[str, Any] = {
|
||||
"iss": self.issuer,
|
||||
"sub": self.subject,
|
||||
"aud": audience,
|
||||
"exp": now + self.jwt_lifetime_seconds,
|
||||
"iat": now,
|
||||
"jti": str(uuid4()),
|
||||
}
|
||||
claims.update(self.claims or {})
|
||||
|
||||
assertion = jwt.encode(
|
||||
claims,
|
||||
self.jwt_signing_key,
|
||||
algorithm=self.jwt_signing_algorithm or "RS256",
|
||||
)
|
||||
return assertion
|
||||
|
||||
|
||||
class RFC7523OAuthClientProvider(OAuthClientProvider):
|
||||
"""OAuth client provider for RFC7532 clients."""
|
||||
|
||||
jwt_parameters: JWTParameters | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
storage: TokenStorage,
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
|
||||
timeout: float = 300.0,
|
||||
jwt_parameters: JWTParameters | None = None,
|
||||
) -> None:
|
||||
super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout)
|
||||
self.jwt_parameters = jwt_parameters
|
||||
|
||||
async def _exchange_token_authorization_code(
|
||||
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = None
|
||||
) -> httpx.Request: # pragma: no cover
|
||||
"""Build token exchange request for authorization_code flow."""
|
||||
token_data = token_data or {}
|
||||
if self.context.client_metadata.token_endpoint_auth_method == "private_key_jwt":
|
||||
self._add_client_authentication_jwt(token_data=token_data)
|
||||
return await super()._exchange_token_authorization_code(auth_code, code_verifier, token_data=token_data)
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request: # pragma: no cover
|
||||
"""Perform the authorization flow."""
|
||||
if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self.context.client_metadata.grant_types:
|
||||
token_request = await self._exchange_token_jwt_bearer()
|
||||
return token_request
|
||||
else:
|
||||
return await super()._perform_authorization()
|
||||
|
||||
def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]): # pragma: no cover
|
||||
"""Add JWT assertion for client authentication to token endpoint parameters."""
|
||||
if not self.jwt_parameters:
|
||||
raise OAuthTokenError("Missing JWT parameters for private_key_jwt flow")
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthTokenError("Missing OAuth metadata for private_key_jwt flow")
|
||||
|
||||
# We need to set the audience to the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
|
||||
issuer = str(self.context.oauth_metadata.issuer)
|
||||
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
|
||||
|
||||
# When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
|
||||
token_data["client_assertion"] = assertion
|
||||
token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
|
||||
# We need to set the audience to the resource server, the audience is difference from the one in claims
|
||||
# it represents the resource server that will validate the token
|
||||
token_data["audience"] = self.context.get_resource_url()
|
||||
|
||||
async def _exchange_token_jwt_bearer(self) -> httpx.Request:
|
||||
"""Build token exchange request for JWT bearer grant."""
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("Missing client info") # pragma: no cover
|
||||
if not self.jwt_parameters:
|
||||
raise OAuthFlowError("Missing JWT parameters") # pragma: no cover
|
||||
if not self.context.oauth_metadata:
|
||||
raise OAuthTokenError("Missing OAuth metadata") # pragma: no cover
|
||||
|
||||
# We need to set the audience to the issuer identifier of the authorization server
|
||||
# https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01#name-updates-to-rfc-7523
|
||||
issuer = str(self.context.oauth_metadata.issuer)
|
||||
assertion = self.jwt_parameters.to_assertion(with_audience_fallback=issuer)
|
||||
|
||||
token_data = {
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
|
||||
"assertion": assertion,
|
||||
}
|
||||
|
||||
if self.context.should_include_resource_param(self.context.protocol_version): # pragma: no branch
|
||||
token_data["resource"] = self.context.get_resource_url()
|
||||
|
||||
if self.context.client_metadata.scope: # pragma: no branch
|
||||
token_data["scope"] = self.context.client_metadata.scope
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
return httpx.Request(
|
||||
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
@@ -0,0 +1,557 @@
|
||||
"""
|
||||
OAuth2 Authentication implementation for HTTPX.
|
||||
|
||||
Implements authorization code flow with PKCE and automatic token refresh.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol
|
||||
from urllib.parse import urlencode, urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from mcp.client.auth import OAuthFlowError, OAuthTokenError
|
||||
from mcp.client.auth.utils import (
|
||||
build_oauth_authorization_server_metadata_discovery_urls,
|
||||
build_protected_resource_metadata_discovery_urls,
|
||||
create_client_registration_request,
|
||||
create_oauth_metadata_request,
|
||||
extract_field_from_www_auth,
|
||||
extract_resource_metadata_from_www_auth,
|
||||
extract_scope_from_www_auth,
|
||||
get_client_metadata_scopes,
|
||||
handle_auth_metadata_response,
|
||||
handle_protected_resource_response,
|
||||
handle_registration_response,
|
||||
handle_token_response_scopes,
|
||||
)
|
||||
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthToken,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from mcp.shared.auth_utils import (
|
||||
calculate_token_expiry,
|
||||
check_resource_allowed,
|
||||
resource_url_from_server_url,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PKCEParameters(BaseModel):
|
||||
"""PKCE (Proof Key for Code Exchange) parameters."""
|
||||
|
||||
code_verifier: str = Field(..., min_length=43, max_length=128)
|
||||
code_challenge: str = Field(..., min_length=43, max_length=128)
|
||||
|
||||
@classmethod
|
||||
def generate(cls) -> "PKCEParameters":
|
||||
"""Generate new PKCE parameters."""
|
||||
code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128))
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=")
|
||||
return cls(code_verifier=code_verifier, code_challenge=code_challenge)
|
||||
|
||||
|
||||
class TokenStorage(Protocol):
|
||||
"""Protocol for token storage implementations."""
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
"""Get stored tokens."""
|
||||
...
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
"""Store tokens."""
|
||||
...
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
"""Get stored client information."""
|
||||
...
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
"""Store client information."""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthContext:
|
||||
"""OAuth flow context."""
|
||||
|
||||
server_url: str
|
||||
client_metadata: OAuthClientMetadata
|
||||
storage: TokenStorage
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None
|
||||
timeout: float = 300.0
|
||||
|
||||
# Discovered metadata
|
||||
protected_resource_metadata: ProtectedResourceMetadata | None = None
|
||||
oauth_metadata: OAuthMetadata | None = None
|
||||
auth_server_url: str | None = None
|
||||
protocol_version: str | None = None
|
||||
|
||||
# Client registration
|
||||
client_info: OAuthClientInformationFull | None = None
|
||||
|
||||
# Token management
|
||||
current_tokens: OAuthToken | None = None
|
||||
token_expiry_time: float | None = None
|
||||
|
||||
# State
|
||||
lock: anyio.Lock = field(default_factory=anyio.Lock)
|
||||
|
||||
def get_authorization_base_url(self, server_url: str) -> str:
|
||||
"""Extract base URL by removing path component."""
|
||||
parsed = urlparse(server_url)
|
||||
return f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
def update_token_expiry(self, token: OAuthToken) -> None:
|
||||
"""Update token expiry time using shared util function."""
|
||||
self.token_expiry_time = calculate_token_expiry(token.expires_in)
|
||||
|
||||
def is_token_valid(self) -> bool:
|
||||
"""Check if current token is valid."""
|
||||
return bool(
|
||||
self.current_tokens
|
||||
and self.current_tokens.access_token
|
||||
and (not self.token_expiry_time or time.time() <= self.token_expiry_time)
|
||||
)
|
||||
|
||||
def can_refresh_token(self) -> bool:
|
||||
"""Check if token can be refreshed."""
|
||||
return bool(self.current_tokens and self.current_tokens.refresh_token and self.client_info)
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""Clear current tokens."""
|
||||
self.current_tokens = None
|
||||
self.token_expiry_time = None
|
||||
|
||||
def get_resource_url(self) -> str:
|
||||
"""Get resource URL for RFC 8707.
|
||||
|
||||
Uses PRM resource if it's a valid parent, otherwise uses canonical server URL.
|
||||
"""
|
||||
resource = resource_url_from_server_url(self.server_url)
|
||||
|
||||
# If PRM provides a resource that's a valid parent, use it
|
||||
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
|
||||
prm_resource = str(self.protected_resource_metadata.resource)
|
||||
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
|
||||
resource = prm_resource
|
||||
|
||||
return resource
|
||||
|
||||
def should_include_resource_param(self, protocol_version: str | None = None) -> bool:
|
||||
"""Determine if the resource parameter should be included in OAuth requests.
|
||||
|
||||
Returns True if:
|
||||
- Protected resource metadata is available, OR
|
||||
- MCP-Protocol-Version header is 2025-06-18 or later
|
||||
"""
|
||||
# If we have protected resource metadata, include the resource param
|
||||
if self.protected_resource_metadata is not None:
|
||||
return True
|
||||
|
||||
# If no protocol version provided, don't include resource param
|
||||
if not protocol_version:
|
||||
return False
|
||||
|
||||
# Check if protocol version is 2025-06-18 or later
|
||||
# Version format is YYYY-MM-DD, so string comparison works
|
||||
return protocol_version >= "2025-06-18"
|
||||
|
||||
|
||||
class OAuthClientProvider(httpx.Auth):
|
||||
"""
|
||||
OAuth2 authentication for httpx.
|
||||
Handles OAuth flow with automatic client registration and token storage.
|
||||
"""
|
||||
|
||||
requires_response_body = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
client_metadata: OAuthClientMetadata,
|
||||
storage: TokenStorage,
|
||||
redirect_handler: Callable[[str], Awaitable[None]] | None = None,
|
||||
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
|
||||
timeout: float = 300.0,
|
||||
):
|
||||
"""Initialize OAuth2 authentication."""
|
||||
self.context = OAuthContext(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
timeout=timeout,
|
||||
)
|
||||
self._initialized = False
|
||||
|
||||
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
|
||||
"""
|
||||
Handle protected resource metadata discovery response.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
|
||||
Returns:
|
||||
True if metadata was successfully discovered, False if we should try next URL
|
||||
"""
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
metadata = ProtectedResourceMetadata.model_validate_json(content)
|
||||
self.context.protected_resource_metadata = metadata
|
||||
if metadata.authorization_servers: # pragma: no branch
|
||||
self.context.auth_server_url = str(metadata.authorization_servers[0])
|
||||
return True
|
||||
|
||||
except ValidationError: # pragma: no cover
|
||||
# Invalid metadata - try next URL
|
||||
logger.warning(f"Invalid protected resource metadata at {response.request.url}")
|
||||
return False
|
||||
elif response.status_code == 404: # pragma: no cover
|
||||
# Not found - try next URL in fallback chain
|
||||
logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL")
|
||||
return False
|
||||
else:
|
||||
# Other error - fail immediately
|
||||
raise OAuthFlowError(
|
||||
f"Protected Resource Metadata request failed: {response.status_code}"
|
||||
) # pragma: no cover
|
||||
|
||||
async def _register_client(self) -> httpx.Request | None:
|
||||
"""Build registration request or skip if already registered."""
|
||||
if self.context.client_info:
|
||||
return None
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint:
|
||||
registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
registration_url = urljoin(auth_base_url, "/register")
|
||||
|
||||
registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
|
||||
return httpx.Request(
|
||||
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
async def _perform_authorization(self) -> httpx.Request:
|
||||
"""Perform the authorization flow."""
|
||||
auth_code, code_verifier = await self._perform_authorization_code_grant()
|
||||
token_request = await self._exchange_token_authorization_code(auth_code, code_verifier)
|
||||
return token_request
|
||||
|
||||
async def _perform_authorization_code_grant(self) -> tuple[str, str]:
|
||||
"""Perform the authorization redirect and get auth code."""
|
||||
if self.context.client_metadata.redirect_uris is None:
|
||||
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.redirect_handler:
|
||||
raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.callback_handler:
|
||||
raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint:
|
||||
auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
auth_endpoint = urljoin(auth_base_url, "/authorize")
|
||||
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("No client info available for authorization") # pragma: no cover
|
||||
|
||||
# Generate PKCE parameters
|
||||
pkce_params = PKCEParameters.generate()
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
auth_params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.context.client_info.client_id,
|
||||
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
|
||||
"state": state,
|
||||
"code_challenge": pkce_params.code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover
|
||||
|
||||
if self.context.client_metadata.scope: # pragma: no branch
|
||||
auth_params["scope"] = self.context.client_metadata.scope
|
||||
|
||||
authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}"
|
||||
await self.context.redirect_handler(authorization_url)
|
||||
|
||||
# Wait for callback
|
||||
auth_code, returned_state = await self.context.callback_handler()
|
||||
|
||||
if returned_state is None or not secrets.compare_digest(returned_state, state):
|
||||
raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover
|
||||
|
||||
if not auth_code:
|
||||
raise OAuthFlowError("No authorization code received") # pragma: no cover
|
||||
|
||||
# Return auth code and code verifier for token exchange
|
||||
return auth_code, pkce_params.code_verifier
|
||||
|
||||
def _get_token_endpoint(self) -> str:
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
|
||||
token_url = str(self.context.oauth_metadata.token_endpoint)
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
return token_url
|
||||
|
||||
async def _exchange_token_authorization_code(
|
||||
self, auth_code: str, code_verifier: str, *, token_data: dict[str, Any] | None = {}
|
||||
) -> httpx.Request:
|
||||
"""Build token exchange request for authorization_code flow."""
|
||||
if self.context.client_metadata.redirect_uris is None:
|
||||
raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover
|
||||
if not self.context.client_info:
|
||||
raise OAuthFlowError("Missing client info") # pragma: no cover
|
||||
|
||||
token_url = self._get_token_endpoint()
|
||||
token_data = token_data or {}
|
||||
token_data.update(
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"code": auth_code,
|
||||
"redirect_uri": str(self.context.client_metadata.redirect_uris[0]),
|
||||
"client_id": self.context.client_info.client_id,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
)
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
token_data["resource"] = self.context.get_resource_url() # RFC 8707
|
||||
|
||||
if self.context.client_info.client_secret:
|
||||
token_data["client_secret"] = self.context.client_info.client_secret
|
||||
|
||||
return httpx.Request(
|
||||
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
|
||||
async def _handle_token_response(self, response: httpx.Response) -> None:
|
||||
"""Handle token exchange response."""
|
||||
if response.status_code != 200:
|
||||
body = await response.aread() # pragma: no cover
|
||||
body_text = body.decode("utf-8") # pragma: no cover
|
||||
raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover
|
||||
|
||||
# Parse and validate response with scope validation
|
||||
token_response = await handle_token_response_scopes(response)
|
||||
|
||||
# Store tokens in context
|
||||
self.context.current_tokens = token_response
|
||||
self.context.update_token_expiry(token_response)
|
||||
await self.context.storage.set_tokens(token_response)
|
||||
|
||||
async def _refresh_token(self) -> httpx.Request:
|
||||
"""Build token refresh request."""
|
||||
if not self.context.current_tokens or not self.context.current_tokens.refresh_token:
|
||||
raise OAuthTokenError("No refresh token available") # pragma: no cover
|
||||
|
||||
if not self.context.client_info:
|
||||
raise OAuthTokenError("No client info available") # pragma: no cover
|
||||
|
||||
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint:
|
||||
token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover
|
||||
else:
|
||||
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
|
||||
token_url = urljoin(auth_base_url, "/token")
|
||||
|
||||
refresh_data = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": self.context.current_tokens.refresh_token,
|
||||
"client_id": self.context.client_info.client_id,
|
||||
}
|
||||
|
||||
# Only include resource param if conditions are met
|
||||
if self.context.should_include_resource_param(self.context.protocol_version):
|
||||
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707
|
||||
|
||||
if self.context.client_info.client_secret: # pragma: no branch
|
||||
refresh_data["client_secret"] = self.context.client_info.client_secret
|
||||
|
||||
return httpx.Request(
|
||||
"POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
|
||||
)
|
||||
|
||||
async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover
|
||||
"""Handle token refresh response. Returns True if successful."""
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Token refresh failed: {response.status_code}")
|
||||
self.context.clear_tokens()
|
||||
return False
|
||||
|
||||
try:
|
||||
content = await response.aread()
|
||||
token_response = OAuthToken.model_validate_json(content)
|
||||
|
||||
self.context.current_tokens = token_response
|
||||
self.context.update_token_expiry(token_response)
|
||||
await self.context.storage.set_tokens(token_response)
|
||||
|
||||
return True
|
||||
except ValidationError:
|
||||
logger.exception("Invalid refresh response")
|
||||
self.context.clear_tokens()
|
||||
return False
|
||||
|
||||
async def _initialize(self) -> None: # pragma: no cover
|
||||
"""Load stored tokens and client info."""
|
||||
self.context.current_tokens = await self.context.storage.get_tokens()
|
||||
self.context.client_info = await self.context.storage.get_client_info()
|
||||
self._initialized = True
|
||||
|
||||
def _add_auth_header(self, request: httpx.Request) -> None:
|
||||
"""Add authorization header to request if we have valid tokens."""
|
||||
if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch
|
||||
request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}"
|
||||
|
||||
async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None:
|
||||
content = await response.aread()
|
||||
metadata = OAuthMetadata.model_validate_json(content)
|
||||
self.context.oauth_metadata = metadata
|
||||
|
||||
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
|
||||
"""HTTPX auth flow integration."""
|
||||
async with self.context.lock:
|
||||
if not self._initialized:
|
||||
await self._initialize() # pragma: no cover
|
||||
|
||||
# Capture protocol version from request headers
|
||||
self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION)
|
||||
|
||||
if not self.context.is_token_valid() and self.context.can_refresh_token():
|
||||
# Try to refresh token
|
||||
refresh_request = await self._refresh_token() # pragma: no cover
|
||||
refresh_response = yield refresh_request # pragma: no cover
|
||||
|
||||
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
|
||||
# Refresh failed, need full re-authentication
|
||||
self._initialized = False
|
||||
|
||||
if self.context.is_token_valid():
|
||||
self._add_auth_header(request)
|
||||
|
||||
response = yield request
|
||||
|
||||
if response.status_code == 401:
|
||||
# Perform full OAuth flow
|
||||
try:
|
||||
# OAuth flow must be inline due to generator constraints
|
||||
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)
|
||||
|
||||
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
|
||||
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
|
||||
www_auth_resource_metadata_url, self.context.server_url
|
||||
)
|
||||
|
||||
for url in prm_discovery_urls: # pragma: no branch
|
||||
discovery_request = create_oauth_metadata_request(url)
|
||||
|
||||
discovery_response = yield discovery_request # sending request
|
||||
|
||||
prm = await handle_protected_resource_response(discovery_response)
|
||||
if prm:
|
||||
self.context.protected_resource_metadata = prm
|
||||
|
||||
# todo: try all authorization_servers to find the OASM
|
||||
assert (
|
||||
len(prm.authorization_servers) > 0
|
||||
) # this is always true as authorization_servers has a min length of 1
|
||||
|
||||
self.context.auth_server_url = str(prm.authorization_servers[0])
|
||||
break
|
||||
else:
|
||||
logger.debug(f"Protected resource metadata discovery failed: {url}")
|
||||
|
||||
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
|
||||
self.context.auth_server_url, self.context.server_url
|
||||
)
|
||||
|
||||
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
|
||||
for url in asm_discovery_urls: # pragma: no cover
|
||||
oauth_metadata_request = create_oauth_metadata_request(url)
|
||||
oauth_metadata_response = yield oauth_metadata_request
|
||||
|
||||
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
|
||||
if not ok:
|
||||
break
|
||||
if ok and asm:
|
||||
self.context.oauth_metadata = asm
|
||||
break
|
||||
else:
|
||||
logger.debug(f"OAuth metadata discovery failed: {url}")
|
||||
|
||||
# Step 3: Apply scope selection strategy
|
||||
self.context.client_metadata.scope = get_client_metadata_scopes(
|
||||
www_auth_resource_metadata_url,
|
||||
self.context.protected_resource_metadata,
|
||||
self.context.oauth_metadata,
|
||||
)
|
||||
|
||||
# Step 4: Register client if needed
|
||||
registration_request = create_client_registration_request(
|
||||
self.context.oauth_metadata,
|
||||
self.context.client_metadata,
|
||||
self.context.get_authorization_base_url(self.context.server_url),
|
||||
)
|
||||
if not self.context.client_info:
|
||||
registration_response = yield registration_request
|
||||
client_information = await handle_registration_response(registration_response)
|
||||
self.context.client_info = client_information
|
||||
await self.context.storage.set_client_info(client_information)
|
||||
|
||||
# Step 5: Perform authorization and complete token exchange
|
||||
token_response = yield await self._perform_authorization()
|
||||
await self._handle_token_response(token_response)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("OAuth flow error")
|
||||
raise
|
||||
|
||||
# Retry with new tokens
|
||||
self._add_auth_header(request)
|
||||
yield request
|
||||
elif response.status_code == 403:
|
||||
# Step 1: Extract error field from WWW-Authenticate header
|
||||
error = extract_field_from_www_auth(response, "error")
|
||||
|
||||
# Step 2: Check if we need to step-up authorization
|
||||
if error == "insufficient_scope": # pragma: no branch
|
||||
try:
|
||||
# Step 2a: Update the required scopes
|
||||
self.context.client_metadata.scope = get_client_metadata_scopes(
|
||||
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
|
||||
)
|
||||
|
||||
# Step 2b: Perform (re-)authorization and token exchange
|
||||
token_response = yield await self._perform_authorization()
|
||||
await self._handle_token_response(token_response)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("OAuth flow error")
|
||||
raise
|
||||
|
||||
# Retry with new tokens
|
||||
self._add_auth_header(request)
|
||||
yield request
|
||||
@@ -0,0 +1,267 @@
|
||||
import logging
|
||||
import re
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from httpx import Request, Response
|
||||
from pydantic import ValidationError
|
||||
|
||||
from mcp.client.auth import OAuthRegistrationError, OAuthTokenError
|
||||
from mcp.client.streamable_http import MCP_PROTOCOL_VERSION
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthMetadata,
|
||||
OAuthToken,
|
||||
ProtectedResourceMetadata,
|
||||
)
|
||||
from mcp.types import LATEST_PROTOCOL_VERSION
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_field_from_www_auth(response: Response, field_name: str) -> str | None:
|
||||
"""
|
||||
Extract field from WWW-Authenticate header.
|
||||
|
||||
Returns:
|
||||
Field value if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
www_auth_header = response.headers.get("WWW-Authenticate")
|
||||
if not www_auth_header:
|
||||
return None
|
||||
|
||||
# Pattern matches: field_name="value" or field_name=value (unquoted)
|
||||
pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))'
|
||||
match = re.search(pattern, www_auth_header)
|
||||
|
||||
if match:
|
||||
# Return quoted value if present, otherwise unquoted value
|
||||
return match.group(1) or match.group(2)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_scope_from_www_auth(response: Response) -> str | None:
|
||||
"""
|
||||
Extract scope parameter from WWW-Authenticate header as per RFC6750.
|
||||
|
||||
Returns:
|
||||
Scope string if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
return extract_field_from_www_auth(response, "scope")
|
||||
|
||||
|
||||
def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
|
||||
"""
|
||||
Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728.
|
||||
|
||||
Returns:
|
||||
Resource metadata URL if found in WWW-Authenticate header, None otherwise
|
||||
"""
|
||||
if not response or response.status_code != 401:
|
||||
return None # pragma: no cover
|
||||
|
||||
return extract_field_from_www_auth(response, "resource_metadata")
|
||||
|
||||
|
||||
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Build ordered list of URLs to try for protected resource metadata discovery.
|
||||
|
||||
Per SEP-985, the client MUST:
|
||||
1. Try resource_metadata from WWW-Authenticate header (if present)
|
||||
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
|
||||
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
|
||||
|
||||
Args:
|
||||
www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header
|
||||
server_url: server url
|
||||
|
||||
Returns:
|
||||
Ordered list of URLs to try for discovery
|
||||
"""
|
||||
urls: list[str] = []
|
||||
|
||||
# Priority 1: WWW-Authenticate header with resource_metadata parameter
|
||||
if www_auth_url:
|
||||
urls.append(www_auth_url)
|
||||
|
||||
# Priority 2-3: Well-known URIs (RFC 9728)
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# Priority 2: Path-based well-known URI (if server has a path component)
|
||||
if parsed.path and parsed.path != "/":
|
||||
path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}")
|
||||
urls.append(path_based_url)
|
||||
|
||||
# Priority 3: Root-based well-known URI
|
||||
root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource")
|
||||
urls.append(root_based_url)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def get_client_metadata_scopes(
|
||||
www_authenticate_scope: str | None,
|
||||
protected_resource_metadata: ProtectedResourceMetadata | None,
|
||||
authorization_server_metadata: OAuthMetadata | None = None,
|
||||
) -> str | None:
|
||||
"""Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec."""
|
||||
# Per MCP spec, scope selection priority order:
|
||||
# 1. Use scope from WWW-Authenticate header (if provided)
|
||||
# 2. Use all scopes from PRM scopes_supported (if available)
|
||||
# 3. Omit scope parameter if neither is available
|
||||
|
||||
if www_authenticate_scope is not None:
|
||||
# Priority 1: WWW-Authenticate header scope
|
||||
return www_authenticate_scope
|
||||
elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None:
|
||||
# Priority 2: PRM scopes_supported
|
||||
return " ".join(protected_resource_metadata.scopes_supported)
|
||||
elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None:
|
||||
return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover
|
||||
else:
|
||||
# Priority 3: Omit scope parameter
|
||||
return None
|
||||
|
||||
|
||||
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
|
||||
"""
|
||||
Generate ordered list of (url, type) tuples for discovery attempts.
|
||||
|
||||
Args:
|
||||
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
|
||||
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
|
||||
"""
|
||||
|
||||
if not auth_server_url:
|
||||
# Legacy path using the 2025-03-26 spec:
|
||||
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
|
||||
parsed = urlparse(server_url)
|
||||
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]
|
||||
|
||||
urls: list[str] = []
|
||||
parsed = urlparse(auth_server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# RFC 8414: Path-aware OAuth discovery
|
||||
if parsed.path and parsed.path != "/":
|
||||
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
|
||||
urls.append(urljoin(base_url, oauth_path))
|
||||
|
||||
# RFC 8414 section 5: Path-aware OIDC discovery
|
||||
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
|
||||
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
|
||||
urls.append(urljoin(base_url, oidc_path))
|
||||
|
||||
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
|
||||
urls.append(urljoin(base_url, oidc_path))
|
||||
return urls
|
||||
|
||||
# OAuth root
|
||||
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))
|
||||
|
||||
# OIDC 1.0 fallback (appends to full URL per OIDC spec)
|
||||
# https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
async def handle_protected_resource_response(
|
||||
response: Response,
|
||||
) -> ProtectedResourceMetadata | None:
|
||||
"""
|
||||
Handle protected resource metadata discovery response.
|
||||
|
||||
Per SEP-985, supports fallback when discovery fails at one URL.
|
||||
|
||||
Returns:
|
||||
True if metadata was successfully discovered, False if we should try next URL
|
||||
"""
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
metadata = ProtectedResourceMetadata.model_validate_json(content)
|
||||
return metadata
|
||||
|
||||
except ValidationError: # pragma: no cover
|
||||
# Invalid metadata - try next URL
|
||||
return None
|
||||
else:
|
||||
# Not found - try next URL in fallback chain
|
||||
return None
|
||||
|
||||
|
||||
async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]:
|
||||
if response.status_code == 200:
|
||||
try:
|
||||
content = await response.aread()
|
||||
asm = OAuthMetadata.model_validate_json(content)
|
||||
return True, asm
|
||||
except ValidationError: # pragma: no cover
|
||||
return True, None
|
||||
elif response.status_code < 400 or response.status_code >= 500:
|
||||
return False, None # Non-4XX error, stop trying
|
||||
return True, None
|
||||
|
||||
|
||||
def create_oauth_metadata_request(url: str) -> Request:
|
||||
return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION})
|
||||
|
||||
|
||||
def create_client_registration_request(
|
||||
auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str
|
||||
) -> Request:
|
||||
"""Build registration request or skip if already registered."""
|
||||
|
||||
if auth_server_metadata and auth_server_metadata.registration_endpoint:
|
||||
registration_url = str(auth_server_metadata.registration_endpoint)
|
||||
else:
|
||||
registration_url = urljoin(auth_base_url, "/register")
|
||||
|
||||
registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
|
||||
return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"})
|
||||
|
||||
|
||||
async def handle_registration_response(response: Response) -> OAuthClientInformationFull:
|
||||
"""Handle registration response."""
|
||||
if response.status_code not in (200, 201):
|
||||
await response.aread()
|
||||
raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}")
|
||||
|
||||
try:
|
||||
content = await response.aread()
|
||||
client_info = OAuthClientInformationFull.model_validate_json(content)
|
||||
return client_info
|
||||
# self.context.client_info = client_info
|
||||
# await self.context.storage.set_client_info(client_info)
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise OAuthRegistrationError(f"Invalid registration response: {e}")
|
||||
|
||||
|
||||
async def handle_token_response_scopes(
|
||||
response: Response,
|
||||
) -> OAuthToken:
|
||||
"""Parse and validate token response with optional scope validation.
|
||||
|
||||
Parses token response JSON. Callers should check response.status_code before calling.
|
||||
|
||||
Args:
|
||||
response: HTTP response from token endpoint (status already checked by caller)
|
||||
|
||||
Returns:
|
||||
Validated OAuthToken model
|
||||
|
||||
Raises:
|
||||
OAuthTokenError: If response JSON is invalid
|
||||
"""
|
||||
try:
|
||||
content = await response.aread()
|
||||
token_response = OAuthToken.model_validate_json(content)
|
||||
return token_response
|
||||
except ValidationError as e: # pragma: no cover
|
||||
raise OAuthTokenError(f"Invalid token response: {e}")
|
||||
@@ -0,0 +1,555 @@
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from typing import Any, Protocol, overload
|
||||
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from jsonschema import SchemaError, ValidationError, validate
|
||||
from pydantic import AnyUrl, TypeAdapter
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.context import RequestContext
|
||||
from mcp.shared.message import SessionMessage
|
||||
from mcp.shared.session import BaseSession, ProgressFnT, RequestResponder
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
|
||||
|
||||
logger = logging.getLogger("client")
|
||||
|
||||
|
||||
class SamplingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class ElicitationFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class ListRootsFnT(Protocol):
|
||||
async def __call__(
|
||||
self, context: RequestContext["ClientSession", Any]
|
||||
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
|
||||
|
||||
|
||||
class LoggingFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None: ... # pragma: no branch
|
||||
|
||||
|
||||
class MessageHandlerFnT(Protocol):
|
||||
async def __call__(
|
||||
self,
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None: ... # pragma: no branch
|
||||
|
||||
|
||||
async def _default_message_handler(
|
||||
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def _default_sampling_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.CreateMessageRequestParams,
|
||||
) -> types.CreateMessageResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Sampling not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_elicitation_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
params: types.ElicitRequestParams,
|
||||
) -> types.ElicitResult | types.ErrorData:
|
||||
return types.ErrorData( # pragma: no cover
|
||||
code=types.INVALID_REQUEST,
|
||||
message="Elicitation not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_list_roots_callback(
|
||||
context: RequestContext["ClientSession", Any],
|
||||
) -> types.ListRootsResult | types.ErrorData:
|
||||
return types.ErrorData(
|
||||
code=types.INVALID_REQUEST,
|
||||
message="List roots not supported",
|
||||
)
|
||||
|
||||
|
||||
async def _default_logging_callback(
|
||||
params: types.LoggingMessageNotificationParams,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
|
||||
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
|
||||
|
||||
|
||||
class ClientSession(
|
||||
BaseSession[
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
types.ClientResult,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
sampling_callback: SamplingFnT | None = None,
|
||||
elicitation_callback: ElicitationFnT | None = None,
|
||||
list_roots_callback: ListRootsFnT | None = None,
|
||||
logging_callback: LoggingFnT | None = None,
|
||||
message_handler: MessageHandlerFnT | None = None,
|
||||
client_info: types.Implementation | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
read_stream,
|
||||
write_stream,
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
read_timeout_seconds=read_timeout_seconds,
|
||||
)
|
||||
self._client_info = client_info or DEFAULT_CLIENT_INFO
|
||||
self._sampling_callback = sampling_callback or _default_sampling_callback
|
||||
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
|
||||
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
|
||||
self._logging_callback = logging_callback or _default_logging_callback
|
||||
self._message_handler = message_handler or _default_message_handler
|
||||
self._tool_output_schemas: dict[str, dict[str, Any] | None] = {}
|
||||
self._server_capabilities: types.ServerCapabilities | None = None
|
||||
|
||||
async def initialize(self) -> types.InitializeResult:
|
||||
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
|
||||
elicitation = (
|
||||
types.ElicitationCapability() if self._elicitation_callback is not _default_elicitation_callback else None
|
||||
)
|
||||
roots = (
|
||||
# TODO: Should this be based on whether we
|
||||
# _will_ send notifications, or only whether
|
||||
# they're supported?
|
||||
types.RootsCapability(listChanged=True)
|
||||
if self._list_roots_callback is not _default_list_roots_callback
|
||||
else None
|
||||
)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.InitializeRequest(
|
||||
params=types.InitializeRequestParams(
|
||||
protocolVersion=types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=types.ClientCapabilities(
|
||||
sampling=sampling,
|
||||
elicitation=elicitation,
|
||||
experimental=None,
|
||||
roots=roots,
|
||||
),
|
||||
clientInfo=self._client_info,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.InitializeResult,
|
||||
)
|
||||
|
||||
if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
|
||||
raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
|
||||
|
||||
self._server_capabilities = result.capabilities
|
||||
|
||||
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
|
||||
|
||||
return result
|
||||
|
||||
def get_server_capabilities(self) -> types.ServerCapabilities | None:
|
||||
"""Return the server capabilities received during initialization.
|
||||
|
||||
Returns None if the session has not been initialized yet.
|
||||
"""
|
||||
return self._server_capabilities
|
||||
|
||||
async def send_ping(self) -> types.EmptyResult:
|
||||
"""Send a ping request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.PingRequest()),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
types.ClientNotification(
|
||||
types.ProgressNotification(
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
|
||||
"""Send a logging/setLevel request."""
|
||||
return await self.send_request( # pragma: no cover
|
||||
types.ClientRequest(
|
||||
types.SetLevelRequest(
|
||||
params=types.SetLevelRequestParams(level=level),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_resources(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_resources(self, cursor: str | None) -> types.ListResourcesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resources(self, *, params: types.PaginatedRequestParams | None) -> types.ListResourcesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resources(self) -> types.ListResourcesResult: ...
|
||||
|
||||
async def list_resources(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListResourcesResult:
|
||||
"""Send a resources/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.ListResourcesRequest(params=request_params)),
|
||||
types.ListResourcesResult,
|
||||
)
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_resource_templates(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_resource_templates(self, cursor: str | None) -> types.ListResourceTemplatesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resource_templates(
|
||||
self, *, params: types.PaginatedRequestParams | None
|
||||
) -> types.ListResourceTemplatesResult: ...
|
||||
|
||||
@overload
|
||||
async def list_resource_templates(self) -> types.ListResourceTemplatesResult: ...
|
||||
|
||||
async def list_resource_templates(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListResourceTemplatesResult:
|
||||
"""Send a resources/templates/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.ListResourceTemplatesRequest(params=request_params)),
|
||||
types.ListResourceTemplatesResult,
|
||||
)
|
||||
|
||||
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
|
||||
"""Send a resources/read request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.ReadResourceRequest(
|
||||
params=types.ReadResourceRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.ReadResourceResult,
|
||||
)
|
||||
|
||||
async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/subscribe request."""
|
||||
return await self.send_request( # pragma: no cover
|
||||
types.ClientRequest(
|
||||
types.SubscribeRequest(
|
||||
params=types.SubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
|
||||
"""Send a resources/unsubscribe request."""
|
||||
return await self.send_request( # pragma: no cover
|
||||
types.ClientRequest(
|
||||
types.UnsubscribeRequest(
|
||||
params=types.UnsubscribeRequestParams(uri=uri),
|
||||
)
|
||||
),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
name: str,
|
||||
arguments: dict[str, Any] | None = None,
|
||||
read_timeout_seconds: timedelta | None = None,
|
||||
progress_callback: ProgressFnT | None = None,
|
||||
*,
|
||||
meta: dict[str, Any] | None = None,
|
||||
) -> types.CallToolResult:
|
||||
"""Send a tools/call request with optional progress callback support."""
|
||||
|
||||
_meta: types.RequestParams.Meta | None = None
|
||||
if meta is not None:
|
||||
_meta = types.RequestParams.Meta(**meta)
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CallToolRequest(
|
||||
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
|
||||
)
|
||||
),
|
||||
types.CallToolResult,
|
||||
request_read_timeout_seconds=read_timeout_seconds,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
if not result.isError:
|
||||
await self._validate_tool_result(name, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _validate_tool_result(self, name: str, result: types.CallToolResult) -> None:
|
||||
"""Validate the structured content of a tool result against its output schema."""
|
||||
if name not in self._tool_output_schemas:
|
||||
# refresh output schema cache
|
||||
await self.list_tools()
|
||||
|
||||
output_schema = None
|
||||
if name in self._tool_output_schemas:
|
||||
output_schema = self._tool_output_schemas.get(name)
|
||||
else:
|
||||
logger.warning(f"Tool {name} not listed by server, cannot validate any structured content")
|
||||
|
||||
if output_schema is not None:
|
||||
if result.structuredContent is None:
|
||||
raise RuntimeError(
|
||||
f"Tool {name} has an output schema but did not return structured content"
|
||||
) # pragma: no cover
|
||||
try:
|
||||
validate(result.structuredContent, output_schema)
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Invalid structured content returned by tool {name}: {e}") # pragma: no cover
|
||||
except SchemaError as e: # pragma: no cover
|
||||
raise RuntimeError(f"Invalid schema for tool {name}: {e}") # pragma: no cover
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_prompts(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_prompts(self, cursor: str | None) -> types.ListPromptsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_prompts(self, *, params: types.PaginatedRequestParams | None) -> types.ListPromptsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_prompts(self) -> types.ListPromptsResult: ...
|
||||
|
||||
async def list_prompts(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListPromptsResult:
|
||||
"""Send a prompts/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(types.ListPromptsRequest(params=request_params)),
|
||||
types.ListPromptsResult,
|
||||
)
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
|
||||
"""Send a prompts/get request."""
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.GetPromptRequest(
|
||||
params=types.GetPromptRequestParams(name=name, arguments=arguments),
|
||||
)
|
||||
),
|
||||
types.GetPromptResult,
|
||||
)
|
||||
|
||||
async def complete(
|
||||
self,
|
||||
ref: types.ResourceTemplateReference | types.PromptReference,
|
||||
argument: dict[str, str],
|
||||
context_arguments: dict[str, str] | None = None,
|
||||
) -> types.CompleteResult:
|
||||
"""Send a completion/complete request."""
|
||||
context = None
|
||||
if context_arguments is not None:
|
||||
context = types.CompletionContext(arguments=context_arguments)
|
||||
|
||||
return await self.send_request(
|
||||
types.ClientRequest(
|
||||
types.CompleteRequest(
|
||||
params=types.CompleteRequestParams(
|
||||
ref=ref,
|
||||
argument=types.CompletionArgument(**argument),
|
||||
context=context,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.CompleteResult,
|
||||
)
|
||||
|
||||
@overload
|
||||
@deprecated("Use list_tools(params=PaginatedRequestParams(...)) instead")
|
||||
async def list_tools(self, cursor: str | None) -> types.ListToolsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_tools(self, *, params: types.PaginatedRequestParams | None) -> types.ListToolsResult: ...
|
||||
|
||||
@overload
|
||||
async def list_tools(self) -> types.ListToolsResult: ...
|
||||
|
||||
async def list_tools(
|
||||
self,
|
||||
cursor: str | None = None,
|
||||
*,
|
||||
params: types.PaginatedRequestParams | None = None,
|
||||
) -> types.ListToolsResult:
|
||||
"""Send a tools/list request.
|
||||
|
||||
Args:
|
||||
cursor: Simple cursor string for pagination (deprecated, use params instead)
|
||||
params: Full pagination parameters including cursor and any future fields
|
||||
"""
|
||||
if params is not None and cursor is not None:
|
||||
raise ValueError("Cannot specify both cursor and params")
|
||||
|
||||
if params is not None:
|
||||
request_params = params
|
||||
elif cursor is not None:
|
||||
request_params = types.PaginatedRequestParams(cursor=cursor)
|
||||
else:
|
||||
request_params = None
|
||||
|
||||
result = await self.send_request(
|
||||
types.ClientRequest(types.ListToolsRequest(params=request_params)),
|
||||
types.ListToolsResult,
|
||||
)
|
||||
|
||||
# Cache tool output schemas for future validation
|
||||
# Note: don't clear the cache, as we may be using a cursor
|
||||
for tool in result.tools:
|
||||
self._tool_output_schemas[tool.name] = tool.outputSchema
|
||||
|
||||
return result
|
||||
|
||||
async def send_roots_list_changed(self) -> None: # pragma: no cover
|
||||
"""Send a roots/list_changed notification."""
|
||||
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
|
||||
|
||||
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
|
||||
ctx = RequestContext[ClientSession, Any](
|
||||
request_id=responder.request_id,
|
||||
meta=responder.request_meta,
|
||||
session=self,
|
||||
lifespan_context=None,
|
||||
)
|
||||
|
||||
match responder.request.root:
|
||||
case types.CreateMessageRequest(params=params):
|
||||
with responder:
|
||||
response = await self._sampling_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ElicitRequest(params=params):
|
||||
with responder:
|
||||
response = await self._elicitation_callback(ctx, params)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.ListRootsRequest():
|
||||
with responder:
|
||||
response = await self._list_roots_callback(ctx)
|
||||
client_response = ClientResponse.validate_python(response)
|
||||
await responder.respond(client_response)
|
||||
|
||||
case types.PingRequest(): # pragma: no cover
|
||||
with responder:
|
||||
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
|
||||
|
||||
async def _handle_incoming(
|
||||
self,
|
||||
req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
|
||||
) -> None:
|
||||
"""Handle incoming messages by forwarding to the message handler."""
|
||||
await self._message_handler(req)
|
||||
|
||||
async def _received_notification(self, notification: types.ServerNotification) -> None:
|
||||
"""Handle notifications from the server."""
|
||||
# Process specific notification types
|
||||
match notification.root:
|
||||
case types.LoggingMessageNotification(params=params):
|
||||
await self._logging_callback(params)
|
||||
case _:
|
||||
pass
|
||||
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
SessionGroup concurrently manages multiple MCP session connections.
|
||||
|
||||
Tools, resources, and prompts are aggregated across servers. Servers may
|
||||
be connected to or disconnected from at any point after initialization.
|
||||
|
||||
This abstractions can handle naming collisions using a custom user-provided
|
||||
hook.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from types import TracebackType
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
import anyio
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
import mcp
|
||||
from mcp import types
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import StdioServerParameters
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.shared.exceptions import McpError
|
||||
|
||||
|
||||
class SseServerParameters(BaseModel):
|
||||
"""Parameters for intializing a sse_client."""
|
||||
|
||||
# The endpoint URL.
|
||||
url: str
|
||||
|
||||
# Optional headers to include in requests.
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
# HTTP timeout for regular operations.
|
||||
timeout: float = 5
|
||||
|
||||
# Timeout for SSE read operations.
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
class StreamableHttpParameters(BaseModel):
|
||||
"""Parameters for intializing a streamablehttp_client."""
|
||||
|
||||
# The endpoint URL.
|
||||
url: str
|
||||
|
||||
# Optional headers to include in requests.
|
||||
headers: dict[str, Any] | None = None
|
||||
|
||||
# HTTP timeout for regular operations.
|
||||
timeout: timedelta = timedelta(seconds=30)
|
||||
|
||||
# Timeout for SSE read operations.
|
||||
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
|
||||
|
||||
# Close the client session when the transport closes.
|
||||
terminate_on_close: bool = True
|
||||
|
||||
|
||||
ServerParameters: TypeAlias = StdioServerParameters | SseServerParameters | StreamableHttpParameters
|
||||
|
||||
|
||||
class ClientSessionGroup:
|
||||
"""Client for managing connections to multiple MCP servers.
|
||||
|
||||
This class is responsible for encapsulating management of server connections.
|
||||
It aggregates tools, resources, and prompts from all connected servers.
|
||||
|
||||
For auxiliary handlers, such as resource subscription, this is delegated to
|
||||
the client and can be accessed via the session.
|
||||
|
||||
Example Usage:
|
||||
name_fn = lambda name, server_info: f"{(server_info.name)}_{name}"
|
||||
async with ClientSessionGroup(component_name_hook=name_fn) as group:
|
||||
for server_param in server_params:
|
||||
await group.connect_to_server(server_param)
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
class _ComponentNames(BaseModel):
|
||||
"""Used for reverse index to find components."""
|
||||
|
||||
prompts: set[str] = set()
|
||||
resources: set[str] = set()
|
||||
tools: set[str] = set()
|
||||
|
||||
# Standard MCP components.
|
||||
_prompts: dict[str, types.Prompt]
|
||||
_resources: dict[str, types.Resource]
|
||||
_tools: dict[str, types.Tool]
|
||||
|
||||
# Client-server connection management.
|
||||
_sessions: dict[mcp.ClientSession, _ComponentNames]
|
||||
_tool_to_session: dict[str, mcp.ClientSession]
|
||||
_exit_stack: contextlib.AsyncExitStack
|
||||
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
|
||||
|
||||
# Optional fn consuming (component_name, serverInfo) for custom names.
|
||||
# This is provide a means to mitigate naming conflicts across servers.
|
||||
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
|
||||
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
|
||||
_component_name_hook: _ComponentNameHook | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exit_stack: contextlib.AsyncExitStack | None = None,
|
||||
component_name_hook: _ComponentNameHook | None = None,
|
||||
) -> None:
|
||||
"""Initializes the MCP client."""
|
||||
|
||||
self._tools = {}
|
||||
self._resources = {}
|
||||
self._prompts = {}
|
||||
|
||||
self._sessions = {}
|
||||
self._tool_to_session = {}
|
||||
if exit_stack is None:
|
||||
self._exit_stack = contextlib.AsyncExitStack()
|
||||
self._owns_exit_stack = True
|
||||
else:
|
||||
self._exit_stack = exit_stack
|
||||
self._owns_exit_stack = False
|
||||
self._session_exit_stacks = {}
|
||||
self._component_name_hook = component_name_hook
|
||||
|
||||
async def __aenter__(self) -> Self: # pragma: no cover
|
||||
# Enter the exit stack only if we created it ourselves
|
||||
if self._owns_exit_stack:
|
||||
await self._exit_stack.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
_exc_type: type[BaseException] | None,
|
||||
_exc_val: BaseException | None,
|
||||
_exc_tb: TracebackType | None,
|
||||
) -> bool | None: # pragma: no cover
|
||||
"""Closes session exit stacks and main exit stack upon completion."""
|
||||
|
||||
# Only close the main exit stack if we created it
|
||||
if self._owns_exit_stack:
|
||||
await self._exit_stack.aclose()
|
||||
|
||||
# Concurrently close session stacks.
|
||||
async with anyio.create_task_group() as tg:
|
||||
for exit_stack in self._session_exit_stacks.values():
|
||||
tg.start_soon(exit_stack.aclose)
|
||||
|
||||
@property
|
||||
def sessions(self) -> list[mcp.ClientSession]:
|
||||
"""Returns the list of sessions being managed."""
|
||||
return list(self._sessions.keys()) # pragma: no cover
|
||||
|
||||
@property
|
||||
def prompts(self) -> dict[str, types.Prompt]:
|
||||
"""Returns the prompts as a dictionary of names to prompts."""
|
||||
return self._prompts
|
||||
|
||||
@property
|
||||
def resources(self) -> dict[str, types.Resource]:
|
||||
"""Returns the resources as a dictionary of names to resources."""
|
||||
return self._resources
|
||||
|
||||
@property
|
||||
def tools(self) -> dict[str, types.Tool]:
|
||||
"""Returns the tools as a dictionary of names to tools."""
|
||||
return self._tools
|
||||
|
||||
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
|
||||
"""Executes a tool given its name and arguments."""
|
||||
session = self._tool_to_session[name]
|
||||
session_tool_name = self.tools[name].name
|
||||
return await session.call_tool(session_tool_name, args)
|
||||
|
||||
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
|
||||
"""Disconnects from a single MCP server."""
|
||||
|
||||
session_known_for_components = session in self._sessions
|
||||
session_known_for_stack = session in self._session_exit_stacks
|
||||
|
||||
if not session_known_for_components and not session_known_for_stack:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message="Provided session is not managed or already disconnected.",
|
||||
)
|
||||
)
|
||||
|
||||
if session_known_for_components: # pragma: no cover
|
||||
component_names = self._sessions.pop(session) # Pop from _sessions tracking
|
||||
|
||||
# Remove prompts associated with the session.
|
||||
for name in component_names.prompts:
|
||||
if name in self._prompts:
|
||||
del self._prompts[name]
|
||||
# Remove resources associated with the session.
|
||||
for name in component_names.resources:
|
||||
if name in self._resources:
|
||||
del self._resources[name]
|
||||
# Remove tools associated with the session.
|
||||
for name in component_names.tools:
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
if name in self._tool_to_session:
|
||||
del self._tool_to_session[name]
|
||||
|
||||
# Clean up the session's resources via its dedicated exit stack
|
||||
if session_known_for_stack:
|
||||
session_stack_to_close = self._session_exit_stacks.pop(session) # pragma: no cover
|
||||
await session_stack_to_close.aclose() # pragma: no cover
|
||||
|
||||
async def connect_with_session(
|
||||
self, server_info: types.Implementation, session: mcp.ClientSession
|
||||
) -> mcp.ClientSession:
|
||||
"""Connects to a single MCP server."""
|
||||
await self._aggregate_components(server_info, session)
|
||||
return session
|
||||
|
||||
async def connect_to_server(
|
||||
self,
|
||||
server_params: ServerParameters,
|
||||
) -> mcp.ClientSession:
|
||||
"""Connects to a single MCP server."""
|
||||
server_info, session = await self._establish_session(server_params)
|
||||
return await self.connect_with_session(server_info, session)
|
||||
|
||||
async def _establish_session(
|
||||
self, server_params: ServerParameters
|
||||
) -> tuple[types.Implementation, mcp.ClientSession]:
|
||||
"""Establish a client session to an MCP server."""
|
||||
|
||||
session_stack = contextlib.AsyncExitStack()
|
||||
try:
|
||||
# Create read and write streams that facilitate io with the server.
|
||||
if isinstance(server_params, StdioServerParameters):
|
||||
client = mcp.stdio_client(server_params)
|
||||
read, write = await session_stack.enter_async_context(client)
|
||||
elif isinstance(server_params, SseServerParameters):
|
||||
client = sse_client(
|
||||
url=server_params.url,
|
||||
headers=server_params.headers,
|
||||
timeout=server_params.timeout,
|
||||
sse_read_timeout=server_params.sse_read_timeout,
|
||||
)
|
||||
read, write = await session_stack.enter_async_context(client)
|
||||
else:
|
||||
client = streamablehttp_client(
|
||||
url=server_params.url,
|
||||
headers=server_params.headers,
|
||||
timeout=server_params.timeout,
|
||||
sse_read_timeout=server_params.sse_read_timeout,
|
||||
terminate_on_close=server_params.terminate_on_close,
|
||||
)
|
||||
read, write, _ = await session_stack.enter_async_context(client)
|
||||
|
||||
session = await session_stack.enter_async_context(mcp.ClientSession(read, write))
|
||||
result = await session.initialize()
|
||||
|
||||
# Session successfully initialized.
|
||||
# Store its stack and register the stack with the main group stack.
|
||||
self._session_exit_stacks[session] = session_stack
|
||||
# session_stack itself becomes a resource managed by the
|
||||
# main _exit_stack.
|
||||
await self._exit_stack.enter_async_context(session_stack)
|
||||
|
||||
return result.serverInfo, session
|
||||
except Exception: # pragma: no cover
|
||||
# If anything during this setup fails, ensure the session-specific
|
||||
# stack is closed.
|
||||
await session_stack.aclose()
|
||||
raise
|
||||
|
||||
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
|
||||
"""Aggregates prompts, resources, and tools from a given session."""
|
||||
|
||||
# Create a reverse index so we can find all prompts, resources, and
|
||||
# tools belonging to this session. Used for removing components from
|
||||
# the session group via self.disconnect_from_server.
|
||||
component_names = self._ComponentNames()
|
||||
|
||||
# Temporary components dicts. We do not want to modify the aggregate
|
||||
# lists in case of an intermediate failure.
|
||||
prompts_temp: dict[str, types.Prompt] = {}
|
||||
resources_temp: dict[str, types.Resource] = {}
|
||||
tools_temp: dict[str, types.Tool] = {}
|
||||
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
|
||||
|
||||
# Query the server for its prompts and aggregate to list.
|
||||
try:
|
||||
prompts = (await session.list_prompts()).prompts
|
||||
for prompt in prompts:
|
||||
name = self._component_name(prompt.name, server_info)
|
||||
prompts_temp[name] = prompt
|
||||
component_names.prompts.add(name)
|
||||
except McpError as err: # pragma: no cover
|
||||
logging.warning(f"Could not fetch prompts: {err}")
|
||||
|
||||
# Query the server for its resources and aggregate to list.
|
||||
try:
|
||||
resources = (await session.list_resources()).resources
|
||||
for resource in resources:
|
||||
name = self._component_name(resource.name, server_info)
|
||||
resources_temp[name] = resource
|
||||
component_names.resources.add(name)
|
||||
except McpError as err: # pragma: no cover
|
||||
logging.warning(f"Could not fetch resources: {err}")
|
||||
|
||||
# Query the server for its tools and aggregate to list.
|
||||
try:
|
||||
tools = (await session.list_tools()).tools
|
||||
for tool in tools:
|
||||
name = self._component_name(tool.name, server_info)
|
||||
tools_temp[name] = tool
|
||||
tool_to_session_temp[name] = session
|
||||
component_names.tools.add(name)
|
||||
except McpError as err: # pragma: no cover
|
||||
logging.warning(f"Could not fetch tools: {err}")
|
||||
|
||||
# Clean up exit stack for session if we couldn't retrieve anything
|
||||
# from the server.
|
||||
if not any((prompts_temp, resources_temp, tools_temp)):
|
||||
del self._session_exit_stacks[session] # pragma: no cover
|
||||
|
||||
# Check for duplicates.
|
||||
matching_prompts = prompts_temp.keys() & self._prompts.keys()
|
||||
if matching_prompts:
|
||||
raise McpError( # pragma: no cover
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_prompts} already exist in group prompts.",
|
||||
)
|
||||
)
|
||||
matching_resources = resources_temp.keys() & self._resources.keys()
|
||||
if matching_resources:
|
||||
raise McpError( # pragma: no cover
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_resources} already exist in group resources.",
|
||||
)
|
||||
)
|
||||
matching_tools = tools_temp.keys() & self._tools.keys()
|
||||
if matching_tools:
|
||||
raise McpError(
|
||||
types.ErrorData(
|
||||
code=types.INVALID_PARAMS,
|
||||
message=f"{matching_tools} already exist in group tools.",
|
||||
)
|
||||
)
|
||||
|
||||
# Aggregate components.
|
||||
self._sessions[session] = component_names
|
||||
self._prompts.update(prompts_temp)
|
||||
self._resources.update(resources_temp)
|
||||
self._tools.update(tools_temp)
|
||||
self._tool_to_session.update(tool_to_session_temp)
|
||||
|
||||
def _component_name(self, name: str, server_info: types.Implementation) -> str:
|
||||
if self._component_name_hook:
|
||||
return self._component_name_hook(name, server_info)
|
||||
return name
|
||||
@@ -0,0 +1,148 @@
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskStatus
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import aconnect_sse
|
||||
from httpx_sse._exceptions import SSEError
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def remove_request_params(url: str) -> str:
|
||||
return urljoin(url, urlparse(url).path)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def sse_client(
|
||||
url: str,
|
||||
headers: dict[str, Any] | None = None,
|
||||
timeout: float = 5,
|
||||
sse_read_timeout: float = 60 * 5,
|
||||
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
|
||||
auth: httpx.Auth | None = None,
|
||||
):
|
||||
"""
|
||||
Client transport for SSE.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Args:
|
||||
url: The SSE endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
auth: Optional HTTPX authentication handler.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
|
||||
async with httpx_client_factory(
|
||||
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
|
||||
) as client:
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
url,
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("SSE connection established")
|
||||
|
||||
async def sse_reader(
|
||||
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
|
||||
):
|
||||
try:
|
||||
async for sse in event_source.aiter_sse(): # pragma: no branch
|
||||
logger.debug(f"Received SSE event: {sse.event}")
|
||||
match sse.event:
|
||||
case "endpoint":
|
||||
endpoint_url = urljoin(url, sse.data)
|
||||
logger.debug(f"Received endpoint URL: {endpoint_url}")
|
||||
|
||||
url_parsed = urlparse(url)
|
||||
endpoint_parsed = urlparse(endpoint_url)
|
||||
if ( # pragma: no cover
|
||||
url_parsed.netloc != endpoint_parsed.netloc
|
||||
or url_parsed.scheme != endpoint_parsed.scheme
|
||||
):
|
||||
error_msg = ( # pragma: no cover
|
||||
f"Endpoint origin does not match connection origin: {endpoint_url}"
|
||||
)
|
||||
logger.error(error_msg) # pragma: no cover
|
||||
raise ValueError(error_msg) # pragma: no cover
|
||||
|
||||
task_status.started(endpoint_url)
|
||||
|
||||
case "message":
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
|
||||
sse.data
|
||||
)
|
||||
logger.debug(f"Received server message: {message}")
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error parsing server message") # pragma: no cover
|
||||
await read_stream_writer.send(exc) # pragma: no cover
|
||||
continue # pragma: no cover
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
case _: # pragma: no cover
|
||||
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
|
||||
except SSEError as sse_exc: # pragma: no cover
|
||||
logger.exception("Encountered SSE exception") # pragma: no cover
|
||||
raise sse_exc # pragma: no cover
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error in sse_reader") # pragma: no cover
|
||||
await read_stream_writer.send(exc) # pragma: no cover
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
|
||||
async def post_writer(endpoint_url: str):
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
logger.debug(f"Sending client message: {session_message}")
|
||||
response = await client.post(
|
||||
endpoint_url,
|
||||
json=session_message.message.model_dump(
|
||||
by_alias=True,
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
logger.debug(f"Client message sent successfully: {response.status_code}")
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Error in post_writer") # pragma: no cover
|
||||
finally:
|
||||
await write_stream.aclose()
|
||||
|
||||
endpoint_url = await tg.start(sse_reader)
|
||||
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
|
||||
tg.start_soon(post_writer, endpoint_url)
|
||||
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
@@ -0,0 +1,278 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Literal, TextIO
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.abc import Process
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.os.posix.utilities import terminate_posix_process_tree
|
||||
from mcp.os.win32.utilities import (
|
||||
FallbackProcess,
|
||||
create_windows_process,
|
||||
get_windows_executable_command,
|
||||
terminate_windows_process_tree,
|
||||
)
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Environment variables to inherit by default
|
||||
DEFAULT_INHERITED_ENV_VARS = (
|
||||
[
|
||||
"APPDATA",
|
||||
"HOMEDRIVE",
|
||||
"HOMEPATH",
|
||||
"LOCALAPPDATA",
|
||||
"PATH",
|
||||
"PATHEXT",
|
||||
"PROCESSOR_ARCHITECTURE",
|
||||
"SYSTEMDRIVE",
|
||||
"SYSTEMROOT",
|
||||
"TEMP",
|
||||
"USERNAME",
|
||||
"USERPROFILE",
|
||||
]
|
||||
if sys.platform == "win32"
|
||||
else ["HOME", "LOGNAME", "PATH", "SHELL", "TERM", "USER"]
|
||||
)
|
||||
|
||||
# Timeout for process termination before falling back to force kill
|
||||
PROCESS_TERMINATION_TIMEOUT = 2.0
|
||||
|
||||
|
||||
def get_default_environment() -> dict[str, str]:
|
||||
"""
|
||||
Returns a default environment object including only environment variables deemed
|
||||
safe to inherit.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
|
||||
for key in DEFAULT_INHERITED_ENV_VARS:
|
||||
value = os.environ.get(key)
|
||||
if value is None:
|
||||
continue # pragma: no cover
|
||||
|
||||
if value.startswith("()"): # pragma: no cover
|
||||
# Skip functions, which are a security risk
|
||||
continue # pragma: no cover
|
||||
|
||||
env[key] = value
|
||||
|
||||
return env
|
||||
|
||||
|
||||
class StdioServerParameters(BaseModel):
|
||||
command: str
|
||||
"""The executable to run to start the server."""
|
||||
|
||||
args: list[str] = Field(default_factory=list)
|
||||
"""Command line arguments to pass to the executable."""
|
||||
|
||||
env: dict[str, str] | None = None
|
||||
"""
|
||||
The environment to use when spawning the process.
|
||||
|
||||
If not specified, the result of get_default_environment() will be used.
|
||||
"""
|
||||
|
||||
cwd: str | Path | None = None
|
||||
"""The working directory to use when spawning the process."""
|
||||
|
||||
encoding: str = "utf-8"
|
||||
"""
|
||||
The text encoding used when sending/receiving messages to the server
|
||||
|
||||
defaults to utf-8
|
||||
"""
|
||||
|
||||
encoding_error_handler: Literal["strict", "ignore", "replace"] = "strict"
|
||||
"""
|
||||
The text encoding error handler.
|
||||
|
||||
See https://docs.python.org/3/library/codecs.html#codec-base-classes for
|
||||
explanations of possible values
|
||||
"""
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
try:
|
||||
command = _get_executable_command(server.command)
|
||||
|
||||
# Open process with stderr piped for capture
|
||||
process = await _create_platform_compatible_process(
|
||||
command=command,
|
||||
args=server.args,
|
||||
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
|
||||
errlog=errlog,
|
||||
cwd=server.cwd,
|
||||
)
|
||||
except OSError:
|
||||
# Clean up streams if process creation fails
|
||||
await read_stream.aclose()
|
||||
await write_stream.aclose()
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
raise
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
buffer = ""
|
||||
async for chunk in TextReceiveStream(
|
||||
process.stdout,
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Failed to parse JSONRPC message from server")
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except anyio.ClosedResourceError: # pragma: no cover
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
json = session_message.message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await process.stdin.send(
|
||||
(json + "\n").encode(
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
)
|
||||
)
|
||||
except anyio.ClosedResourceError: # pragma: no cover
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async with (
|
||||
anyio.create_task_group() as tg,
|
||||
process,
|
||||
):
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
try:
|
||||
yield read_stream, write_stream
|
||||
finally:
|
||||
# MCP spec: stdio shutdown sequence
|
||||
# 1. Close input stream to server
|
||||
# 2. Wait for server to exit, or send SIGTERM if it doesn't exit in time
|
||||
# 3. Send SIGKILL if still not exited
|
||||
if process.stdin: # pragma: no branch
|
||||
try:
|
||||
await process.stdin.aclose()
|
||||
except Exception: # pragma: no cover
|
||||
# stdin might already be closed, which is fine
|
||||
pass
|
||||
|
||||
try:
|
||||
# Give the process time to exit gracefully after stdin closes
|
||||
with anyio.fail_after(PROCESS_TERMINATION_TIMEOUT):
|
||||
await process.wait()
|
||||
except TimeoutError:
|
||||
# Process didn't exit from stdin closure, use platform-specific termination
|
||||
# which handles SIGTERM -> SIGKILL escalation
|
||||
await _terminate_process_tree(process)
|
||||
except ProcessLookupError: # pragma: no cover
|
||||
# Process already exited, which is fine
|
||||
pass
|
||||
await read_stream.aclose()
|
||||
await write_stream.aclose()
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream_reader.aclose()
|
||||
|
||||
|
||||
def _get_executable_command(command: str) -> str:
|
||||
"""
|
||||
Get the correct executable command normalized for the current platform.
|
||||
|
||||
Args:
|
||||
command: Base command (e.g., 'uvx', 'npx')
|
||||
|
||||
Returns:
|
||||
str: Platform-appropriate command
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
return get_windows_executable_command(command)
|
||||
else:
|
||||
return command # pragma: no cover
|
||||
|
||||
|
||||
async def _create_platform_compatible_process(
|
||||
command: str,
|
||||
args: list[str],
|
||||
env: dict[str, str] | None = None,
|
||||
errlog: TextIO = sys.stderr,
|
||||
cwd: Path | str | None = None,
|
||||
):
|
||||
"""
|
||||
Creates a subprocess in a platform-compatible way.
|
||||
|
||||
Unix: Creates process in a new session/process group for killpg support
|
||||
Windows: Creates process in a Job Object for reliable child termination
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
process = await create_windows_process(command, args, env, errlog, cwd)
|
||||
else:
|
||||
process = await anyio.open_process(
|
||||
[command, *args],
|
||||
env=env,
|
||||
stderr=errlog,
|
||||
cwd=cwd,
|
||||
start_new_session=True,
|
||||
) # pragma: no cover
|
||||
|
||||
return process
|
||||
|
||||
|
||||
async def _terminate_process_tree(process: Process | FallbackProcess, timeout_seconds: float = 2.0) -> None:
|
||||
"""
|
||||
Terminate a process and all its children using platform-specific methods.
|
||||
|
||||
Unix: Uses os.killpg() for atomic process group termination
|
||||
Windows: Uses Job Objects via pywin32 for reliable child process cleanup
|
||||
|
||||
Args:
|
||||
process: The process to terminate
|
||||
timeout_seconds: Timeout in seconds before force killing (default: 2.0)
|
||||
"""
|
||||
if sys.platform == "win32": # pragma: no cover
|
||||
await terminate_windows_process_tree(process, timeout_seconds)
|
||||
else: # pragma: no cover
|
||||
# FallbackProcess should only be used for Windows compatibility
|
||||
assert isinstance(process, Process)
|
||||
await terminate_posix_process_tree(process, timeout_seconds)
|
||||
Binary file not shown.
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
StreamableHTTP Client Transport Module
|
||||
|
||||
This module implements the StreamableHTTP transport for MCP clients,
|
||||
providing support for HTTP POST requests with optional SSE streaming responses
|
||||
and session management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Awaitable, Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
|
||||
import anyio
|
||||
import httpx
|
||||
from anyio.abc import TaskGroup
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
|
||||
|
||||
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
|
||||
from mcp.shared.message import ClientMessageMetadata, SessionMessage
|
||||
from mcp.types import (
|
||||
ErrorData,
|
||||
InitializeResult,
|
||||
JSONRPCError,
|
||||
JSONRPCMessage,
|
||||
JSONRPCNotification,
|
||||
JSONRPCRequest,
|
||||
JSONRPCResponse,
|
||||
RequestId,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SessionMessageOrError = SessionMessage | Exception
|
||||
StreamWriter = MemoryObjectSendStream[SessionMessageOrError]
|
||||
StreamReader = MemoryObjectReceiveStream[SessionMessage]
|
||||
GetSessionIdCallback = Callable[[], str | None]
|
||||
|
||||
MCP_SESSION_ID = "mcp-session-id"
|
||||
MCP_PROTOCOL_VERSION = "mcp-protocol-version"
|
||||
LAST_EVENT_ID = "last-event-id"
|
||||
CONTENT_TYPE = "content-type"
|
||||
ACCEPT = "accept"
|
||||
|
||||
|
||||
JSON = "application/json"
|
||||
SSE = "text/event-stream"
|
||||
|
||||
|
||||
class StreamableHTTPError(Exception):
|
||||
"""Base exception for StreamableHTTP transport errors."""
|
||||
|
||||
|
||||
class ResumptionError(StreamableHTTPError):
|
||||
"""Raised when resumption request is invalid."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestContext:
|
||||
"""Context for a request operation."""
|
||||
|
||||
client: httpx.AsyncClient
|
||||
headers: dict[str, str]
|
||||
session_id: str | None
|
||||
session_message: SessionMessage
|
||||
metadata: ClientMessageMetadata | None
|
||||
read_stream_writer: StreamWriter
|
||||
sse_read_timeout: float
|
||||
|
||||
|
||||
class StreamableHTTPTransport:
|
||||
"""StreamableHTTP client transport implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> None:
|
||||
"""Initialize the StreamableHTTP transport.
|
||||
|
||||
Args:
|
||||
url: The endpoint URL.
|
||||
headers: Optional headers to include in requests.
|
||||
timeout: HTTP timeout for regular operations.
|
||||
sse_read_timeout: Timeout for SSE read operations.
|
||||
auth: Optional HTTPX authentication handler.
|
||||
"""
|
||||
self.url = url
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout.total_seconds() if isinstance(timeout, timedelta) else timeout
|
||||
self.sse_read_timeout = (
|
||||
sse_read_timeout.total_seconds() if isinstance(sse_read_timeout, timedelta) else sse_read_timeout
|
||||
)
|
||||
self.auth = auth
|
||||
self.session_id = None
|
||||
self.protocol_version = None
|
||||
self.request_headers = {
|
||||
ACCEPT: f"{JSON}, {SSE}",
|
||||
CONTENT_TYPE: JSON,
|
||||
**self.headers,
|
||||
}
|
||||
|
||||
def _prepare_request_headers(self, base_headers: dict[str, str]) -> dict[str, str]:
|
||||
"""Update headers with session ID and protocol version if available."""
|
||||
headers = base_headers.copy()
|
||||
if self.session_id:
|
||||
headers[MCP_SESSION_ID] = self.session_id
|
||||
if self.protocol_version:
|
||||
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
|
||||
return headers
|
||||
|
||||
def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialization request."""
|
||||
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
|
||||
|
||||
def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
|
||||
"""Check if the message is an initialized notification."""
|
||||
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
|
||||
|
||||
def _maybe_extract_session_id_from_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
) -> None:
|
||||
"""Extract and store session ID from response headers."""
|
||||
new_session_id = response.headers.get(MCP_SESSION_ID)
|
||||
if new_session_id:
|
||||
self.session_id = new_session_id
|
||||
logger.info(f"Received session ID: {self.session_id}")
|
||||
|
||||
def _maybe_extract_protocol_version_from_message(
|
||||
self,
|
||||
message: JSONRPCMessage,
|
||||
) -> None:
|
||||
"""Extract protocol version from initialization response message."""
|
||||
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
|
||||
try:
|
||||
# Parse the result as InitializeResult for type safety
|
||||
init_result = InitializeResult.model_validate(message.root.result)
|
||||
self.protocol_version = str(init_result.protocolVersion)
|
||||
logger.info(f"Negotiated protocol version: {self.protocol_version}")
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.warning(
|
||||
f"Failed to parse initialization response as InitializeResult: {exc}"
|
||||
) # pragma: no cover
|
||||
logger.warning(f"Raw result: {message.root.result}")
|
||||
|
||||
async def _handle_sse_event(
|
||||
self,
|
||||
sse: ServerSentEvent,
|
||||
read_stream_writer: StreamWriter,
|
||||
original_request_id: RequestId | None = None,
|
||||
resumption_callback: Callable[[str], Awaitable[None]] | None = None,
|
||||
is_initialization: bool = False,
|
||||
) -> bool:
|
||||
"""Handle an SSE event, returning True if the response is complete."""
|
||||
if sse.event == "message":
|
||||
try:
|
||||
message = JSONRPCMessage.model_validate_json(sse.data)
|
||||
logger.debug(f"SSE message: {message}")
|
||||
|
||||
# Extract protocol version from initialization response
|
||||
if is_initialization:
|
||||
self._maybe_extract_protocol_version_from_message(message)
|
||||
|
||||
# If this is a response and we have original_request_id, replace it
|
||||
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
|
||||
message.root.id = original_request_id
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
# Call resumption token callback if we have an ID
|
||||
if sse.id and resumption_callback:
|
||||
await resumption_callback(sse.id)
|
||||
|
||||
# If this is a response or error return True indicating completion
|
||||
# Otherwise, return False to continue listening
|
||||
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
|
||||
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error parsing SSE message")
|
||||
await read_stream_writer.send(exc)
|
||||
return False
|
||||
else: # pragma: no cover
|
||||
logger.warning(f"Unknown SSE event: {sse.event}")
|
||||
return False
|
||||
|
||||
async def handle_get_stream(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
read_stream_writer: StreamWriter,
|
||||
) -> None:
|
||||
"""Handle GET stream for server-initiated messages."""
|
||||
try:
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
headers = self._prepare_request_headers(self.request_headers)
|
||||
|
||||
async with aconnect_sse(
|
||||
client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("GET SSE connection established")
|
||||
|
||||
async for sse in event_source.aiter_sse():
|
||||
await self._handle_sse_event(sse, read_stream_writer)
|
||||
|
||||
except Exception as exc:
|
||||
logger.debug(f"GET stream error (non-fatal): {exc}") # pragma: no cover
|
||||
|
||||
async def _handle_resumption_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a resumption request using GET with SSE."""
|
||||
headers = self._prepare_request_headers(ctx.headers)
|
||||
if ctx.metadata and ctx.metadata.resumption_token:
|
||||
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
|
||||
else:
|
||||
raise ResumptionError("Resumption request requires a resumption token") # pragma: no cover
|
||||
|
||||
# Extract original request ID to map responses
|
||||
original_request_id = None
|
||||
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
|
||||
original_request_id = ctx.session_message.message.root.id
|
||||
|
||||
async with aconnect_sse(
|
||||
ctx.client,
|
||||
"GET",
|
||||
self.url,
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout, read=self.sse_read_timeout),
|
||||
) as event_source:
|
||||
event_source.response.raise_for_status()
|
||||
logger.debug("Resumption GET SSE connection established")
|
||||
|
||||
async for sse in event_source.aiter_sse(): # pragma: no branch
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
original_request_id,
|
||||
ctx.metadata.on_resumption_token_update if ctx.metadata else None,
|
||||
)
|
||||
if is_complete:
|
||||
await event_source.response.aclose()
|
||||
break
|
||||
|
||||
async def _handle_post_request(self, ctx: RequestContext) -> None:
|
||||
"""Handle a POST request with response processing."""
|
||||
headers = self._prepare_request_headers(ctx.headers)
|
||||
message = ctx.session_message.message
|
||||
is_initialization = self._is_initialization_request(message)
|
||||
|
||||
async with ctx.client.stream(
|
||||
"POST",
|
||||
self.url,
|
||||
json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
|
||||
headers=headers,
|
||||
) as response:
|
||||
if response.status_code == 202:
|
||||
logger.debug("Received 202 Accepted")
|
||||
return
|
||||
|
||||
if response.status_code == 404: # pragma: no branch
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
await self._send_session_terminated_error( # pragma: no cover
|
||||
ctx.read_stream_writer, # pragma: no cover
|
||||
message.root.id, # pragma: no cover
|
||||
) # pragma: no cover
|
||||
return # pragma: no cover
|
||||
|
||||
response.raise_for_status()
|
||||
if is_initialization:
|
||||
self._maybe_extract_session_id_from_response(response)
|
||||
|
||||
# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
|
||||
# The server MUST NOT send a response to notifications.
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
content_type = response.headers.get(CONTENT_TYPE, "").lower()
|
||||
if content_type.startswith(JSON):
|
||||
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
|
||||
elif content_type.startswith(SSE):
|
||||
await self._handle_sse_response(response, ctx, is_initialization)
|
||||
else:
|
||||
await self._handle_unexpected_content_type( # pragma: no cover
|
||||
content_type, # pragma: no cover
|
||||
ctx.read_stream_writer, # pragma: no cover
|
||||
) # pragma: no cover
|
||||
|
||||
async def _handle_json_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
read_stream_writer: StreamWriter,
|
||||
is_initialization: bool = False,
|
||||
) -> None:
|
||||
"""Handle JSON response from the server."""
|
||||
try:
|
||||
content = await response.aread()
|
||||
message = JSONRPCMessage.model_validate_json(content)
|
||||
|
||||
# Extract protocol version from initialization response
|
||||
if is_initialization:
|
||||
self._maybe_extract_protocol_version_from_message(message)
|
||||
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Error parsing JSON response")
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
async def _handle_sse_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
ctx: RequestContext,
|
||||
is_initialization: bool = False,
|
||||
) -> None:
|
||||
"""Handle SSE response from the server."""
|
||||
try:
|
||||
event_source = EventSource(response)
|
||||
async for sse in event_source.aiter_sse(): # pragma: no branch
|
||||
is_complete = await self._handle_sse_event(
|
||||
sse,
|
||||
ctx.read_stream_writer,
|
||||
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
|
||||
is_initialization=is_initialization,
|
||||
)
|
||||
# If the SSE event indicates completion, like returning respose/error
|
||||
# break the loop
|
||||
if is_complete:
|
||||
await response.aclose()
|
||||
break
|
||||
except Exception as e:
|
||||
logger.exception("Error reading SSE stream:") # pragma: no cover
|
||||
await ctx.read_stream_writer.send(e) # pragma: no cover
|
||||
|
||||
async def _handle_unexpected_content_type(
|
||||
self,
|
||||
content_type: str,
|
||||
read_stream_writer: StreamWriter,
|
||||
) -> None: # pragma: no cover
|
||||
"""Handle unexpected content type in response."""
|
||||
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
|
||||
logger.error(error_msg) # pragma: no cover
|
||||
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover
|
||||
|
||||
async def _send_session_terminated_error(
|
||||
self,
|
||||
read_stream_writer: StreamWriter,
|
||||
request_id: RequestId,
|
||||
) -> None:
|
||||
"""Send a session terminated error response."""
|
||||
jsonrpc_error = JSONRPCError(
|
||||
jsonrpc="2.0",
|
||||
id=request_id,
|
||||
error=ErrorData(code=32600, message="Session terminated"),
|
||||
)
|
||||
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
|
||||
await read_stream_writer.send(session_message)
|
||||
|
||||
async def post_writer(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
write_stream_reader: StreamReader,
|
||||
read_stream_writer: StreamWriter,
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
start_get_stream: Callable[[], None],
|
||||
tg: TaskGroup,
|
||||
) -> None:
|
||||
"""Handle writing requests to the server."""
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
message = session_message.message
|
||||
metadata = (
|
||||
session_message.metadata
|
||||
if isinstance(session_message.metadata, ClientMessageMetadata)
|
||||
else None
|
||||
)
|
||||
|
||||
# Check if this is a resumption request
|
||||
is_resumption = bool(metadata and metadata.resumption_token)
|
||||
|
||||
logger.debug(f"Sending client message: {message}")
|
||||
|
||||
# Handle initialized notification
|
||||
if self._is_initialized_notification(message):
|
||||
start_get_stream()
|
||||
|
||||
ctx = RequestContext(
|
||||
client=client,
|
||||
headers=self.request_headers,
|
||||
session_id=self.session_id,
|
||||
session_message=session_message,
|
||||
metadata=metadata,
|
||||
read_stream_writer=read_stream_writer,
|
||||
sse_read_timeout=self.sse_read_timeout,
|
||||
)
|
||||
|
||||
async def handle_request_async():
|
||||
if is_resumption:
|
||||
await self._handle_resumption_request(ctx)
|
||||
else:
|
||||
await self._handle_post_request(ctx)
|
||||
|
||||
# If this is a request, start a new task to handle it
|
||||
if isinstance(message.root, JSONRPCRequest):
|
||||
tg.start_soon(handle_request_async)
|
||||
else:
|
||||
await handle_request_async()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error in post_writer") # pragma: no cover
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
|
||||
async def terminate_session(self, client: httpx.AsyncClient) -> None: # pragma: no cover
|
||||
"""Terminate the session by sending a DELETE request."""
|
||||
if not self.session_id:
|
||||
return
|
||||
|
||||
try:
|
||||
headers = self._prepare_request_headers(self.request_headers)
|
||||
response = await client.delete(self.url, headers=headers)
|
||||
|
||||
if response.status_code == 405:
|
||||
logger.debug("Server does not allow session termination")
|
||||
elif response.status_code not in (200, 204):
|
||||
logger.warning(f"Session termination failed: {response.status_code}")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Session termination failed: {exc}")
|
||||
|
||||
def get_session_id(self) -> str | None:
|
||||
"""Get the current session ID."""
|
||||
return self.session_id
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def streamablehttp_client(
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: float | timedelta = 30,
|
||||
sse_read_timeout: float | timedelta = 60 * 5,
|
||||
terminate_on_close: bool = True,
|
||||
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
|
||||
auth: httpx.Auth | None = None,
|
||||
) -> AsyncGenerator[
|
||||
tuple[
|
||||
MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
MemoryObjectSendStream[SessionMessage],
|
||||
GetSessionIdCallback,
|
||||
],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
Client transport for StreamableHTTP.
|
||||
|
||||
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
|
||||
event before disconnecting. All other HTTP operations are controlled by `timeout`.
|
||||
|
||||
Yields:
|
||||
Tuple containing:
|
||||
- read_stream: Stream for reading messages from the server
|
||||
- write_stream: Stream for sending messages to the server
|
||||
- get_session_id_callback: Function to retrieve the current session ID
|
||||
"""
|
||||
transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout, auth)
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
try:
|
||||
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
|
||||
|
||||
async with httpx_client_factory(
|
||||
headers=transport.request_headers,
|
||||
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
|
||||
auth=transport.auth,
|
||||
) as client:
|
||||
# Define callbacks that need access to tg
|
||||
def start_get_stream() -> None:
|
||||
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
|
||||
|
||||
tg.start_soon(
|
||||
transport.post_writer,
|
||||
client,
|
||||
write_stream_reader,
|
||||
read_stream_writer,
|
||||
write_stream,
|
||||
start_get_stream,
|
||||
tg,
|
||||
)
|
||||
|
||||
try:
|
||||
yield (
|
||||
read_stream,
|
||||
write_stream,
|
||||
transport.get_session_id,
|
||||
)
|
||||
finally:
|
||||
if transport.session_id and terminate_on_close:
|
||||
await transport.terminate_session(client)
|
||||
tg.cancel_scope.cancel()
|
||||
finally:
|
||||
await read_stream_writer.aclose()
|
||||
await write_stream.aclose()
|
||||
@@ -0,0 +1,86 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import ValidationError
|
||||
from websockets.asyncio.client import connect as ws_connect
|
||||
from websockets.typing import Subprotocol
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.shared.message import SessionMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def websocket_client(
|
||||
url: str,
|
||||
) -> AsyncGenerator[
|
||||
tuple[MemoryObjectReceiveStream[SessionMessage | Exception], MemoryObjectSendStream[SessionMessage]],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
WebSocket client transport for MCP, symmetrical to the server version.
|
||||
|
||||
Connects to 'url' using the 'mcp' subprotocol, then yields:
|
||||
(read_stream, write_stream)
|
||||
|
||||
- read_stream: As you read from this stream, you'll receive either valid
|
||||
JSONRPCMessage objects or Exception objects (when validation fails).
|
||||
- write_stream: Write JSONRPCMessage objects to this stream to send them
|
||||
over the WebSocket to the server.
|
||||
"""
|
||||
|
||||
# Create two in-memory streams:
|
||||
# - One for incoming messages (read_stream, written by ws_reader)
|
||||
# - One for outgoing messages (write_stream, read by ws_writer)
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
|
||||
write_stream: MemoryObjectSendStream[SessionMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
# Connect using websockets, requesting the "mcp" subprotocol
|
||||
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
|
||||
|
||||
async def ws_reader():
|
||||
"""
|
||||
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
|
||||
and sends them into read_stream_writer.
|
||||
"""
|
||||
async with read_stream_writer:
|
||||
async for raw_text in ws:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(raw_text)
|
||||
session_message = SessionMessage(message)
|
||||
await read_stream_writer.send(session_message)
|
||||
except ValidationError as exc: # pragma: no cover
|
||||
# If JSON parse or model validation fails, send the exception
|
||||
await read_stream_writer.send(exc)
|
||||
|
||||
async def ws_writer():
|
||||
"""
|
||||
Reads JSON-RPC messages from write_stream_reader and
|
||||
sends them to the server.
|
||||
"""
|
||||
async with write_stream_reader:
|
||||
async for session_message in write_stream_reader:
|
||||
# Convert to a dict, then to JSON
|
||||
msg_dict = session_message.message.model_dump(by_alias=True, mode="json", exclude_none=True)
|
||||
await ws.send(json.dumps(msg_dict))
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
# Start reader and writer tasks
|
||||
tg.start_soon(ws_reader)
|
||||
tg.start_soon(ws_writer)
|
||||
|
||||
# Yield the receive/send streams
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
# Once the caller's 'async with' block exits, we shut down
|
||||
tg.cancel_scope.cancel()
|
||||
Reference in New Issue
Block a user