Skip to content

fix flaky fix-test_streamablehttp_client_resumption test #1166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,9 +837,7 @@ async def message_router():
response_id = str(message.root.id)
# If this response is for an existing request stream,
# send it there
if response_id in self._request_streams:
target_request_id = response_id

target_request_id = response_id
else:
# Extract related_request_id from meta if it exists
if (
Expand Down
154 changes: 99 additions & 55 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,32 +98,33 @@ async def replay_events_after(
send_callback: EventCallback,
) -> StreamId | None:
"""Replay events after the specified ID."""
# Find the index of the last event ID
start_index = None
for i, (_, event_id, _) in enumerate(self._events):
# Find the stream ID of the last event
target_stream_id = None
for stream_id, event_id, _ in self._events:
if event_id == last_event_id:
start_index = i + 1
target_stream_id = stream_id
break

if start_index is None:
# If event ID not found, start from beginning
start_index = 0
if target_stream_id is None:
# If event ID not found, return None
return None

stream_id = None
# Replay events
for _, event_id, message in self._events[start_index:]:
await send_callback(EventMessage(message, event_id))
# Capture the stream ID from the first replayed event
if stream_id is None and len(self._events) > start_index:
stream_id = self._events[start_index][0]
# Convert last_event_id to int for comparison
last_event_id_int = int(last_event_id)

return stream_id
# Replay only events from the same stream with ID > last_event_id
for stream_id, event_id, message in self._events:
if stream_id == target_stream_id and int(event_id) > last_event_id_int:
await send_callback(EventMessage(message, event_id))

return target_stream_id


# Test server implementation that follows MCP protocol
class ServerTest(Server):
def __init__(self):
super().__init__(SERVER_NAME)
self._lock = None # Will be initialized in async context

@self.read_resource()
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
Expand Down Expand Up @@ -159,6 +160,16 @@ async def handle_list_tools() -> list[Tool]:
description="A tool that triggers server-side sampling",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="wait_for_lock_with_notification",
description="A tool that sends a notification and waits for lock",
inputSchema={"type": "object", "properties": {}},
),
Tool(
name="release_lock",
description="A tool that releases the lock",
inputSchema={"type": "object", "properties": {}},
),
]

@self.call_tool()
Expand Down Expand Up @@ -214,6 +225,39 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
)
]

elif name == "wait_for_lock_with_notification":
# Initialize lock if not already done
if self._lock is None:
self._lock = anyio.Event()

# First send a notification
await ctx.session.send_log_message(
level="info",
data="First notification before lock",
logger="lock_tool",
related_request_id=ctx.request_id,
)

# Now wait for the lock to be released
await self._lock.wait()

# Send second notification after lock is released
await ctx.session.send_log_message(
level="info",
data="Second notification after lock",
logger="lock_tool",
related_request_id=ctx.request_id,
)

return [TextContent(type="text", text="Completed")]

elif name == "release_lock":
assert self._lock is not None, "Lock must be initialized before releasing"

# Release the lock
self._lock.set()
return [TextContent(type="text", text="Lock released")]

Comment on lines +228 to +260
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

really cool idea to have the tool calls create internal locks for this!

return [TextContent(type="text", text=f"Called {name}")]


Expand Down Expand Up @@ -825,7 +869,7 @@ async def test_streamablehttp_client_tool_invocation(initialized_client_session)
"""Test client tool invocation."""
# First list tools
tools = await initialized_client_session.list_tools()
assert len(tools.tools) == 4
assert len(tools.tools) == 6
assert tools.tools[0].name == "test_tool"

