Skip to content

Commit 0c31910

Browse files
fix: Improve exception handling in lowlevel server
- Handle exceptions in match statement as suggested by TODO - Use proper dict structure for error logging instead of ErrorData - Add comprehensive tests with parameterization - Include exception type, module, and args in log data Reported-by: AishwaryaKalloli Github-Issue: #786
1 parent 7678a07 commit 0c31910

File tree

2 files changed

+85
-8
lines changed

2 files changed

+85
-8
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -603,24 +603,25 @@ async def _handle_message(
603603
raise_exceptions: bool = False,
604604
):
605605
with warnings.catch_warnings(record=True) as w:
606-
# TODO(Marcelo): We should be checking if message is Exception here.
607-
match message: # type: ignore[reportMatchNotExhaustive]
606+
match message:
608607
case RequestResponder(request=types.ClientRequest(root=req)) as responder:
609608
with responder:
610609
await self._handle_request(message, req, session, lifespan_context, raise_exceptions)
611610
case types.ClientNotification(root=notify):
612611
await self._handle_notification(notify)
613612
case Exception():
614-
logger.error(f"Received error message: {message}")
613+
logger.error(f"Received exception from stream: {message}")
615614
if raise_exceptions:
616615
raise message
617-
# Send the error as a notification
618-
# as we don't have a request context
619616
await session.send_log_message(
620617
level="error",
621-
data=types.ErrorData(
622-
code=types.INTERNAL_ERROR, message=str(message), data=None
623-
),
618+
data={
619+
"message": str(message),
620+
"type": type(message).__name__,
621+
"module": type(message).__module__,
622+
"args": getattr(message, "args", None),
623+
},
624+
logger="mcp.server.exception_handler",
624625
)
625626

626627
for warning in w:
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from unittest.mock import AsyncMock, Mock
2+
3+
import pytest
4+
5+
import mcp.types as types
6+
from mcp.server.lowlevel.server import Server
7+
from mcp.server.session import ServerSession
8+
from mcp.shared.session import RequestResponder
9+
10+
11+
@pytest.mark.anyio
12+
async def test_exception_handling_with_raise_exceptions_true():
13+
"""Test that exceptions are re-raised when raise_exceptions=True"""
14+
server = Server("test-server")
15+
session = Mock(spec=ServerSession)
16+
session.send_log_message = AsyncMock()
17+
18+
test_exception = RuntimeError("Test error")
19+
20+
with pytest.raises(RuntimeError, match="Test error"):
21+
await server._handle_message(test_exception, session, {}, raise_exceptions=True)
22+
23+
# Should not send log message when re-raising
24+
session.send_log_message.assert_not_called()
25+
26+
27+
@pytest.mark.anyio
28+
@pytest.mark.parametrize(
29+
"exception_class,message",
30+
[
31+
(ValueError, "Test validation error"),
32+
(RuntimeError, "Test runtime error"),
33+
(KeyError, "Test key error"),
34+
(Exception, "Basic error"),
35+
],
36+
)
37+
async def test_exception_handling_with_raise_exceptions_false(exception_class, message):
38+
"""Test that exceptions are logged when raise_exceptions=False"""
39+
server = Server("test-server")
40+
session = Mock(spec=ServerSession)
41+
session.send_log_message = AsyncMock()
42+
43+
test_exception = exception_class(message)
44+
45+
await server._handle_message(test_exception, session, {}, raise_exceptions=False)
46+
47+
# Should send log message
48+
session.send_log_message.assert_called_once()
49+
call_args = session.send_log_message.call_args
50+
51+
assert call_args.kwargs["level"] == "error"
52+
assert call_args.kwargs["data"]["message"] == str(test_exception)
53+
assert call_args.kwargs["data"]["type"] == exception_class.__name__
54+
assert call_args.kwargs["logger"] == "mcp.server.exception_handler"
55+
56+
57+
@pytest.mark.anyio
58+
async def test_normal_message_handling_not_affected():
59+
"""Test that normal messages still work correctly"""
60+
server = Server("test-server")
61+
session = Mock(spec=ServerSession)
62+
63+
# Create a mock RequestResponder
64+
responder = Mock(spec=RequestResponder)
65+
responder.request = types.ClientRequest(root=types.PingRequest(method="ping"))
66+
responder.__enter__ = Mock(return_value=responder)
67+
responder.__exit__ = Mock(return_value=None)
68+
69+
# Mock the _handle_request method to avoid complex setup
70+
server._handle_request = AsyncMock()
71+
72+
# Should handle normally without any exception handling
73+
await server._handle_message(responder, session, {}, raise_exceptions=False)
74+
75+
# Verify _handle_request was called
76+
server._handle_request.assert_called_once()

0 commit comments

Comments
 (0)