diff --git a/.circleci/config.yml b/.circleci/config.yml index 28a1f3f5aab5..6e51e9ac7141 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -785,7 +785,7 @@ jobs: pip install "pytest-asyncio==0.21.1" pip install "respx==0.22.0" pip install "pydantic==2.10.2" - pip install "mcp==1.5.0" + pip install "mcp==1.9.3" # Run pytest and generate JUnit XML report - run: name: Run tests @@ -920,7 +920,7 @@ jobs: pip install "respx==0.22.0" pip install "hypercorn==0.17.3" pip install "pydantic==2.10.2" - pip install "mcp==1.5.0" + pip install "mcp==1.9.3" pip install "requests-mock>=1.12.1" pip install "responses==0.25.7" pip install "pytest-xdist==3.6.1" diff --git a/.circleci/requirements.txt b/.circleci/requirements.txt index b720d15a7fda..dbd4fd9d2d58 100644 --- a/.circleci/requirements.txt +++ b/.circleci/requirements.txt @@ -12,4 +12,4 @@ pydantic==2.10.2 google-cloud-aiplatform==1.43.0 fastapi-sso==0.16.0 uvloop==0.21.0 -mcp==1.5.0 # for MCP server +mcp==1.9.3 # for MCP server diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index c29d88148195..e362cb303333 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -1,7 +1,7 @@ """ MCP Client Manager -This class is responsible for managing MCP SSE clients. +This class is responsible for managing MCP clients with support for both SSE and HTTP streamable transports. This is a Proxy """ @@ -25,6 +25,12 @@ MCPTransport, MCPTransportType, ) + +try: + from mcp.client.streamable_http import streamablehttp_client +except ImportError: + streamablehttp_client = None # type: ignore + from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer @@ -169,16 +175,43 @@ async def _get_tools_from_server(self, server: MCPServer) -> List[MCPTool]: # Update tool to server mapping for tool in tools_result.tools: - self.tool_name_to_mcp_server_name_mapping[ - tool.name - ] = server.name + self.tool_name_to_mcp_server_name_mapping[tool.name] = ( + server.name + ) return tools_result.tools elif server.transport == MCPTransport.http: - # TODO: implement http transport - return [] + if streamablehttp_client is None: + verbose_logger.error( + "streamablehttp_client not available - install mcp with HTTP support" + ) + raise ValueError( + "streamablehttp_client not available - please run `pip install mcp -U`" + ) + verbose_logger.debug(f"Using HTTP streamable transport for {server.url}") + async with streamablehttp_client( + url=server.url, + ) as (read_stream, write_stream, get_session_id): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + if get_session_id is not None: + session_id = get_session_id() + if session_id: + verbose_logger.debug(f"HTTP session ID: {session_id}") + + tools_result = await session.list_tools() + verbose_logger.debug(f"Tools from {server.name}: {tools_result}") + + # Update tool to server mapping + for tool in tools_result.tools: + self.tool_name_to_mcp_server_name_mapping[tool.name] = ( + server.name + ) + + return tools_result.tools else: - # TODO: throw error on transport found or skip + verbose_logger.warning(f"Unsupported transport type: {server.transport}") return [] def initialize_tool_name_to_mcp_server_name_mapping(self): @@ -217,8 +250,30 @@ async def call_tool(self, name: str, arguments: Dict[str, Any]): await session.initialize() return await session.call_tool(name, arguments) elif mcp_server.transport == MCPTransport.http: - # TODO: implement http transport - raise NotImplementedError("HTTP transport is not implemented yet") + if streamablehttp_client is None: + verbose_logger.error( + "streamablehttp_client not available - install mcp with HTTP support" + ) + raise ValueError( + "streamablehttp_client not available - please run `pip install mcp -U`" + ) + verbose_logger.debug( + f"Using HTTP streamable transport for tool call: {name}" + ) + async with streamablehttp_client( + url=mcp_server.url, + ) as (read_stream, write_stream, get_session_id): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + + if get_session_id is not None: + session_id = get_session_id() + if session_id: + verbose_logger.debug( + f"HTTP session ID for tool call: {session_id}" + ) + + return await session.call_tool(name, arguments) else: return CallToolResult(content=[], isError=True) diff --git a/poetry.lock b/poetry.lock index b2c035972873..1ffcb8d5021c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1966,15 +1966,15 @@ files = [ [[package]] name = "mcp" -version = "1.5.0" +version = "1.9.3" description = "Model Context Protocol SDK" optional = true python-versions = ">=3.10" groups = ["main"] markers = "python_version >= \"3.10\" and extra == \"proxy\"" files = [ - {file = "mcp-1.5.0-py3-none-any.whl", hash = "sha256:51c3f35ce93cb702f7513c12406bbea9665ef75a08db909200b07da9db641527"}, - {file = "mcp-1.5.0.tar.gz", hash = "sha256:5b2766c05e68e01a2034875e250139839498c61792163a7b221fc170c12f5aa9"}, + {file = "mcp-1.9.3-py3-none-any.whl", hash = "sha256:69b0136d1ac9927402ed4cf221d4b8ff875e7132b0b06edd446448766f34f9b9"}, + {file = "mcp-1.9.3.tar.gz", hash = "sha256:587ba38448e81885e5d1b84055cfcc0ca56d35cd0c58f50941cab01109405388"}, ] [package.dependencies] @@ -1983,9 +1983,10 @@ httpx = ">=0.27" httpx-sse = ">=0.4" pydantic = ">=2.7.2,<3.0.0" pydantic-settings = ">=2.5.2" +python-multipart = ">=0.0.9" sse-starlette = ">=1.6.1" starlette = ">=0.27" -uvicorn = ">=0.23.1" +uvicorn = {version = ">=0.23.1", markers = "sys_platform != \"emscripten\""} [package.extras] cli = ["python-dotenv (>=1.0.0)", "typer (>=0.12.4)"] @@ -4994,4 +4995,4 @@ utils = ["numpydoc"] [metadata] lock-version = "2.1" python-versions = ">=3.8.1,<4.0, !=3.9.7" -content-hash = "b681facfcabcb2085056838abec37e3c339fa940694b3b45485bc797d7dbfe1e" +content-hash = "55a9fa9dee2e3836205b692afb6429f0aa134fbfde15ea460bfb28f7dd0a85f1" diff --git a/pyproject.toml b/pyproject.toml index 5dcd37cf0643..32661d7f85b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ pynacl = {version = "^1.5.0", optional = true} websockets = {version = "^13.1.0", optional = true} boto3 = {version = "1.34.34", optional = true} redisvl = {version = "^0.4.1", optional = true, markers = "python_version >= '3.9' and python_version < '3.14'"} -mcp = {version = "1.5.0", optional = true, python = ">=3.10"} +mcp = {version = "1.9.3", optional = true, python = ">=3.10"} litellm-proxy-extras = {version = "0.2.3", optional = true} rich = {version = "13.7.1", optional = true} litellm-enterprise = {version = "0.1.7", optional = true} diff --git a/requirements.txt b/requirements.txt index 386decdeaaac..d99925cf7be3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,7 +15,7 @@ mangum==0.17.0 # for aws lambda functions pynacl==1.5.0 # for encrypting keys google-cloud-aiplatform==1.47.0 # for vertex ai calls anthropic[vertex]==0.21.3 -mcp==1.5.0 # for MCP server +mcp==1.9.3 # for MCP server google-generativeai==0.5.0 # for vertex ai calls async_generator==1.10.0 # for async ollama calls langfuse==2.45.0 # for langfuse self-hosted logging diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index 9659ab329776..a78805ca1510 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -2,6 +2,8 @@ import os import sys import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from contextlib import asynccontextmanager sys.path.insert( 0, os.path.abspath("../../..") @@ -10,7 +12,10 @@ from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( MCPServerManager, MCPServer, + MCPTransport, ) +from mcp.types import Tool as MCPTool, CallToolResult, ListToolsResult +from mcp.types import TextContent mcp_server_manager = MCPServerManager() @@ -33,3 +38,286 @@ async def test_mcp_server_manager(): name="gmail_send_email", arguments={"body": "Test"} ) print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result) + + +@pytest.mark.asyncio +async def test_mcp_server_manager_https_server(): + mcp_server_manager.load_servers_from_config( + { + "zapier_mcp_server": { + "url": os.environ.get("ZAPIER_MCP_HTTPS_SERVER_URL"), + "transport": MCPTransport.http, + } + } + ) + tools = await mcp_server_manager.list_tools() + print("TOOLS FROM MCP SERVER MANAGER== ", tools) + + result = await mcp_server_manager.call_tool( + name="gmail_send_email", + arguments={ + "body": "Test", + "message": "Test", + "instructions": "Test", + }, + ) + print("RESULT FROM CALLING TOOL FROM MCP SERVER MANAGER== ", result) + + +@pytest.mark.asyncio +async def test_mcp_http_transport_list_tools_mock(): + """Test HTTP transport list_tools functionality with mocked dependencies""" + + # Create a fresh manager for testing + test_manager = MCPServerManager() + + # Mock tools that should be returned + mock_tools = [ + MCPTool( + name="gmail_send_email", + description="Send an email via Gmail", + inputSchema={ + "type": "object", + "properties": { + "to": {"type": "string"}, + "subject": {"type": "string"}, + "body": {"type": "string"} + }, + "required": ["to", "subject", "body"] + } + ), + MCPTool( + name="calendar_create_event", + description="Create a calendar event", + inputSchema={ + "type": "object", + "properties": { + "title": {"type": "string"}, + "date": {"type": "string"}, + "time": {"type": "string"} + }, + "required": ["title", "date"] + } + ) + ] + + # Mock the session and its methods + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.list_tools = AsyncMock(return_value=ListToolsResult(tools=mock_tools)) + + # Create an async context manager mock for streamablehttp_client + @asynccontextmanager + async def mock_streamablehttp_client(url): + read_stream = AsyncMock() + write_stream = AsyncMock() + get_session_id = MagicMock(return_value="test-session-123") + yield (read_stream, write_stream, get_session_id) + + # Create an async context manager mock for ClientSession + @asynccontextmanager + async def mock_client_session(read_stream, write_stream): + yield mock_session + + with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.streamablehttp_client', mock_streamablehttp_client), \ + patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.ClientSession', mock_client_session): + + # Load server config with HTTP transport + test_manager.load_servers_from_config({ + "test_http_server": { + "url": "https://test-mcp-server.com/mcp", + "transport": MCPTransport.http, + "description": "Test HTTP MCP Server" + } + }) + + # Call list_tools + tools = await test_manager.list_tools() + + # Assertions + assert len(tools) == 2 + assert tools[0].name == "gmail_send_email" + assert tools[1].name == "calendar_create_event" + + # Verify session methods were called + mock_session.initialize.assert_called_once() + mock_session.list_tools.assert_called_once() + + # Verify tool mapping was updated + assert test_manager.tool_name_to_mcp_server_name_mapping["gmail_send_email"] == "test_http_server" + assert test_manager.tool_name_to_mcp_server_name_mapping["calendar_create_event"] == "test_http_server" + + +@pytest.mark.asyncio +async def test_mcp_http_transport_call_tool_mock(): + """Test HTTP transport call_tool functionality with mocked dependencies""" + + # Create a fresh manager for testing + test_manager = MCPServerManager() + + # Mock tool call result + mock_result = CallToolResult( + content=[ + TextContent( + type="text", + text="Email sent successfully to test@example.com" + ) + ], + isError=False + ) + + # Mock the session and its methods + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_result) + + # Create an async context manager mock for streamablehttp_client + @asynccontextmanager + async def mock_streamablehttp_client(url): + read_stream = AsyncMock() + write_stream = AsyncMock() + get_session_id = MagicMock(return_value="test-session-456") + yield (read_stream, write_stream, get_session_id) + + # Create an async context manager mock for ClientSession + @asynccontextmanager + async def mock_client_session(read_stream, write_stream): + yield mock_session + + with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.streamablehttp_client', mock_streamablehttp_client), \ + patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.ClientSession', mock_client_session): + + # Load server config with HTTP transport + test_manager.load_servers_from_config({ + "test_http_server": { + "url": "https://test-mcp-server.com/mcp", + "transport": MCPTransport.http, + "description": "Test HTTP MCP Server" + } + }) + + # Manually set up tool mapping (normally done by list_tools) + test_manager.tool_name_to_mcp_server_name_mapping["gmail_send_email"] = "test_http_server" + + # Call the tool + result = await test_manager.call_tool( + name="gmail_send_email", + arguments={ + "to": "test@example.com", + "subject": "Test Subject", + "body": "Test email body" + } + ) + + # Assertions + assert result.isError is False + assert len(result.content) == 1 + # Type check before accessing text attribute + assert isinstance(result.content[0], TextContent) + assert result.content[0].text == "Email sent successfully to test@example.com" + + # Verify session methods were called + mock_session.initialize.assert_called_once() + mock_session.call_tool.assert_called_once_with( + "gmail_send_email", + { + "to": "test@example.com", + "subject": "Test Subject", + "body": "Test email body" + } + ) + + +@pytest.mark.asyncio +async def test_mcp_http_transport_call_tool_error_mock(): + """Test HTTP transport call_tool error handling with mocked dependencies""" + + # Create a fresh manager for testing + test_manager = MCPServerManager() + + # Mock tool call error result + mock_error_result = CallToolResult( + content=[ + TextContent( + type="text", + text="Error: Invalid email address" + ) + ], + isError=True + ) + + # Mock the session and its methods + mock_session = AsyncMock() + mock_session.initialize = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=mock_error_result) + + # Create an async context manager mock for streamablehttp_client + @asynccontextmanager + async def mock_streamablehttp_client(url): + read_stream = AsyncMock() + write_stream = AsyncMock() + get_session_id = MagicMock(return_value="test-session-789") + yield (read_stream, write_stream, get_session_id) + + # Create an async context manager mock for ClientSession + @asynccontextmanager + async def mock_client_session(read_stream, write_stream): + yield mock_session + + with patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.streamablehttp_client', mock_streamablehttp_client), \ + patch('litellm.proxy._experimental.mcp_server.mcp_server_manager.ClientSession', mock_client_session): + + # Load server config with HTTP transport + test_manager.load_servers_from_config({ + "test_http_server": { + "url": "https://test-mcp-server.com/mcp", + "transport": MCPTransport.http, + "description": "Test HTTP MCP Server" + } + }) + + # Manually set up tool mapping + test_manager.tool_name_to_mcp_server_name_mapping["gmail_send_email"] = "test_http_server" + + # Call the tool with invalid data + result = await test_manager.call_tool( + name="gmail_send_email", + arguments={"to": "invalid-email", "subject": "Test", "body": "Test"} + ) + + # Assertions for error case + assert result.isError is True + assert len(result.content) == 1 + # Type check before accessing text attribute + assert isinstance(result.content[0], TextContent) + assert "Error: Invalid email address" in result.content[0].text + + # Verify session methods were called + mock_session.initialize.assert_called_once() + mock_session.call_tool.assert_called_once() + + +@pytest.mark.asyncio +async def test_mcp_http_transport_tool_not_found(): + """Test calling a tool that doesn't exist""" + + # Create a fresh manager for testing + test_manager = MCPServerManager() + + # Load server config + test_manager.load_servers_from_config({ + "test_http_server": { + "url": "https://test-mcp-server.com/mcp", + "transport": MCPTransport.http, + "description": "Test HTTP MCP Server" + } + }) + + # Try to call a tool that doesn't exist in mapping + with pytest.raises(ValueError, match="Tool nonexistent_tool not found"): + await test_manager.call_tool( + name="nonexistent_tool", + arguments={"param": "value"} + ) + +