Skip to content

Commit f248d6d

Browse files
fix: Improve exception handling in lowlevel server
- Handle exceptions in match statement as suggested by TODO - Use proper dict structure for error logging instead of ErrorData - Add comprehensive tests with parameterization - Include exception type, module, and args in log data Reported-by: AishwaryaKalloli Github-Issue: modelcontextprotocol#786
1 parent 7a35ca4 commit f248d6d

File tree

2 files changed

+85
-8
lines changed

2 files changed

+85
-8
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,7 @@ async def _handle_message(
525525
raise_exceptions: bool = False,
526526
):
527527
with warnings.catch_warnings(record=True) as w:
528-
# TODO(Marcelo): We should be checking if message is Exception here.
529-
match message: # type: ignore[reportMatchNotExhaustive]
528+
match message:
530529
case (
531530
RequestResponder(request=types.ClientRequest(root=req)) as responder
532531
):
@@ -537,16 +536,18 @@ async def _handle_message(
537536
case types.ClientNotification(root=notify):
538537
await self._handle_notification(notify)
539538
case Exception():
540-
logger.error(f"Received error message: {message}")
539+
logger.error(f"Received exception from stream: {message}")
541540
if raise_exceptions:
542541
raise message
543-
# Send the error as a notification
544-
# as we don't have a request context
545542
await session.send_log_message(
546543
level="error",
547-
data=types.ErrorData(
548-
code=types.INTERNAL_ERROR, message=str(message), data=None
549-
),
544+
data={
545+
"message": str(message),
546+
"type": type(message).__name__,
547+
"module": type(message).__module__,
548+
"args": getattr(message, "args", None),
549+
},
550+
logger="mcp.server.exception_handler",
550551
)
551552

552553
for warning in w:
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from unittest.mock import AsyncMock, Mock
2+
3+
import pytest
4+
5+
import mcp.types as types
6+
from mcp.server.lowlevel.server import Server
7+
from mcp.server.session import ServerSession
8+
from mcp.shared.session import RequestResponder
9+
10+
11+
@pytest.mark.anyio
12+
async def test_exception_handling_with_raise_exceptions_true():
13+
"""Test that exceptions are re-raised when raise_exceptions=True"""
14+
server = Server("test-server")
15+
session = Mock(spec=ServerSession)
16+
session.send_log_message = AsyncMock()
17+
18+
test_exception = RuntimeError("Test error")
19+
20+
with pytest.raises(RuntimeError, match="Test error"):
21+
await server._handle_message(test_exception, session, {}, raise_exceptions=True)
22+
23+
# Should not send log message when re-raising
24+
session.send_log_message.assert_not_called()
25+
26+
27+
@pytest.mark.anyio
28+
@pytest.mark.parametrize(
29+
"exception_class,message",
30+
[
31+
(ValueError, "Test validation error"),
32+
(RuntimeError, "Test runtime error"),
33+
(KeyError, "Test key error"),
34+
(Exception, "Basic error"),
35+
],
36+
)
37+
async def test_exception_handling_with_raise_exceptions_false(exception_class, message):
38+
"""Test that exceptions are logged when raise_exceptions=False"""
39+
server = Server("test-server")
40+
session = Mock(spec=ServerSession)
41+
session.send_log_message = AsyncMock()
42+
43+
test_exception = exception_class(message)
44+
45+
await server._handle_message(test_exception, session, {}, raise_exceptions=False)
46+
47+
# Should send log message
48+
session.send_log_message.assert_called_once()
49+
call_args = session.send_log_message.call_args
50+
51+
assert call_args.kwargs["level"] == "error"
52+
assert call_args.kwargs["data"]["message"] == str(test_exception)
53+
assert call_args.kwargs["data"]["type"] == exception_class.__name__
54+
assert call_args.kwargs["logger"] == "mcp.server.exception_handler"
55+
56+
57+
@pytest.mark.anyio
58+
async def test_normal_message_handling_not_affected():
59+
"""Test that normal messages still work correctly"""
60+
server = Server("test-server")
61+
session = Mock(spec=ServerSession)
62+
63+
# Create a mock RequestResponder
64+
responder = Mock(spec=RequestResponder)
65+
responder.request = types.ClientRequest(root=types.PingRequest(method="ping"))
66+
responder.__enter__ = Mock(return_value=responder)
67+
responder.__exit__ = Mock(return_value=None)
68+
69+
# Mock the _handle_request method to avoid complex setup
70+
server._handle_request = AsyncMock()
71+
72+
# Should handle normally without any exception handling
73+
await server._handle_message(responder, session, {}, raise_exceptions=False)
74+
75+
# Verify _handle_request was called
76+
server._handle_request.assert_called_once()

0 commit comments

Comments
 (0)