Skip to content

Improved Trio support #946

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"python-multipart>=0.0.9",
"sse-starlette>=1.6.1",
"pydantic-settings>=2.5.2",
"typing_extensions>=4.12",
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
]

Expand All @@ -48,10 +49,10 @@ required-version = ">=0.7.2"

[dependency-groups]
dev = [
"anyio[trio]",
"pyright>=1.1.391",
"pytest>=8.3.4",
"ruff>=0.8.5",
"trio>=0.26.2",
"pytest-flakefinder>=1.1.0",
"pytest-xdist>=3.6.1",
"pytest-examples>=0.0.14",
Expand Down Expand Up @@ -122,5 +123,8 @@ filterwarnings = [
# This should be fixed on Uvicorn's side.
"ignore::DeprecationWarning:websockets",
"ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning",
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel"
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel",
# This is to avoid test failures on Trio due to httpx's failure to explicitly close
# async generators
"ignore::pytest.PytestUnraisableExceptionWarning"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we solve those? Why this doesn't happen on asyncio?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trio considers implicit async generator finalization a bad practice and emits a warning. Pytest turns this into an unraisable exception warning.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only way to fix this is to fix the problem in httpcore.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would the fix in httpcore look like?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this is not resource warning anyway, I'm okay with it. I read it wrongly.

Copy link
Author

@agronholm agronholm Jun 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The httpcore PR is ready, just waiting for review.

EDIT: the PR has been accepted, waiting for merge.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The httpx PR: encode/httpx#3593

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything seems to fall into place now. With all my local changes applied, the test suite runs flawlessly on Trio.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The httpx PR is now passing all tests with 100% coverage. Waiting for a review and merging.

]
46 changes: 24 additions & 22 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

import logging
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from contextlib import aclosing, asynccontextmanager
from dataclasses import dataclass
from datetime import timedelta
from typing import cast

import anyio
import httpx
Expand Down Expand Up @@ -284,16 +285,18 @@ async def _handle_sse_response(self, response: httpx.Response, ctx: RequestConte
"""Handle SSE response from the server."""
try:
event_source = EventSource(response)
async for sse in event_source.aiter_sse():
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
sse_iter = cast(AsyncGenerator[ServerSentEvent], event_source.aiter_sse())
async with aclosing(sse_iter) as items:
async for sse in items:
is_complete = await self._handle_sse_event(
sse,
ctx.read_stream_writer,
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
)
# If the SSE event indicates completion, like returning respose/error
# break the loop
if is_complete:
break
except Exception as e:
logger.exception("Error reading SSE stream:")
await ctx.read_stream_writer.send(e)
Expand Down Expand Up @@ -434,15 +437,14 @@ async def streamablehttp_client(
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

async with anyio.create_task_group() as tg:
try:
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
try:
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")

async with httpx_client_factory(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
auth=transport.auth,
) as client:
async with create_mcp_http_client(
headers=transport.request_headers,
timeout=httpx.Timeout(transport.timeout, read=transport.sse_read_timeout),
) as client:
async with anyio.create_task_group() as tg:
# Define callbacks that need access to tg
def start_get_stream() -> None:
tg.start_soon(transport.handle_get_stream, client, read_stream_writer)
Expand All @@ -467,6 +469,6 @@ def start_get_stream() -> None:
if transport.session_id and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
finally:
await read_stream_writer.aclose()
await write_stream.aclose()
7 changes: 6 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
import anyio.lowlevel
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import AnyUrl
from typing_extensions import Self

import mcp.types as types
from mcp.server.models import InitializationOptions
Expand Down Expand Up @@ -93,10 +94,14 @@ def __init__(
)

self._init_options = init_options

async def __aenter__(self) -> Self:
await super().__aenter__()
self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[
ServerRequestResponder
](0)
self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose())
self._exit_stack.push_async_callback(self._incoming_message_stream_reader.aclose)
return self

@property
def client_params(self) -> types.InitializeRequestParams | None:
Expand Down
23 changes: 13 additions & 10 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import anyio
import httpx
from anyio.abc import TaskGroup
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from pydantic import BaseModel
from typing_extensions import Self
Expand Down Expand Up @@ -177,6 +178,8 @@ class BaseSession(
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_exit_stack: AsyncExitStack
_task_group: TaskGroup

def __init__(
self,
Expand All @@ -196,12 +199,17 @@ def __init__(
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_callbacks = {}
self._exit_stack = AsyncExitStack()

async def __aenter__(self) -> Self:
self._task_group = anyio.create_task_group()
await self._task_group.__aenter__()
self._task_group.start_soon(self._receive_loop)
async with AsyncExitStack() as exit_stack:
self._task_group = await exit_stack.enter_async_context(anyio.create_task_group())
self._task_group.start_soon(self._receive_loop)
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
exit_stack.callback(self._task_group.cancel_scope.cancel)
self._exit_stack = exit_stack.pop_all()

return self

async def __aexit__(
Expand All @@ -210,12 +218,7 @@ async def __aexit__(
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
await self._exit_stack.aclose()
# Using BaseSession as a context manager should not block on exit (this
# would be very surprising behavior), so make sure to cancel the tasks
# in the task group.
self._task_group.cancel_scope.cancel()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
return await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)

async def send_request(
self,
Expand Down
2 changes: 1 addition & 1 deletion tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def oauth_token():


@pytest.fixture
async def oauth_provider(client_metadata, mock_storage):
def oauth_provider(client_metadata, mock_storage):
async def mock_redirect_handler(url: str) -> None:
pass

Expand Down
8 changes: 4 additions & 4 deletions tests/client/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,15 @@ async def mock_server():
)

async with (
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
ClientSession(
server_to_client_receive,
client_to_server_send,
) as session,
anyio.create_task_group() as tg,
client_to_server_send,
client_to_server_receive,
server_to_client_send,
server_to_client_receive,
):
tg.start_soon(mock_server)

Expand Down
6 changes: 0 additions & 6 deletions tests/conftest.py

This file was deleted.

16 changes: 9 additions & 7 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,17 @@ async def replay_events_after(
"""Replay events after the specified ID."""
# Find the index of the last event ID
start_index = None
for i, (_, event_id, _) in enumerate(self._events):
stream_id = None
for i, (stream_id_, event_id, _) in enumerate(self._events):
if event_id == last_event_id:
start_index = i + 1
stream_id = stream_id_
break

if start_index is None:
# If event ID not found, start from beginning
start_index = 0

stream_id = None
# Replay events
for _, event_id, message in self._events[start_index:]:
await send_callback(EventMessage(message, event_id))
Expand Down Expand Up @@ -1003,7 +1004,8 @@ async def test_streamablehttp_client_resumption(event_server):
captured_session_id = None
captured_resumption_token = None
captured_notifications = []
tool_started = False
tool_started_event = anyio.Event()
session_resumption_token_received_event = anyio.Event()

async def message_handler(
message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
Expand All @@ -1013,12 +1015,12 @@ async def message_handler(
# Look for our special notification that indicates the tool is running
if isinstance(message.root, types.LoggingMessageNotification):
if message.root.params.data == "Tool started":
nonlocal tool_started
tool_started = True
tool_started_event.set()

async def on_resumption_token_update(token: str) -> None:
nonlocal captured_resumption_token
captured_resumption_token = token
session_resumption_token_received_event.set()

# First, start the client session and begin the long-running tool
async with streamablehttp_client(f"{server_url}/mcp", terminate_on_close=False) as (
Expand Down Expand Up @@ -1055,8 +1057,8 @@ async def run_tool():

# Wait for the tool to start and at least one notification
# and then kill the task group
while not tool_started or not captured_resumption_token:
await anyio.sleep(0.1)
await tool_started_event.wait()
await session_resumption_token_received_event.wait()
tg.cancel_scope.cancel()

# Store pre notifications and clear the captured notifications
Expand Down
2 changes: 2 additions & 0 deletions tests/shared/test_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

SERVER_NAME = "test_server_for_WS"

pytestmark = pytest.mark.parametrize("anyio_backend", ["asyncio"])


@pytest.fixture
def server_port() -> int:
Expand Down
11 changes: 9 additions & 2 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading