Skip to content

Commit 7b1078b

Browse files
sobyfelixweinbergerihrpr
authored
Fix: Prevent session manager shutdown on individual session crash (#841)
Co-authored-by: Felix Weinberger <fweinberger@anthropic.com> Co-authored-by: Inna Harper <inna.hrpr@gmail.com>
1 parent 3abefee commit 7b1078b

File tree

3 files changed

+165
-13
lines changed

3 files changed

+165
-13
lines changed

src/mcp/server/streamable_http.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ def __init__(
173173
] = {}
174174
self._terminated = False
175175

176+
@property
177+
def is_terminated(self) -> bool:
178+
"""Check if this transport has been explicitly terminated."""
179+
return self._terminated
180+
176181
def _create_error_response(
177182
self,
178183
error_message: str,

src/mcp/server/streamable_http_manager.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ class StreamableHTTPSessionManager:
5252
json_response: Whether to use JSON responses instead of SSE streams
5353
stateless: If True, creates a completely fresh transport for each request
5454
with no session tracking or state persistence between requests.
55-
5655
"""
5756

5857
def __init__(
@@ -173,12 +172,15 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA
173172
async with http_transport.connect() as streams:
174173
read_stream, write_stream = streams
175174
task_status.started()
176-
await self.app.run(
177-
read_stream,
178-
write_stream,
179-
self.app.create_initialization_options(),
180-
stateless=True,
181-
)
175+
try:
176+
await self.app.run(
177+
read_stream,
178+
write_stream,
179+
self.app.create_initialization_options(),
180+
stateless=True,
181+
)
182+
except Exception:
183+
logger.exception("Stateless session crashed")
182184

183185
# Assert task group is not None for type checking
184186
assert self._task_group is not None
@@ -233,12 +235,31 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE
233235
async with http_transport.connect() as streams:
234236
read_stream, write_stream = streams
235237
task_status.started()
236-
await self.app.run(
237-
read_stream,
238-
write_stream,
239-
self.app.create_initialization_options(),
240-
stateless=False, # Stateful mode
241-
)
238+
try:
239+
await self.app.run(
240+
read_stream,
241+
write_stream,
242+
self.app.create_initialization_options(),
243+
stateless=False, # Stateful mode
244+
)
245+
except Exception as e:
246+
logger.error(
247+
f"Session {http_transport.mcp_session_id} crashed: {e}",
248+
exc_info=True,
249+
)
250+
finally:
251+
# Only remove from instances if not terminated
252+
if (
253+
http_transport.mcp_session_id
254+
and http_transport.mcp_session_id in self._server_instances
255+
and not http_transport.is_terminated
256+
):
257+
logger.info(
258+
"Cleaning up crashed session "
259+
f"{http_transport.mcp_session_id} from "
260+
"active instances."
261+
)
262+
del self._server_instances[http_transport.mcp_session_id]
242263

243264
# Assert task group is not None for type checking
244265
assert self._task_group is not None

tests/server/test_streamable_http_manager.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
"""Tests for StreamableHTTPSessionManager."""
22

3+
from unittest.mock import AsyncMock
4+
35
import anyio
46
import pytest
57

68
from mcp.server.lowlevel import Server
9+
from mcp.server.streamable_http import MCP_SESSION_ID_HEADER
710
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
811

912

@@ -71,3 +74,126 @@ async def send(message):
7174
await manager.handle_request(scope, receive, send)
7275

7376
assert "Task group is not initialized. Make sure to use run()." in str(excinfo.value)
77+
78+
79+
class TestException(Exception):
80+
__test__ = False # Prevent pytest from collecting this as a test class
81+
pass
82+
83+
84+
@pytest.fixture
85+
async def running_manager():
86+
app = Server("test-cleanup-server")
87+
# It's important that the app instance used by the manager is the one we can patch
88+
manager = StreamableHTTPSessionManager(app=app)
89+
async with manager.run():
90+
# Patch app.run here if it's simpler, or patch it within the test
91+
yield manager, app
92+
93+
94+
@pytest.mark.anyio
95+
async def test_stateful_session_cleanup_on_graceful_exit(running_manager):
96+
manager, app = running_manager
97+
98+
mock_mcp_run = AsyncMock(return_value=None)
99+
# This will be called by StreamableHTTPSessionManager's run_server -> self.app.run
100+
app.run = mock_mcp_run
101+
102+
sent_messages = []
103+
104+
async def mock_send(message):
105+
sent_messages.append(message)
106+
107+
scope = {
108+
"type": "http",
109+
"method": "POST",
110+
"path": "/mcp",
111+
"headers": [(b"content-type", b"application/json")],
112+
}
113+
114+
async def mock_receive():
115+
return {"type": "http.request", "body": b"", "more_body": False}
116+
117+
# Trigger session creation
118+
await manager.handle_request(scope, mock_receive, mock_send)
119+
120+
# Extract session ID from response headers
121+
session_id = None
122+
for msg in sent_messages:
123+
if msg["type"] == "http.response.start":
124+
for header_name, header_value in msg.get("headers", []):
125+
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
126+
session_id = header_value.decode()
127+
break
128+
if session_id: # Break outer loop if session_id is found
129+
break
130+
131+
assert session_id is not None, "Session ID not found in response headers"
132+
133+
# Ensure MCPServer.run was called
134+
mock_mcp_run.assert_called_once()
135+
136+
# At this point, mock_mcp_run has completed, and the finally block in
137+
# StreamableHTTPSessionManager's run_server should have executed.
138+
139+
# To ensure the task spawned by handle_request finishes and cleanup occurs:
140+
# Give other tasks a chance to run. This is important for the finally block.
141+
await anyio.sleep(0.01)
142+
143+
assert session_id not in manager._server_instances, (
144+
"Session ID should be removed from _server_instances after graceful exit"
145+
)
146+
assert not manager._server_instances, "No sessions should be tracked after the only session exits gracefully"
147+
148+
149+
@pytest.mark.anyio
150+
async def test_stateful_session_cleanup_on_exception(running_manager):
151+
manager, app = running_manager
152+
153+
mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash"))
154+
app.run = mock_mcp_run
155+
156+
sent_messages = []
157+
158+
async def mock_send(message):
159+
sent_messages.append(message)
160+
# If an exception occurs, the transport might try to send an error response
161+
# For this test, we mostly care that the session is established enough
162+
# to get an ID
163+
if message["type"] == "http.response.start" and message["status"] >= 500:
164+
pass # Expected if TestException propagates that far up the transport
165+
166+
scope = {
167+
"type": "http",
168+
"method": "POST",
169+
"path": "/mcp",
170+
"headers": [(b"content-type", b"application/json")],
171+
}
172+
173+
async def mock_receive():
174+
return {"type": "http.request", "body": b"", "more_body": False}
175+
176+
# Trigger session creation
177+
await manager.handle_request(scope, mock_receive, mock_send)
178+
179+
session_id = None
180+
for msg in sent_messages:
181+
if msg["type"] == "http.response.start":
182+
for header_name, header_value in msg.get("headers", []):
183+
if header_name.decode().lower() == MCP_SESSION_ID_HEADER.lower():
184+
session_id = header_value.decode()
185+
break
186+
if session_id: # Break outer loop if session_id is found
187+
break
188+
189+
assert session_id is not None, "Session ID not found in response headers"
190+
191+
mock_mcp_run.assert_called_once()
192+
193+
# Give other tasks a chance to run to ensure the finally block executes
194+
await anyio.sleep(0.01)
195+
196+
assert session_id not in manager._server_instances, (
197+
"Session ID should be removed from _server_instances after an exception"
198+
)
199+
assert not manager._server_instances, "No sessions should be tracked after the only session crashes"

0 commit comments

Comments
 (0)