# Call the tool
Expand Down Expand Up @@ -862,7 +906,7 @@ async def test_streamablehttp_client_session_persistence(basic_server, basic_ser

# Make multiple requests to verify session persistence
tools = await session.list_tools()
assert len(tools.tools) == 4
assert len(tools.tools) == 6

# Read a resource
resource = await session.read_resource(uri=AnyUrl("foobar://test-persist"))
Expand Down Expand Up @@ -891,7 +935,7 @@ async def test_streamablehttp_client_json_response(json_response_server, json_se

# Check tool listing
tools = await session.list_tools()
assert len(tools.tools) == 4
assert len(tools.tools) == 6

# Call a tool and verify JSON response handling
result = await session.call_tool("test_tool", {})
Expand Down Expand Up @@ -962,7 +1006,7 @@ async def test_streamablehttp_client_session_termination(basic_server, basic_ser

# Make a request to confirm session is working
tools = await session.list_tools()
assert len(tools.tools) == 4
assert len(tools.tools) == 6

headers = {}
if captured_session_id:
Expand Down Expand Up @@ -1026,7 +1070,7 @@ async def mock_delete(self, *args, **kwargs):

# Make a request to confirm session is working
tools = await session.list_tools()
assert len(tools.tools) == 4
assert len(tools.tools) == 6

headers = {}
if captured_session_id:
Expand All @@ -1048,32 +1092,32 @@ async def mock_delete(self, *args, **kwargs):

@pytest.mark.anyio
async def test_streamablehttp_client_resumption(event_server):
"""Test client session to resume a long running tool."""
"""Test client session resumption using sync primitives for reliable coordination."""
_, server_url = event_server

# Variables to track the state
captured_session_id = None
captured_resumption_token = None
captured_notifications = []
tool_started = False
captured_protocol_version = None
first_notification_received = False

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
) -> None:
if isinstance(message, types.ServerNotification):
captured_notifications.append(message)
# Look for our special notification that indicates the tool is running
# Look for our first notification
if isinstance(message.root, types.LoggingMessageNotification):
if message.root.params.data == "Tool started":
nonlocal tool_started
tool_started = True
if message.root.params.data == "First notification before lock":
nonlocal first_notification_received
first_notification_received = True

async def on_resumption_token_update(token: str) -> None:
nonlocal captured_resumption_token
captured_resumption_token = token

# First, start the client session and begin the long-running tool
# First, start the client session and begin the tool that waits on lock
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
read_stream,
write_stream,
Expand All @@ -1088,7 +1132,7 @@ async def on_resumption_token_update(token: str) -> None:
# Capture the negotiated protocol version
captured_protocol_version = result.protocolVersion

# Start a long-running tool in a task
# Start the tool that will wait on lock in a task
async with anyio.create_task_group() as tg:

async def run_tool():
Expand All @@ -1099,7 +1143,9 @@ async def run_tool():
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
params=types.CallToolRequestParams(
name="wait_for_lock_with_notification", arguments={}
),
)
),
types.CallToolResult,
Expand All @@ -1108,15 +1154,19 @@ async def run_tool():

tg.start_soon(run_tool)

# Wait for the tool to start and at least one notification
# and then kill the task group
while not tool_started or not captured_resumption_token:
# Wait for the first notification and resumption token
while not first_notification_received or not captured_resumption_token:
await anyio.sleep(0.1)

# Kill the client session while tool is waiting on lock
tg.cancel_scope.cancel()

# Store pre notifications and clear the captured notifications
# for the post-resumption check
captured_notifications_pre = captured_notifications.copy()
# Verify we received exactly one notification
assert len(captured_notifications) == 1
assert isinstance(captured_notifications[0].root, types.LoggingMessageNotification)
assert captured_notifications[0].root.params.data == "First notification before lock"

# Clear notifications for the second phase
captured_notifications = []

# Now resume the session with the same mcp-session-id and protocol version
Expand All @@ -1125,54 +1175,48 @@ async def run_tool():
headers[MCP_SESSION_ID_HEADER] = captured_session_id
if captured_protocol_version:
headers[MCP_PROTOCOL_VERSION_HEADER] = captured_protocol_version

async with streamablehttp_client(f"{server_url}/mcp", headers=headers) as (
read_stream,
write_stream,
_,
):
async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session:
# Don't initialize - just use the existing session

# Resume the tool with the resumption token
assert captured_resumption_token is not None

result = await session.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name="release_lock", arguments={}),
)
),
types.CallToolResult,
)
metadata = ClientMessageMetadata(
resumption_token=captured_resumption_token,
)

result = await session.send_request(
types.ClientRequest(
types.CallToolRequest(
method="tools/call",
params=types.CallToolRequestParams(name="long_running_with_checkpoints", arguments={}),
params=types.CallToolRequestParams(name="wait_for_lock_with_notification", arguments={}),
)
),
types.CallToolResult,
metadata=metadata,
)

# We should get a complete result
assert len(result.content) == 1
assert result.content[0].type == "text"
assert "Completed" in result.content[0].text
assert result.content[0].text == "Completed"

# We should have received the remaining notifications
assert len(captured_notifications) > 0
assert len(captured_notifications) == 1

# Should not have the first notification
# Check that "Tool started" notification isn't repeated when resuming
assert not any(
isinstance(n.root, types.LoggingMessageNotification) and n.root.params.data == "Tool started"
for n in captured_notifications
)
# there is no intersection between pre and post notifications
assert not any(n in captured_notifications_pre for n in captured_notifications)
assert captured_notifications[0].root.params.data == "Second notification after lock"


@pytest.mark.anyio
async def test_streamablehttp_server_sampling(basic_server, basic_server_url):
"""Test server-initiated sampling request through streamable HTTP transport."""
print("Testing server sampling...")
# Variable to track if sampling callback was invoked
sampling_callback_invoked = False
captured_message_params = None
Expand Down
Loading