From bca59f2230294ef483612c8cc19171161ab3c3b1 Mon Sep 17 00:00:00 2001 From: hanhan123 Date: Sun, 13 Jul 2025 01:05:19 +0800 Subject: [PATCH] feat: add client connection manager to manage multiple sessions without requiring async with --- .../clients/client-connection-manager/main.py | 37 +++ src/mcp/client/client_connection_manager.py | 294 ++++++++++++++++++ src/mcp/client/exceptions.py | 2 + src/mcp/types.py | 9 + .../client/test_client_connection_manager.py | 229 ++++++++++++++ 5 files changed, 571 insertions(+) create mode 100644 examples/clients/client-connection-manager/main.py create mode 100644 src/mcp/client/client_connection_manager.py create mode 100644 src/mcp/client/exceptions.py create mode 100644 tests/client/test_client_connection_manager.py diff --git a/examples/clients/client-connection-manager/main.py b/examples/clients/client-connection-manager/main.py new file mode 100644 index 000000000..a5bc60554 --- /dev/null +++ b/examples/clients/client-connection-manager/main.py @@ -0,0 +1,37 @@ +import asyncio + +from mcp.client.client_connection_manager import ClientConnectionManager, StreamalbeHttpClientParams + + +async def main(): + s1_name = "s1_name" + s2_name = "s2_name" + s1 = StreamalbeHttpClientParams(name=s1_name, url="http://localhost:8910/mcp/") + s2 = StreamalbeHttpClientParams(name=s2_name, url="http://localhost:8910/mcp/") + + m = ClientConnectionManager() + + await m.connect(s1) + await m.connect(s2) + + print("---session initialize---") + + await m.session_initialize(s1_name) + await m.session_initialize(s2_name) + await asyncio.sleep(1) + + print("---session list tools---") + res = await m.session_list_tools(s1_name) + + await asyncio.sleep(1) + print("---session call tool---") + res = await m.session_call_tool(s1_name, "create_user") + print(res) + await asyncio.sleep(3) + print("---session disconnect---") + await m.disconnect(s1_name) + # await m.cleanup(s2_name) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/mcp/client/client_connection_manager.py b/src/mcp/client/client_connection_manager.py new file mode 100644 index 000000000..3dbbe0309 --- /dev/null +++ b/src/mcp/client/client_connection_manager.py @@ -0,0 +1,294 @@ +import asyncio +import logging +from collections.abc import Coroutine +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import Any, TypeVar + +from pydantic import AnyUrl, BaseModel, ConfigDict, Field + +import mcp +from mcp import types +from mcp.client.exceptions import ConnectTimeOut +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.exceptions import McpError +from mcp.shared.session import ProgressFnT +from mcp.types import StreamalbeHttpClientParams + +logger = logging.getLogger(__name__) + +R = TypeVar("R") + + +class ClientSessionState(BaseModel): + session: mcp.ClientSession | None = None + lifespan_task: asyncio.Task[Any] | None = None + running_event: asyncio.Event = Field(default_factory=asyncio.Event) + error: Exception | None = None + request_task: dict[str, asyncio.Task[Any]] = Field(default_factory=dict) + model_config = ConfigDict(arbitrary_types_allowed=True) + + @property + def lifespan(self) -> asyncio.Task[Any]: + if self.lifespan_task is None: + raise RuntimeError("lifespan_task is not set") + return self.lifespan_task + + @property + def active_session(self) -> mcp.ClientSession: + if self.session is None: + raise RuntimeError("session is not set") + return self.session + + +class ClientConnectionManager: + def __init__( + self, + ): + self._session: dict[str, ClientSessionState] = {} + + async def connect(self, parameter: StreamalbeHttpClientParams): + logger.info(f"Attempting to connect to MCP server: {parameter.name} ({parameter.url})") + state = ClientSessionState() + if not self._is_session_exists(parameter.name): + self._session[parameter.name] = state + logger.debug(f"Session state created for: {parameter.name}") + else: + raise McpError( + types.ErrorData( + code=types.CONNECTION_CLOSED, + message=f"Session with name '{parameter.name}' already exists. \ + Duplicate connections are not allowed.", + ) + ) + ready_future = asyncio.get_running_loop().create_future() + + task = asyncio.create_task(self._maintain_session(parameter, ready_future)) + state.lifespan_task = task + + try: + await asyncio.wait_for(ready_future, timeout=5) + except asyncio.TimeoutError: + task.cancel() + try: + await task # 等待 task 真正結束或取消 + except asyncio.CancelledError: + pass + state.error = ConnectTimeOut(f"Connection to {parameter.name} timed out") + raise state.error + except Exception as e: + task.cancel() + state.error = e + raise e + + async def _maintain_session(self, parameter: StreamalbeHttpClientParams, connect_res: asyncio.Future[Any]): + try: + async with self._session_context(parameter): + if not connect_res.done(): + connect_res.set_result(True) + + logger.debug(f"Session maintenance started for: {parameter.name}. Waiting for shutdown event...") + await self._session[parameter.name].running_event.wait() + logger.info(f"Graceful shutdown initiated for session: {parameter.name}") + + except Exception as e: + if not connect_res.done(): + connect_res.set_exception(e) + self._session[parameter.name].running_event.set() + self._session[parameter.name].error = e + raise e + + @asynccontextmanager + async def _session_context(self, parameter: StreamalbeHttpClientParams): + try: + async with streamablehttp_client(parameter.url) as streams: + read_stream, write_stream, _ = streams + async with mcp.ClientSession(read_stream, write_stream) as session: + state = self._session[parameter.name] + state.session = session + + logger.info(f"Connected to MCP server: {parameter.name} ({parameter.url})") + yield + logger.info(f"MCP server {parameter.name} ({parameter.url}): disconnected") + + except Exception as e: + raise e + + def _is_session_exists(self, session_name: str) -> bool: + if session_name in self._session: + return True + return False + + def _validate_session(self, session_name: str) -> ClientSessionState: + if self._is_session_exists(session_name): + state = self._session[session_name] + if state.error: + raise McpError( + types.ErrorData( + code=types.CONNECTION_CLOSED, + message=f"Session with name '{session_name}' has error. {state.error}", + ) + ) + return state + else: + raise McpError( + types.ErrorData( + code=types.CONNECTION_CLOSED, + message=f"Session with name '{session_name}' does not exist. Please establish a connection first.", + ) + ) + + async def _safe_run_task(self, session_name: str, task_cor: Coroutine[Any, Any, R]) -> R: + actived_task = asyncio.create_task(task_cor) + + async def monitor(): + await asyncio.sleep(0.1) + while not actived_task.done(): + if self._session[session_name].error is not None: + actived_task.cancel() + break + + await asyncio.sleep(2) + + asyncio.create_task(monitor()) + try: + res = await actived_task + except asyncio.exceptions.CancelledError as err: + session_err = self._session[session_name].error + if session_err is not None: + raise session_err + raise err + # except Exception as err: + # raise err + return res + + async def session_initialize(self, session_name: str) -> types.InitializeResult: + session_state = self._validate_session(session_name) + + try: + res = await self._safe_run_task(session_name, session_state.active_session.initialize()) + + except Exception as e: + raise e + + return res + + async def session_send_pings(self, session_name: str) -> types.EmptyResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task(session_name, session_state.active_session.send_ping()) + + async def session_send_progress_notification( + self, + session_name: str, + progress_token: str | int, + progress: float, + total: float | None = None, + message: str | None = None, + ) -> None: + session_state = self._validate_session(session_name) + return await self._safe_run_task( + session_name, + session_state.active_session.send_progress_notification(progress_token, progress, total, message), + ) + + async def session_set_logging_level(self, session_name: str, level: types.LoggingLevel) -> types.EmptyResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task(session_name, session_state.active_session.set_logging_level(level)) + + async def session_list_resources(self, session_name: str, cursor: str | None = None) -> types.ListResourcesResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task( + session_name, + session_state.active_session.list_resources(cursor), + ) + + async def session_list_resource_templates( + self, session_name: str, cursor: str | None = None + ) -> types.ListResourceTemplatesResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task( + session_name, + session_state.active_session.list_resource_templates(cursor), + ) + + async def session_read_resource(self, session_name: str, uri: AnyUrl) -> types.ReadResourceResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task( + session_name, + session_state.active_session.read_resource(uri), + ) + + async def session_subscribe_resource(self, session_name: str, uri: AnyUrl) -> types.EmptyResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task( + session_name, + session_state.active_session.subscribe_resource(uri), + ) + + async def session_unsubscribe_resource(self, session_name: str, uri: AnyUrl) -> types.EmptyResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task( + session_name, + session_state.active_session.unsubscribe_resource(uri), + ) + + async def session_call_tool( + self, + session_name: str, + name: str, + arguments: dict[str, Any] | None = None, + read_timeout_seconds: timedelta | None = None, + progress_callback: ProgressFnT | None = None, + ) -> types.CallToolResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task( + session_name, + session_state.active_session.call_tool(name, arguments, read_timeout_seconds, progress_callback), + ) + + async def session_list_prompts(self, session_name: str, cursor: str | None = None) -> types.ListPromptsResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task( + session_name, + session_state.active_session.list_prompts(cursor), + ) + + async def session_get_prompt( + self, session_name: str, name: str, arguments: dict[str, str] | None = None + ) -> types.GetPromptResult: + session_state = self._validate_session(session_name) + return await self._safe_run_task( + session_name, + session_state.active_session.get_prompt(name, arguments), + ) + + async def session_list_tools(self, session_name: str, cursor: str | None = None) -> types.ListToolsResult: + session_state = self._validate_session(session_name) + + return await self._safe_run_task(session_name, session_state.active_session.list_tools(cursor)) + + async def session_send_roots_list_changed(self, session_name: str) -> None: + session_state = self._validate_session(session_name) + + return await self._safe_run_task(session_name, session_state.active_session.send_roots_list_changed()) + + async def disconnect(self, name: str) -> None: + session = self._session[name] + if not session.session: + return + + if session.lifespan_task and not session.lifespan_task.done(): + session.running_event.set() + + try: + await session.lifespan + except Exception as e: + raise McpError( + types.ErrorData( + code=types.CONNECTION_CLOSED, + message=f"MCP server {name} disconnect failed {e}", + ) + ) + finally: + session.session = None + session.lifespan_task = None diff --git a/src/mcp/client/exceptions.py b/src/mcp/client/exceptions.py new file mode 100644 index 000000000..ae198e923 --- /dev/null +++ b/src/mcp/client/exceptions.py @@ -0,0 +1,2 @@ +class ConnectTimeOut(Exception): + """Failed to connect: timeout""" diff --git a/src/mcp/types.py b/src/mcp/types.py index 4a9c2bf1a..faf97e88d 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1,4 +1,5 @@ from collections.abc import Callable +from datetime import timedelta from typing import Annotated, Any, Generic, Literal, TypeAlias, TypeVar from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel @@ -1310,3 +1311,11 @@ class ServerResult( ] ): pass + + +class StreamalbeHttpClientParams(BaseModel): + name: str + url: str + headers: dict[str, Any] | None = None + timeout: timedelta = timedelta(seconds=30) + terminate_on_close: bool = True diff --git a/tests/client/test_client_connection_manager.py b/tests/client/test_client_connection_manager.py new file mode 100644 index 000000000..f0655372e --- /dev/null +++ b/tests/client/test_client_connection_manager.py @@ -0,0 +1,229 @@ +import asyncio +from unittest.mock import AsyncMock, patch + +import pytest + +import mcp +from mcp import types +from mcp.client.client_connection_manager import ClientConnectionManager, ClientSessionState +from mcp.client.exceptions import ConnectTimeOut +from mcp.shared.exceptions import McpError +from mcp.types import StreamalbeHttpClientParams + + +@pytest.fixture +def manager(): + return ClientConnectionManager() + + +@pytest.fixture +def mock_mcp_client_session(): + mock_session = AsyncMock(spec=mcp.ClientSession) + mock_session.initialize.return_value = types.InitializeResult( + capabilities=types.ServerCapabilities(), + serverInfo=types.Implementation(name="1", version="2"), + protocolVersion="1", + ) + mock_session.send_ping.return_value = types.EmptyResult() + mock_session.set_logging_level.return_value = types.EmptyResult() + mock_session.list_resources.return_value = types.ListResourcesResult(resources=[]) + mock_session.read_resource.return_value = types.ReadResourceResult(contents=[]) + mock_session.call_tool.return_value = types.CallToolResult(content=[]) + mock_session.list_tools.return_value = types.ListToolsResult(tools=[]) + mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[]) + mock_session.get_prompt.return_value = types.GetPromptResult(messages=[]) + mock_session.send_roots_list_changed.return_value = None + mock_session.subscribe_resource.return_value = types.EmptyResult() + mock_session.unsubscribe_resource.return_value = types.EmptyResult() + mock_session.send_progress_notification.return_value = None + mock_session.__aenter__.return_value = mock_session + mock_session.__aexit__.return_value = None + + return mock_session + + +@pytest.fixture +def mock_streamable_http_client(): + mock_read_stream = AsyncMock() + mock_write_stream = AsyncMock() + get_session_id_callback = AsyncMock() + mock_streams = (mock_read_stream, mock_write_stream, get_session_id_callback) + + mock_client_context = AsyncMock() + mock_client_context.__aenter__.return_value = mock_streams + mock_client_context.__aexit__.return_value = None + + with patch( + "mcp.client.client_connection_manager.streamablehttp_client", return_value=mock_client_context + ) as mock_streamable_http_client: + yield mock_streamable_http_client + + +@pytest.mark.anyio +async def test_connect_success(manager, mock_streamable_http_client, mock_mcp_client_session): + session_name = "test_session_1" + url = "http://mock1:8000/mcp/" + + param = StreamalbeHttpClientParams(name=session_name, url=url) + + with patch("mcp.client.client_connection_manager.mcp.ClientSession", return_value=mock_mcp_client_session): + await manager.connect(param) + + assert session_name in manager._session + + state = manager._session[session_name] + assert state.session is mock_mcp_client_session + assert state.lifespan_task is not None + assert not state.running_event.is_set() + assert state.error is None + + mock_streamable_http_client.assert_called_once_with(param.url) + + await manager.disconnect(session_name) + assert manager._session[session_name].session is None + assert manager._session[session_name].lifespan_task is None + + +@pytest.mark.anyio +async def test_connect_duplicate_session_fails(manager, mock_streamable_http_client): + session_name = "test_session_duplicate" + + param = StreamalbeHttpClientParams(name=session_name, url="http://localhost:8080") + + manager._session[session_name] = ClientSessionState() + + with pytest.raises(McpError) as excinfo: + await manager.connect(param) + + assert "already exists" in str(excinfo.value) + + mock_streamable_http_client.assert_not_called() + + +@pytest.mark.anyio +async def test_connect_timeout_during_startup(manager, mock_streamable_http_client): + session_name = "test_session_timeout" + + param = StreamalbeHttpClientParams(name=session_name, url="http://localhost:8080") + + async def never_set_result(*args, **kwargs): + await asyncio.sleep(10) + + with patch.object(manager, "_maintain_session", AsyncMock(side_effect=never_set_result)): + with pytest.raises(ConnectTimeOut): + await manager.connect(param) + + assert session_name in manager._session + + assert manager._session[session_name].lifespan_task.done() + assert str(manager._session[session_name].error) == f"Connection to {param.name} timed out" + + assert manager._session[session_name].lifespan_task.cancelled() + + +@pytest.mark.anyio +async def test_disconnect_success(manager, mock_streamable_http_client, mock_mcp_client_session): + session_name = "test_session_disconnect" + + param = StreamalbeHttpClientParams(name=session_name, url="http://localhost:8080") + + with patch("mcp.client.client_connection_manager.mcp.ClientSession", return_value=mock_mcp_client_session): + await manager.connect(param) + + assert session_name in manager._session + + assert manager._session[session_name].session is mock_mcp_client_session + + assert manager._session[session_name].lifespan_task is not None + + # disconnnect + + await manager.disconnect(session_name) + assert manager._session[session_name].session is None + assert manager._session[session_name].lifespan_task is None + assert manager._session[session_name].running_event.is_set() + + mock_mcp_client_session.__aexit__.assert_called_once() + mock_streamable_http_client.return_value.__aexit__.assert_called_once() + + +@pytest.mark.anyio +async def test_session_initialize_success(manager, mock_streamable_http_client, mock_mcp_client_session): + session_name = "test_session_init" + + param = StreamalbeHttpClientParams(name=session_name, url="http://localhost:8080") + + with patch("mcp.client.client_connection_manager.mcp.ClientSession", return_value=mock_mcp_client_session): + await manager.connect(param) + await manager.session_initialize(session_name) + + mock_mcp_client_session.initialize.assert_called_once() + + +@pytest.mark.anyio +async def test_session_initialize_no_session(manager): + with pytest.raises(McpError) as excinfo: + await manager.session_initialize("non_existent_session") + + assert "does not exist" in str(excinfo.value) + + +@pytest.mark.anyio +async def test_session_initialize_with_error_state(manager, mock_streamable_http_client, mock_mcp_client_session): + session_name = "test_session_error_state" + + param = StreamalbeHttpClientParams(name=session_name, url="http://localhost:8080") + + with patch("mcp.client.client_connection_manager.mcp.ClientSession", return_value=mock_mcp_client_session): + await manager.connect(param) + + manager._session[session_name].error = RuntimeError("Simulated session error") + + with pytest.raises(McpError) as excinfo: + await manager.session_initialize(session_name) + + assert "has error" in str(excinfo.value) + + mock_mcp_client_session.initialize.assert_not_called() + + +@pytest.mark.anyio +async def test_safe_run_task_propagates_session_error(manager, mock_streamable_http_client, mock_mcp_client_session): + session_name = "test_safe_run_task_error" + + state = ClientSessionState() + state.session = mock_mcp_client_session + manager._session[session_name] = state + + async def mock_long_running_task(): + await asyncio.sleep(100) + + task_to_test = mock_long_running_task() + + safe_run_task_handle = asyncio.create_task(manager._safe_run_task(session_name, task_to_test)) + + await asyncio.sleep(0.2) + simulated_error = McpError(types.ErrorData(code=types.CONNECTION_CLOSED, message="Simulated network error")) + manager._session[session_name].error = simulated_error + + with pytest.raises(McpError) as excinfo: + await safe_run_task_handle + + assert excinfo.value == simulated_error + assert safe_run_task_handle.done() + assert manager._session[session_name].error is simulated_error + print(manager._session[session_name].running_event.is_set()) + + +@pytest.mark.anyio +async def test_maintain_session_handles_context_exception(manager): + session_name = "test_session_error_state" + + param = StreamalbeHttpClientParams(name=session_name, url="http://localhost:8080") + + await manager.connect(param) + + with pytest.raises(Exception): + await manager.session_initialize(session_name) + + assert manager._session[session_name].lifespan_task.done()