chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
ServerSession Module
|
||||
|
||||
This module provides the ServerSession class, which manages communication between the
|
||||
server and client in the MCP (Model Context Protocol) framework. It is most commonly
|
||||
used in MCP servers to interact with the client.
|
||||
|
||||
Common usage pattern:
|
||||
```
|
||||
server = Server(name)
|
||||
|
||||
@server.call_tool()
|
||||
async def handle_tool_call(ctx: RequestContext, arguments: dict[str, Any]) -> Any:
|
||||
# Check client capabilities before proceeding
|
||||
if ctx.session.check_client_capability(
|
||||
types.ClientCapabilities(experimental={"advanced_tools": dict()})
|
||||
):
|
||||
# Perform advanced tool operations
|
||||
result = await perform_advanced_tool_operation(arguments)
|
||||
else:
|
||||
# Fall back to basic tool operations
|
||||
result = await perform_basic_tool_operation(arguments)
|
||||
|
||||
return result
|
||||
|
||||
@server.list_prompts()
|
||||
async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
|
||||
# Access session for any necessary checks or operations
|
||||
if ctx.session.client_params:
|
||||
# Customize prompts based on client initialization parameters
|
||||
return generate_custom_prompts(ctx.session.client_params)
|
||||
else:
|
||||
return default_prompts
|
||||
```
|
||||
|
||||
The ServerSession class is typically used internally by the Server class and should not
|
||||
be instantiated directly by users of the MCP framework.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from pydantic import AnyUrl
|
||||
|
||||
import mcp.types as types
|
||||
from mcp.server.models import InitializationOptions
|
||||
from mcp.shared.message import ServerMessageMetadata, SessionMessage
|
||||
from mcp.shared.session import (
|
||||
BaseSession,
|
||||
RequestResponder,
|
||||
)
|
||||
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
||||
|
||||
|
||||
class InitializationState(Enum):
|
||||
NotInitialized = 1
|
||||
Initializing = 2
|
||||
Initialized = 3
|
||||
|
||||
|
||||
ServerSessionT = TypeVar("ServerSessionT", bound="ServerSession")
|
||||
|
||||
ServerRequestResponder = (
|
||||
RequestResponder[types.ClientRequest, types.ServerResult] | types.ClientNotification | Exception
|
||||
)
|
||||
|
||||
|
||||
class ServerSession(
|
||||
BaseSession[
|
||||
types.ServerRequest,
|
||||
types.ServerNotification,
|
||||
types.ServerResult,
|
||||
types.ClientRequest,
|
||||
types.ClientNotification,
|
||||
]
|
||||
):
|
||||
_initialized: InitializationState = InitializationState.NotInitialized
|
||||
_client_params: types.InitializeRequestParams | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception],
|
||||
write_stream: MemoryObjectSendStream[SessionMessage],
|
||||
init_options: InitializationOptions,
|
||||
stateless: bool = False,
|
||||
) -> None:
|
||||
super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification)
|
||||
self._initialization_state = (
|
||||
InitializationState.Initialized if stateless else InitializationState.NotInitialized
|
||||
)
|
||||
|
||||
self._init_options = init_options
|
||||
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
|
||||
ServerRequestResponder
|
||||
](0)
|
||||
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
|
||||
|
||||
@property
|
||||
def client_params(self) -> types.InitializeRequestParams | None:
|
||||
return self._client_params # pragma: no cover
|
||||
|
||||
def check_client_capability(self, capability: types.ClientCapabilities) -> bool: # pragma: no cover
|
||||
"""Check if the client supports a specific capability."""
|
||||
if self._client_params is None:
|
||||
return False
|
||||
|
||||
# Get client capabilities from initialization params
|
||||
client_caps = self._client_params.capabilities
|
||||
|
||||
# Check each specified capability in the passed in capability object
|
||||
if capability.roots is not None:
|
||||
if client_caps.roots is None:
|
||||
return False
|
||||
if capability.roots.listChanged and not client_caps.roots.listChanged:
|
||||
return False
|
||||
|
||||
if capability.sampling is not None:
|
||||
if client_caps.sampling is None:
|
||||
return False
|
||||
|
||||
if capability.elicitation is not None:
|
||||
if client_caps.elicitation is None:
|
||||
return False
|
||||
|
||||
if capability.experimental is not None:
|
||||
if client_caps.experimental is None:
|
||||
return False
|
||||
# Check each experimental capability
|
||||
for exp_key, exp_value in capability.experimental.items():
|
||||
if exp_key not in client_caps.experimental or client_caps.experimental[exp_key] != exp_value:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
async with self._incoming_message_stream_writer:
|
||||
await super()._receive_loop()
|
||||
|
||||
async def _received_request(self, responder: RequestResponder[types.ClientRequest, types.ServerResult]):
|
||||
match responder.request.root:
|
||||
case types.InitializeRequest(params=params):
|
||||
requested_version = params.protocolVersion
|
||||
self._initialization_state = InitializationState.Initializing
|
||||
self._client_params = params
|
||||
with responder:
|
||||
await responder.respond(
|
||||
types.ServerResult(
|
||||
types.InitializeResult(
|
||||
protocolVersion=requested_version
|
||||
if requested_version in SUPPORTED_PROTOCOL_VERSIONS
|
||||
else types.LATEST_PROTOCOL_VERSION,
|
||||
capabilities=self._init_options.capabilities,
|
||||
serverInfo=types.Implementation(
|
||||
name=self._init_options.server_name,
|
||||
version=self._init_options.server_version,
|
||||
websiteUrl=self._init_options.website_url,
|
||||
icons=self._init_options.icons,
|
||||
),
|
||||
instructions=self._init_options.instructions,
|
||||
)
|
||||
)
|
||||
)
|
||||
self._initialization_state = InitializationState.Initialized
|
||||
case types.PingRequest():
|
||||
# Ping requests are allowed at any time
|
||||
pass
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized:
|
||||
raise RuntimeError("Received request before initialization was complete")
|
||||
|
||||
async def _received_notification(self, notification: types.ClientNotification) -> None:
|
||||
# Need this to avoid ASYNC910
|
||||
await anyio.lowlevel.checkpoint()
|
||||
match notification.root:
|
||||
case types.InitializedNotification():
|
||||
self._initialization_state = InitializationState.Initialized
|
||||
case _:
|
||||
if self._initialization_state != InitializationState.Initialized: # pragma: no cover
|
||||
raise RuntimeError("Received notification before initialization was complete")
|
||||
|
||||
async def send_log_message(
|
||||
self,
|
||||
level: types.LoggingLevel,
|
||||
data: Any,
|
||||
logger: str | None = None,
|
||||
related_request_id: types.RequestId | None = None,
|
||||
) -> None:
|
||||
"""Send a log message notification."""
|
||||
await self.send_notification(
|
||||
types.ServerNotification(
|
||||
types.LoggingMessageNotification(
|
||||
params=types.LoggingMessageNotificationParams(
|
||||
level=level,
|
||||
data=data,
|
||||
logger=logger,
|
||||
),
|
||||
)
|
||||
),
|
||||
related_request_id,
|
||||
)
|
||||
|
||||
async def send_resource_updated(self, uri: AnyUrl) -> None: # pragma: no cover
|
||||
"""Send a resource updated notification."""
|
||||
await self.send_notification(
|
||||
types.ServerNotification(
|
||||
types.ResourceUpdatedNotification(
|
||||
params=types.ResourceUpdatedNotificationParams(uri=uri),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
messages: list[types.SamplingMessage],
|
||||
*,
|
||||
max_tokens: int,
|
||||
system_prompt: str | None = None,
|
||||
include_context: types.IncludeContext | None = None,
|
||||
temperature: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
model_preferences: types.ModelPreferences | None = None,
|
||||
related_request_id: types.RequestId | None = None,
|
||||
) -> types.CreateMessageResult:
|
||||
"""Send a sampling/create_message request."""
|
||||
return await self.send_request(
|
||||
request=types.ServerRequest(
|
||||
types.CreateMessageRequest(
|
||||
params=types.CreateMessageRequestParams(
|
||||
messages=messages,
|
||||
systemPrompt=system_prompt,
|
||||
includeContext=include_context,
|
||||
temperature=temperature,
|
||||
maxTokens=max_tokens,
|
||||
stopSequences=stop_sequences,
|
||||
metadata=metadata,
|
||||
modelPreferences=model_preferences,
|
||||
),
|
||||
)
|
||||
),
|
||||
result_type=types.CreateMessageResult,
|
||||
metadata=ServerMessageMetadata(
|
||||
related_request_id=related_request_id,
|
||||
),
|
||||
)
|
||||
|
||||
async def list_roots(self) -> types.ListRootsResult:
|
||||
"""Send a roots/list request."""
|
||||
return await self.send_request(
|
||||
types.ServerRequest(types.ListRootsRequest()),
|
||||
types.ListRootsResult,
|
||||
)
|
||||
|
||||
async def elicit(
|
||||
self,
|
||||
message: str,
|
||||
requestedSchema: types.ElicitRequestedSchema,
|
||||
related_request_id: types.RequestId | None = None,
|
||||
) -> types.ElicitResult:
|
||||
"""Send an elicitation/create request.
|
||||
|
||||
Args:
|
||||
message: The message to present to the user
|
||||
requestedSchema: Schema defining the expected response structure
|
||||
|
||||
Returns:
|
||||
The client's response
|
||||
"""
|
||||
return await self.send_request(
|
||||
types.ServerRequest(
|
||||
types.ElicitRequest(
|
||||
params=types.ElicitRequestParams(
|
||||
message=message,
|
||||
requestedSchema=requestedSchema,
|
||||
),
|
||||
)
|
||||
),
|
||||
types.ElicitResult,
|
||||
metadata=ServerMessageMetadata(related_request_id=related_request_id),
|
||||
)
|
||||
|
||||
async def send_ping(self) -> types.EmptyResult: # pragma: no cover
|
||||
"""Send a ping request."""
|
||||
return await self.send_request(
|
||||
types.ServerRequest(types.PingRequest()),
|
||||
types.EmptyResult,
|
||||
)
|
||||
|
||||
async def send_progress_notification(
|
||||
self,
|
||||
progress_token: str | int,
|
||||
progress: float,
|
||||
total: float | None = None,
|
||||
message: str | None = None,
|
||||
related_request_id: str | None = None,
|
||||
) -> None:
|
||||
"""Send a progress notification."""
|
||||
await self.send_notification(
|
||||
types.ServerNotification(
|
||||
types.ProgressNotification(
|
||||
params=types.ProgressNotificationParams(
|
||||
progressToken=progress_token,
|
||||
progress=progress,
|
||||
total=total,
|
||||
message=message,
|
||||
),
|
||||
)
|
||||
),
|
||||
related_request_id,
|
||||
)
|
||||
|
||||
async def send_resource_list_changed(self) -> None: # pragma: no cover
|
||||
"""Send a resource list changed notification."""
|
||||
await self.send_notification(types.ServerNotification(types.ResourceListChangedNotification()))
|
||||
|
||||
async def send_tool_list_changed(self) -> None: # pragma: no cover
|
||||
"""Send a tool list changed notification."""
|
||||
await self.send_notification(types.ServerNotification(types.ToolListChangedNotification()))
|
||||
|
||||
async def send_prompt_list_changed(self) -> None: # pragma: no cover
|
||||
"""Send a prompt list changed notification."""
|
||||
await self.send_notification(types.ServerNotification(types.PromptListChangedNotification()))
|
||||
|
||||
async def _handle_incoming(self, req: ServerRequestResponder) -> None:
|
||||
await self._incoming_message_stream_writer.send(req)
|
||||
|
||||
@property
|
||||
def incoming_messages(
|
||||
self,
|
||||
) -> MemoryObjectReceiveStream[ServerRequestResponder]:
|
||||
return self._incoming_message_stream_reader
|
||||
Reference in New Issue
Block a user