Skip to content

Simplify code on stdio_client #1181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 13 additions & 30 deletions src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import anyio
import anyio.lowlevel
from anyio.abc import Process
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from anyio.streams.text import TextReceiveStream
from pydantic import BaseModel, Field

Expand Down Expand Up @@ -107,33 +106,19 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
read_stream_writer, read_stream = anyio.create_memory_object_stream[SessionMessage | Exception](0)
write_stream, write_stream_reader = anyio.create_memory_object_stream[SessionMessage](0)

write_stream: MemoryObjectSendStream[SessionMessage]
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
command = _get_executable_command(server.command)

read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

try:
command = _get_executable_command(server.command)

# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
errlog=errlog,
cwd=server.cwd,
)
except OSError:
# Clean up streams if process creation fails
await read_stream.aclose()
await write_stream.aclose()
await read_stream_writer.aclose()
await write_stream_reader.aclose()
raise
Comment on lines -130 to -136
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It actually never triggers OSError here. anyio.open_process doesn't trigger.

Maybe the Windows logic does, but even if it does... The streams don't need to be closed because they are not even open yet, so just removing the except is fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the test suite I'm wrong... I'm not sure how.

# Open process with stderr piped for capture
process = await _create_platform_compatible_process(
command=command,
args=server.args,
env=({**get_default_environment(), **server.env} if server.env is not None else get_default_environment()),
errlog=errlog,
cwd=server.cwd,
)

async def stdout_reader():
assert process.stdout, "Opened process is missing stdout"
Expand Down Expand Up @@ -177,12 +162,10 @@ async def stdin_writer():
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()

async with (
anyio.create_task_group() as tg,
process,
):
async with anyio.create_task_group() as tg, process:
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)

try:
yield read_stream, write_stream
finally:
Expand Down
16 changes: 4 additions & 12 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,7 @@ async def send_request(
self._progress_callbacks[request_id] = progress_callback

try:
jsonrpc_request = JSONRPCRequest(
jsonrpc="2.0",
id=request_id,
**request_data,
)

jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)
await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))

# request read timeout takes precedence over session read timeout
Expand Down Expand Up @@ -329,10 +324,7 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er
await self._write_stream.send(session_message)

async def _receive_loop(self) -> None:
async with (
self._read_stream,
self._write_stream,
):
async with self._read_stream, self._write_stream:
try:
async for message in self._read_stream:
if isinstance(message, Exception):
Expand Down Expand Up @@ -418,10 +410,10 @@ async def _receive_loop(self) -> None:
# Without this handler, the exception would propagate up and
# crash the server's task group.
logging.debug("Read stream closed by client")
except Exception as e:
except Exception:
# Other exceptions are not expected and should be logged. We purposefully
# catch all exceptions here to avoid crashing the server.
logging.exception(f"Unhandled exception in receive loop: {e}")
logging.exception("Unhandled exception in receive loop")
finally:
# after the read stream is closed, we need to send errors
# to any pending requests
Expand Down
14 changes: 7 additions & 7 deletions tests/client/test_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def test_stdio_client_universal_cleanup():
"""
import time
import sys

# Simulate a long-running process
for i in range(100):
time.sleep(0.1)
Expand Down Expand Up @@ -532,7 +532,7 @@ async def test_stdio_client_graceful_stdin_exit():
script_content = textwrap.dedent(
"""
import sys

# Read from stdin until it's closed
try:
while True:
Expand All @@ -541,7 +541,7 @@ async def test_stdio_client_graceful_stdin_exit():
break
except:
pass

# Exit gracefully
sys.exit(0)
"""
Expand Down Expand Up @@ -590,16 +590,16 @@ async def test_stdio_client_stdin_close_ignored():
import signal
import sys
import time

# Set up SIGTERM handler to exit cleanly
def sigterm_handler(signum, frame):
sys.exit(0)

signal.signal(signal.SIGTERM, sigterm_handler)

# Close stdin immediately to simulate ignoring it
sys.stdin.close()

# Keep running until SIGTERM
while True:
time.sleep(0.1)
Expand Down
Loading