diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index 562de31b7..b2f203d6e 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -603,13 +603,26 @@ async def _handle_message( raise_exceptions: bool = False, ): with warnings.catch_warnings(record=True) as w: - # TODO(Marcelo): We should be checking if message is Exception here. - match message: # type: ignore[reportMatchNotExhaustive] + match message: case RequestResponder(request=types.ClientRequest(root=req)) as responder: with responder: await self._handle_request(message, req, session, lifespan_context, raise_exceptions) case types.ClientNotification(root=notify): await self._handle_notification(notify) + case Exception(): + logger.error(f"Received exception from stream: {message}") + if raise_exceptions: + raise message + await session.send_log_message( + level="error", + data={ + "message": str(message), + "type": type(message).__name__, + "module": type(message).__module__, + "args": getattr(message, "args", None), + }, + logger="mcp.server.exception_handler", + ) for warning in w: logger.info("Warning: %s: %s", warning.category.__name__, warning.message) diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py new file mode 100644 index 000000000..35e7b1e1d --- /dev/null +++ b/tests/server/test_lowlevel_exception_handling.py @@ -0,0 +1,76 @@ +from unittest.mock import AsyncMock, Mock + +import pytest + +import mcp.types as types +from mcp.server.lowlevel.server import Server +from mcp.server.session import ServerSession +from mcp.shared.session import RequestResponder + + +@pytest.mark.anyio +async def test_exception_handling_with_raise_exceptions_true(): + """Test that exceptions are re-raised when raise_exceptions=True""" + server = Server("test-server") + session = Mock(spec=ServerSession) + session.send_log_message = AsyncMock() + + test_exception = RuntimeError("Test error") + + with pytest.raises(RuntimeError, match="Test error"): + await server._handle_message(test_exception, session, {}, raise_exceptions=True) + + # Should not send log message when re-raising + session.send_log_message.assert_not_called() + + +@pytest.mark.anyio +@pytest.mark.parametrize( + "exception_class,message", + [ + (ValueError, "Test validation error"), + (RuntimeError, "Test runtime error"), + (KeyError, "Test key error"), + (Exception, "Basic error"), + ], +) +async def test_exception_handling_with_raise_exceptions_false(exception_class, message): + """Test that exceptions are logged when raise_exceptions=False""" + server = Server("test-server") + session = Mock(spec=ServerSession) + session.send_log_message = AsyncMock() + + test_exception = exception_class(message) + + await server._handle_message(test_exception, session, {}, raise_exceptions=False) + + # Should send log message + session.send_log_message.assert_called_once() + call_args = session.send_log_message.call_args + + assert call_args.kwargs["level"] == "error" + assert call_args.kwargs["data"]["message"] == str(test_exception) + assert call_args.kwargs["data"]["type"] == exception_class.__name__ + assert call_args.kwargs["logger"] == "mcp.server.exception_handler" + + +@pytest.mark.anyio +async def test_normal_message_handling_not_affected(): + """Test that normal messages still work correctly""" + server = Server("test-server") + session = Mock(spec=ServerSession) + + # Create a mock RequestResponder + responder = Mock(spec=RequestResponder) + responder.request = types.ClientRequest(root=types.PingRequest(method="ping")) + responder.__enter__ = Mock(return_value=responder) + responder.__exit__ = Mock(return_value=None) + + # Mock the _handle_request method to avoid complex setup + server._handle_request = AsyncMock() + + # Should handle normally without any exception handling + await server._handle_message(responder, session, {}, raise_exceptions=False) + + # Verify _handle_request was called + server._handle_request.assert_called_once()