From a988d3713fceadf5710333fdf70aa80e9258bbf3 Mon Sep 17 00:00:00 2001 From: bzsurbhi Date: Tue, 15 Jul 2025 18:01:36 -0700 Subject: [PATCH] fix: Return JSON-RPC protocol errors for unknown tools --- src/mcp/server/lowlevel/server.py | 10 ++ .../server/test_lowlevel_input_validation.py | 148 +++++++++++++++--- 2 files changed, 137 insertions(+), 21 deletions(-) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 562de31b7..61a70899a 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -452,6 +452,14 @@ async def handler(req: types.CallToolRequest): arguments = req.params.arguments or {} tool = await self._get_cached_tool_definition(tool_name) + # Check if tool exists - return protocol error if not found + if tool is None: + raise McpError( + types.ErrorData( + code=types.METHOD_NOT_FOUND, + message=f"Unknown tool: {tool_name}", + ) + ) # input validation if validate_input and tool: try: @@ -499,6 +507,8 @@ async def handler(req: types.CallToolRequest): isError=False, ) ) + except McpError: + raise except Exception as e: return self._make_error_result(str(e)) diff --git a/tests/server/test_lowlevel_input_validation.py b/tests/server/test_lowlevel_input_validation.py index 250159733..6ff2f01e7 100644 --- a/tests/server/test_lowlevel_input_validation.py +++ b/tests/server/test_lowlevel_input_validation.py @@ -12,16 +12,26 @@ from mcp.server.lowlevel import NotificationOptions from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession +from mcp.shared.exceptions import McpError from mcp.shared.message import SessionMessage from mcp.shared.session import RequestResponder -from mcp.types import CallToolResult, ClientResult, ServerNotification, ServerRequest, TextContent, Tool +from mcp.types import ( + METHOD_NOT_FOUND, + CallToolResult, + ClientResult, + ErrorData, + ServerNotification, + ServerRequest, + TextContent, + Tool, +) async def run_tool_test( tools: list[Tool], call_tool_handler: Callable[[str, dict[str, Any]], Awaitable[list[TextContent]]], - test_callback: Callable[[ClientSession], Awaitable[CallToolResult]], -) -> CallToolResult: + test_callback: Callable[[ClientSession], Awaitable[Any]], +) -> Any: """Helper to run a tool test with minimal boilerplate. Args: @@ -263,8 +273,9 @@ async def test_callback(client_session: ClientSession) -> CallToolResult: @pytest.mark.anyio -async def test_tool_not_in_list_logs_warning(caplog): - """Test that calling a tool not in list_tools logs a warning and skips validation.""" +async def test_tool_not_in_list_logs_warning_before_protocol_error(caplog): + """Test that calling a tool not in list_tools logs a warning before returning protocol error.""" + tools = [ Tool( name="add", @@ -281,30 +292,125 @@ async def test_tool_not_in_list_logs_warning(caplog): ] async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: - # This should be reached since validation is skipped for unknown tools - if name == "unknown_tool": - # Even with invalid arguments, this should execute since validation is skipped - return [TextContent(type="text", text="Unknown tool executed without validation")] + # This should not be reached due to protocol error for unknown tools + if name == "add": + result = arguments["a"] + arguments["b"] + return [TextContent(type="text", text=f"Result: {result}")] else: raise ValueError(f"Unknown tool: {name}") - async def test_callback(client_session: ClientSession) -> CallToolResult: - # Call a tool that's not in the list with invalid arguments - # This should trigger the warning about validation not being performed - return await client_session.call_tool("unknown_tool", {"invalid": "args"}) + async def test_callback(client_session: ClientSession): + # Call a tool that's not in the list - should now raise McpError + try: + return await client_session.call_tool("unknown_tool", {"invalid": "args"}) + except McpError as e: + return e with caplog.at_level(logging.WARNING): result = await run_tool_test(tools, call_tool_handler, test_callback) - # Verify results - should succeed because validation is skipped for unknown tools - assert result is not None - assert not result.isError - assert len(result.content) == 1 - assert result.content[0].type == "text" - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Unknown tool executed without validation" + # Verify it's the correct protocol error + assert isinstance(result, McpError), f"Expected McpError but got {type(result)}" + assert isinstance(result.error, ErrorData) + assert result.error.code == METHOD_NOT_FOUND + assert "Unknown tool: unknown_tool" in result.error.message - # Verify warning was logged + # Verify warning was still logged during the tool lookup process assert any( "Tool 'unknown_tool' not listed, no validation will be performed" in record.message for record in caplog.records ) + + +@pytest.mark.anyio +async def test_unknown_tool_returns_protocol_error(): + """Test that calling an unknown tool returns a proper JSON-RPC protocol error.""" + + tools = [ + Tool( + name="add", + description="Add two numbers", + inputSchema={ + "type": "object", + "properties": { + "a": {"type": "number"}, + "b": {"type": "number"}, + }, + "required": ["a", "b"], + }, + ) + ] + + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: + # This should not be reached for unknown tools due to protocol error + if name == "add": + result = arguments["a"] + arguments["b"] + return [TextContent(type="text", text=f"Result: {result}")] + else: + raise ValueError(f"Unknown tool: {name}") + + async def test_callback(client_session: ClientSession): + # Try to call a tool that doesn't exist - should raise McpError + try: + return await client_session.call_tool("unknown_tool", {"invalid": "args"}) + except McpError as e: + return e + + result = await run_tool_test(tools, call_tool_handler, test_callback) + + # Verify it's the correct protocol error + assert isinstance(result, McpError), f"Expected McpError but got {type(result)}" + assert isinstance(result.error, ErrorData) + assert result.error.code == METHOD_NOT_FOUND + assert "Unknown tool: unknown_tool" in result.error.message + + +@pytest.mark.anyio +async def test_tool_execution_error_vs_protocol_error(): + """Test the difference between tool execution errors and protocol errors.""" + + tools = [ + Tool( + name="failing_tool", + description="A tool that always fails during execution", + inputSchema={ + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + }, + ) + ] + + async def call_tool_handler(name: str, arguments: dict[str, Any]) -> list[TextContent]: + if name == "failing_tool": + # This should cause a tool execution error (not a protocol error) + raise RuntimeError("Tool execution failed") + else: + raise ValueError(f"Unknown tool: {name}") + + # Test 1: Tool execution error (valid tool that fails) + async def test_execution_error(client_session: ClientSession): + return await client_session.call_tool("failing_tool", {"input": "test"}) + + execution_result = await run_tool_test(tools, call_tool_handler, test_execution_error) + + # Should return CallToolResult with isError=True (tool execution error) + assert isinstance(execution_result, CallToolResult) + assert execution_result.isError + assert isinstance(execution_result.content[0], TextContent) + assert "Tool execution failed" in execution_result.content[0].text + + # Test 2: Protocol error (unknown tool) + async def test_protocol_error(client_session: ClientSession): + try: + return await client_session.call_tool("nonexistent_tool", {"input": "test"}) + except McpError as e: + return e + + protocol_result = await run_tool_test(tools, call_tool_handler, test_protocol_error) + + # Should return McpError (protocol error) + assert isinstance(protocol_result, McpError), f"Expected McpError but got {type(protocol_result)}" + assert isinstance(protocol_result.error, ErrorData) + assert protocol_result.error.code == METHOD_NOT_FOUND + assert "Unknown tool: nonexistent_tool" in protocol_result.error.message