Skip to content

Commit fd826cc

Browse files
committed
Fix Websocket Client and Add Test
1 parent fc021ee commit fd826cc

File tree

4 files changed

+324
-42
lines changed

4 files changed

+324
-42
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ dependencies = [
3030
"sse-starlette>=1.6.1",
3131
"pydantic-settings>=2.5.2",
3232
"uvicorn>=0.23.1",
33+
"websockets>=15.0.1",
3334
]
3435

3536
[project.optional-dependencies]

src/mcp/client/websocket.py

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,23 @@
44
from typing import AsyncGenerator
55

66
import anyio
7-
import websockets
8-
from anyio.streams.memory import (
9-
MemoryObjectReceiveStream,
10-
MemoryObjectSendStream,
11-
create_memory_object_stream,
12-
)
7+
from pydantic import ValidationError
8+
from websockets.asyncio.client import connect as ws_connect
9+
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
10+
from websockets.typing import Subprotocol
1311

1412
import mcp.types as types
1513

1614
logger = logging.getLogger(__name__)
1715

16+
1817
@asynccontextmanager
19-
async def websocket_client(
20-
url: str
21-
) -> AsyncGenerator[
18+
async def websocket_client(url: str) -> AsyncGenerator[
2219
tuple[
2320
MemoryObjectReceiveStream[types.JSONRPCMessage | Exception],
2421
MemoryObjectSendStream[types.JSONRPCMessage],
2522
],
26-
None
23+
None,
2724
]:
2825
"""
2926
WebSocket client transport for MCP, symmetrical to the server version.
@@ -38,13 +35,13 @@ async def websocket_client(
3835
"""
3936

4037
# Create two in-memory streams:
41-
# - One for incoming messages (read_stream_recv, written by ws_reader)
42-
# - One for outgoing messages (write_stream_send, read by ws_writer)
43-
read_stream_send, read_stream_recv = create_memory_object_stream(0)
44-
write_stream_send, write_stream_recv = create_memory_object_stream(0)
38+
# - One for incoming messages (read_stream, written by ws_reader)
39+
# - One for outgoing messages (write_stream, read by ws_writer)
40+
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
41+
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
4542

4643
# Connect using websockets, requesting the "mcp" subprotocol
47-
async with websockets.connect(url, subprotocols=["mcp"]) as ws:
44+
async with ws_connect(url, subprotocols=[Subprotocol("mcp")]) as ws:
4845
# Optional check to ensure the server actually accepted "mcp"
4946
if ws.subprotocol != "mcp":
5047
raise ValueError(
@@ -54,46 +51,42 @@ async def websocket_client(
5451
async def ws_reader():
5552
"""
5653
Reads text messages from the WebSocket, parses them as JSON-RPC messages,
57-
and sends them into read_stream_send.
54+
and sends them into read_stream_writer.
5855
"""
5956
try:
60-
async for raw_text in ws:
61-
try:
62-
message = types.JSONRPCMessage.model_validate_json(raw_text)
63-
await read_stream_send.send(message)
64-
except Exception as exc:
65-
# If JSON parse or model validation fails, send the exception
66-
await read_stream_send.send(exc)
67-
except (anyio.ClosedResourceError, websockets.ConnectionClosed):
68-
pass
69-
finally:
70-
# Ensure our read stream is closed
71-
await read_stream_send.aclose()
57+
async with read_stream_writer:
58+
async for raw_text in ws:
59+
try:
60+
message = types.JSONRPCMessage.model_validate_json(raw_text)
61+
await read_stream_writer.send(message)
62+
except ValidationError as exc:
63+
# If JSON parse or model validation fails, send the exception
64+
await read_stream_writer.send(exc)
65+
except (anyio.ClosedResourceError, Exception):
66+
await ws.close()
7267

7368
async def ws_writer():
7469
"""
75-
Reads JSON-RPC messages from write_stream_recv and sends them to the server.
70+
Reads JSON-RPC messages from write_stream_reader and sends them to the server.
7671
"""
7772
try:
78-
async for message in write_stream_recv:
79-
# Convert to a dict, then to JSON
80-
msg_dict = message.model_dump(
81-
by_alias=True, mode="json", exclude_none=True
82-
)
83-
await ws.send(json.dumps(msg_dict))
84-
except (anyio.ClosedResourceError, websockets.ConnectionClosed):
85-
pass
86-
finally:
87-
# Ensure our write stream is closed
88-
await write_stream_recv.aclose()
73+
async with write_stream_reader:
74+
async for message in write_stream_reader:
75+
# Convert to a dict, then to JSON
76+
msg_dict = message.model_dump(
77+
by_alias=True, mode="json", exclude_none=True
78+
)
79+
await ws.send(json.dumps(msg_dict))
80+
except (anyio.ClosedResourceError, Exception):
81+
await ws.close()
8982

9083
async with anyio.create_task_group() as tg:
9184
# Start reader and writer tasks
9285
tg.start_soon(ws_reader)
9386
tg.start_soon(ws_writer)
9487

9588
# Yield the receive/send streams
96-
yield (read_stream_recv, write_stream_send)
89+
yield (read_stream, write_stream)
9790

9891
# Once the caller's 'async with' block exits, we shut down
9992
tg.cancel_scope.cancel()

tests/shared/test_ws.py

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import multiprocessing
2+
import socket
3+
import time
4+
from typing import AsyncGenerator, Generator
5+
6+
import anyio
7+
import pytest
8+
import uvicorn
9+
from pydantic import AnyUrl
10+
from starlette.applications import Starlette
11+
from starlette.requests import Request
12+
from starlette.routing import WebSocketRoute
13+
14+
from mcp.client.session import ClientSession
15+
from mcp.client.websocket import websocket_client
16+
from mcp.server import Server
17+
from mcp.server.websocket import websocket_server
18+
from mcp.shared.exceptions import McpError
19+
from mcp.types import (
20+
EmptyResult,
21+
ErrorData,
22+
InitializeResult,
23+
ReadResourceResult,
24+
TextContent,
25+
TextResourceContents,
26+
Tool,
27+
)
28+
29+
SERVER_NAME = "test_server_for_WS"
30+
31+
32+
@pytest.fixture
33+
def server_port() -> int:
34+
with socket.socket() as s:
35+
s.bind(("127.0.0.1", 0))
36+
return s.getsockname()[1]
37+
38+
39+
@pytest.fixture
40+
def server_url(server_port: int) -> str:
41+
return f"ws://127.0.0.1:{server_port}"
42+
43+
44+
# Test server implementation
45+
class ServerTest(Server):
46+
def __init__(self):
47+
super().__init__(SERVER_NAME)
48+
49+
@self.read_resource()
50+
async def handle_read_resource(uri: AnyUrl) -> str | bytes:
51+
if uri.scheme == "foobar":
52+
return f"Read {uri.host}"
53+
elif uri.scheme == "slow":
54+
# Simulate a slow resource
55+
await anyio.sleep(2.0)
56+
return f"Slow response from {uri.host}"
57+
58+
raise McpError(
59+
error=ErrorData(
60+
code=404, message="OOPS! no resource with that URI was found"
61+
)
62+
)
63+
64+
@self.list_tools()
65+
async def handle_list_tools() -> list[Tool]:
66+
return [
67+
Tool(
68+
name="test_tool",
69+
description="A test tool",
70+
inputSchema={"type": "object", "properties": {}},
71+
)
72+
]
73+
74+
@self.call_tool()
75+
async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
76+
return [TextContent(type="text", text=f"Called {name}")]
77+
78+
79+
# Test fixtures
80+
def make_server_app() -> Starlette:
81+
"""Create test Starlette app with WebSocket transport"""
82+
server = ServerTest()
83+
84+
async def handle_ws(websocket):
85+
async with websocket_server(
86+
websocket.scope, websocket.receive, websocket.send
87+
) as streams:
88+
await server.run(
89+
streams[0], streams[1], server.create_initialization_options()
90+
)
91+
92+
app = Starlette(
93+
routes=[
94+
WebSocketRoute("/ws", endpoint=handle_ws),
95+
]
96+
)
97+
98+
return app
99+
100+
101+
def run_server(server_port: int) -> None:
102+
app = make_server_app()
103+
server = uvicorn.Server(
104+
config=uvicorn.Config(
105+
app=app, host="127.0.0.1", port=server_port, log_level="error"
106+
)
107+
)
108+
print(f"starting server on {server_port}")
109+
server.run()
110+
111+
# Give server time to start
112+
while not server.started:
113+
print("waiting for server to start")
114+
time.sleep(0.5)
115+
116+
117+
@pytest.fixture()
118+
def server(server_port: int) -> Generator[None, None, None]:
119+
proc = multiprocessing.Process(
120+
target=run_server, kwargs={"server_port": server_port}, daemon=True
121+
)
122+
print("starting process")
123+
proc.start()
124+
125+
# Wait for server to be running
126+
max_attempts = 20
127+
attempt = 0
128+
print("waiting for server to start")
129+
while attempt < max_attempts:
130+
try:
131+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
132+
s.connect(("127.0.0.1", server_port))
133+
break
134+
except ConnectionRefusedError:
135+
time.sleep(0.1)
136+
attempt += 1
137+
else:
138+
raise RuntimeError(
139+
"Server failed to start after {} attempts".format(max_attempts)
140+
)
141+
142+
yield
143+
144+
print("killing server")
145+
# Signal the server to stop
146+
proc.kill()
147+
proc.join(timeout=2)
148+
if proc.is_alive():
149+
print("server process failed to terminate")
150+
151+
152+
@pytest.fixture()
153+
async def initialized_ws_client_session(
154+
server, server_url: str
155+
) -> AsyncGenerator[ClientSession, None]:
156+
"""Create and initialize a WebSocket client session"""
157+
async with websocket_client(server_url + "/ws") as streams:
158+
async with ClientSession(*streams) as session:
159+
# Test initialization
160+
result = await session.initialize()
161+
assert isinstance(result, InitializeResult)
162+
assert result.serverInfo.name == SERVER_NAME
163+
164+
# Test ping
165+
ping_result = await session.send_ping()
166+
assert isinstance(ping_result, EmptyResult)
167+
168+
yield session
169+
170+
171+
# Tests
172+
@pytest.mark.anyio
173+
async def test_ws_client_basic_connection(server: None, server_url: str) -> None:
174+
"""Test the WebSocket connection establishment"""
175+
async with websocket_client(server_url + "/ws") as streams:
176+
async with ClientSession(*streams) as session:
177+
# Test initialization
178+
result = await session.initialize()
179+
assert isinstance(result, InitializeResult)
180+
assert result.serverInfo.name == SERVER_NAME
181+
182+
# Test ping
183+
ping_result = await session.send_ping()
184+
assert isinstance(ping_result, EmptyResult)
185+
186+
187+
@pytest.mark.anyio
188+
async def test_ws_client_happy_request_and_response(
189+
initialized_ws_client_session: ClientSession,
190+
) -> None:
191+
"""Test a successful request and response via WebSocket"""
192+
result = await initialized_ws_client_session.read_resource("foobar://example")
193+
assert isinstance(result, ReadResourceResult)
194+
assert isinstance(result.contents, list)
195+
assert len(result.contents) > 0
196+
assert isinstance(result.contents[0], TextResourceContents)
197+
assert result.contents[0].text == "Read example"
198+
199+
200+
@pytest.mark.anyio
201+
async def test_ws_client_exception_handling(
202+
initialized_ws_client_session: ClientSession,
203+
) -> None:
204+
"""Test exception handling in WebSocket communication"""
205+
with pytest.raises(McpError) as exc_info:
206+
await initialized_ws_client_session.read_resource("unknown://example")
207+
assert exc_info.value.error.code == 404
208+
209+
210+
@pytest.mark.anyio
211+
async def test_ws_client_timeout(
212+
initialized_ws_client_session: ClientSession,
213+
) -> None:
214+
"""Test timeout handling in WebSocket communication"""
215+
# Set a very short timeout to trigger a timeout exception
216+
with pytest.raises(TimeoutError):
217+
with anyio.fail_after(0.1): # 100ms timeout
218+
await initialized_ws_client_session.read_resource("slow://example")
219+
220+
# Now test that we can still use the session after a timeout
221+
with anyio.fail_after(5): # Longer timeout to allow completion
222+
result = await initialized_ws_client_session.read_resource("foobar://example")
223+
assert isinstance(result, ReadResourceResult)
224+
assert isinstance(result.contents, list)
225+
assert len(result.contents) > 0
226+
assert isinstance(result.contents[0], TextResourceContents)
227+
assert result.contents[0].text == "Read example"

0 commit comments

Comments
 (0)