Files
DronePlanning/backend_service/venv/lib/python3.13/site-packages/mcp/client/sse.py
huangfu c4f851d387 chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
2025-12-03 10:19:25 +08:00

149 lines
7.2 KiB
Python

import logging
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import urljoin, urlparse
import anyio
import httpx
from anyio.abc import TaskStatus
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from httpx_sse import aconnect_sse
from httpx_sse._exceptions import SSEError
import mcp.types as types
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
from mcp.shared.message import SessionMessage
logger = logging.getLogger(__name__)
def remove_request_params(url: str) -> str:
return urljoin(url, urlparse(url).path)
@asynccontextmanager
async def sse_client(
url: str,
headers: dict[str, Any] | None = None,
timeout: float = 5,
sse_read_timeout: float = 60 * 5,
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
auth: httpx.Auth | None = None,
):
"""
Client transport for SSE.
`sse_read_timeout` determines how long (in seconds) the client will wait for a new
event before disconnecting. All other HTTP operations are controlled by `timeout`.
Args:
url: The SSE endpoint URL.
headers: Optional headers to include in requests.
timeout: HTTP timeout for regular operations.
sse_read_timeout: Timeout for SSE read operations.
auth: Optional HTTPX authentication handler.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
async with httpx_client_factory(
headers=headers, auth=auth, timeout=httpx.Timeout(timeout, read=sse_read_timeout)
) as client:
async with aconnect_sse(
client,
"GET",
url,
) as event_source:
event_source.response.raise_for_status()
logger.debug("SSE connection established")
async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
try:
async for sse in event_source.aiter_sse(): # pragma: no branch
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")
url_parsed = urlparse(url)
endpoint_parsed = urlparse(endpoint_url)
if ( # pragma: no cover
url_parsed.netloc != endpoint_parsed.netloc
or url_parsed.scheme != endpoint_parsed.scheme
):
error_msg = ( # pragma: no cover
f"Endpoint origin does not match connection origin: {endpoint_url}"
)
logger.error(error_msg) # pragma: no cover
raise ValueError(error_msg) # pragma: no cover
task_status.started(endpoint_url)
case "message":
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
sse.data
)
logger.debug(f"Received server message: {message}")
except Exception as exc: # pragma: no cover
logger.exception("Error parsing server message") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
continue # pragma: no cover
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
case _: # pragma: no cover
logger.warning(f"Unknown SSE event: {sse.event}") # pragma: no cover
except SSEError as sse_exc: # pragma: no cover
logger.exception("Encountered SSE exception") # pragma: no cover
raise sse_exc # pragma: no cover
except Exception as exc: # pragma: no cover
logger.exception("Error in sse_reader") # pragma: no cover
await read_stream_writer.send(exc) # pragma: no cover
finally:
await read_stream_writer.aclose()
async def post_writer(endpoint_url: str):
try:
async with write_stream_reader:
async for session_message in write_stream_reader:
logger.debug(f"Sending client message: {session_message}")
response = await client.post(
endpoint_url,
json=session_message.message.model_dump(
by_alias=True,
mode="json",
exclude_none=True,
),
)
response.raise_for_status()
logger.debug(f"Client message sent successfully: {response.status_code}")
except Exception: # pragma: no cover
logger.exception("Error in post_writer") # pragma: no cover
finally:
await write_stream.aclose()
endpoint_url = await tg.start(sse_reader)
logger.debug(f"Starting post writer with endpoint URL: {endpoint_url}")
tg.start_soon(post_writer, endpoint_url)
try:
yield read_stream, write_stream
finally:
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()