Skip to content

Commit 00ac534

Browse files
authored
Show received frames in ConnectionConfirmationTimeoutError (#38)
1 parent 4ef70f4 commit 00ac534

File tree

6 files changed

+46
-25
lines changed

6 files changed

+46
-25
lines changed

Justfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ install:
55
uv -q sync
66

77
test *args:
8-
uv -q run pytest {{args}}
8+
.venv/bin/pytest {{args}}
99

1010
lint:
1111
uv -q run ruff check .

stompman/client.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,26 @@ async def _connect_to_any_server(self) -> None:
160160
timeout=self.connect_timeout,
161161
)
162162

163+
async def _wait_for_connected_frame(self) -> ConnectedFrame:
164+
collected_frames = []
165+
166+
async def inner() -> ConnectedFrame:
167+
async for frame in self._connection.read_frames(
168+
max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
169+
):
170+
if isinstance(frame, ConnectedFrame):
171+
return frame
172+
collected_frames.append(frame)
173+
msg = "unreachable" # pragma: no cover
174+
raise AssertionError(msg) # pragma: no cover
175+
176+
try:
177+
return await asyncio.wait_for(inner(), timeout=self.connection_confirmation_timeout)
178+
except TimeoutError as exception:
179+
raise ConnectionConfirmationTimeoutError(
180+
timeout=self.connection_confirmation_timeout, frames=collected_frames
181+
) from exception
182+
163183
@asynccontextmanager
164184
async def _connection_lifespan(self) -> AsyncGenerator[None, None]:
165185
# On startup:
@@ -182,15 +202,7 @@ async def _connection_lifespan(self) -> AsyncGenerator[None, None]:
182202
},
183203
)
184204
)
185-
try:
186-
connected_frame = await asyncio.wait_for(
187-
self._connection.read_frame_of_type(
188-
ConnectedFrame, max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
189-
),
190-
timeout=self.connection_confirmation_timeout,
191-
)
192-
except TimeoutError as exception:
193-
raise ConnectionConfirmationTimeoutError(self.connection_confirmation_timeout) from exception
205+
connected_frame = await self._wait_for_connected_frame()
194206

195207
if connected_frame.headers["version"] != self.PROTOCOL_VERSION:
196208
raise UnsupportedProtocolVersionError(
@@ -222,9 +234,11 @@ async def send_heartbeats_forever() -> None:
222234
if self._connection.active:
223235
await self._connection.write_frame(DisconnectFrame(headers={"receipt": str(uuid4())}))
224236
if self._connection.active:
225-
await self._connection.read_frame_of_type(
226-
ReceiptFrame, max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
227-
)
237+
async for frame in self._connection.read_frames(
238+
max_chunk_size=self.read_max_chunk_size, timeout=self.read_timeout
239+
):
240+
if isinstance(frame, ReceiptFrame):
241+
break
228242

229243
@asynccontextmanager
230244
async def enter_transaction(self) -> AsyncGenerator[str, None]:

stompman/connection.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,6 @@ def write_heartbeat(self) -> None: ...
2323
async def write_frame(self, frame: AnyClientFrame) -> None: ...
2424
def read_frames(self, max_chunk_size: int, timeout: int) -> AsyncGenerator[AnyServerFrame, None]: ...
2525

26-
async def read_frame_of_type(self, type_: type[FrameType], max_chunk_size: int, timeout: int) -> FrameType:
27-
while True:
28-
async for frame in self.read_frames(max_chunk_size=max_chunk_size, timeout=timeout):
29-
if isinstance(frame, type_):
30-
return frame
31-
3226

3327
@dataclass(kw_only=True)
3428
class Connection(AbstractConnection):

stompman/errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass
22
from typing import TYPE_CHECKING
33

4+
from stompman.frames import ErrorFrame, MessageFrame, ReceiptFrame
5+
46
if TYPE_CHECKING:
57
from stompman.client import ConnectionParameters
68

@@ -14,6 +16,7 @@ def __str__(self) -> str:
1416
@dataclass
1517
class ConnectionConfirmationTimeoutError(Error):
1618
timeout: int
19+
frames: list[MessageFrame | ReceiptFrame | ErrorFrame]
1720

1821

1922
@dataclass

tests/test_client.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import asyncio
2-
from collections.abc import AsyncGenerator, Awaitable
2+
from collections.abc import AsyncGenerator, Coroutine
33
from contextlib import suppress
44
from dataclasses import dataclass, field
55
from typing import Any, Self
@@ -227,18 +227,28 @@ class MockConnection(connection_class): # type: ignore[valid-type, misc]
227227

228228

229229
async def test_client_lifespan_connection_not_confirmed(monkeypatch: pytest.MonkeyPatch) -> None:
230-
async def timeout(future: Awaitable[Any], timeout: float) -> object:
230+
async def timeout(future: Coroutine[Any, Any, Any], timeout: float) -> object:
231231
assert timeout == client.connection_confirmation_timeout
232-
return await original_wait_for(future, 0)
232+
task = asyncio.create_task(future)
233+
await asyncio.sleep(0)
234+
return await original_wait_for(task, 0)
233235

234236
original_wait_for = asyncio.wait_for
235237
monkeypatch.setattr("asyncio.wait_for", timeout)
236238

237-
client = EnrichedClient(connection_class=BaseMockConnection)
239+
class MockConnection(BaseMockConnection):
240+
@staticmethod
241+
async def read_frames(max_chunk_size: int, timeout: int) -> AsyncGenerator[AnyServerFrame, None]:
242+
yield ErrorFrame(headers={"message": "hi"})
243+
await asyncio.sleep(0)
244+
245+
client = EnrichedClient(connection_class=MockConnection)
238246
with pytest.raises(ConnectionConfirmationTimeoutError) as exc_info:
239247
await client.__aenter__() # noqa: PLC2801
240248

241-
assert exc_info.value == ConnectionConfirmationTimeoutError(client.connection_confirmation_timeout)
249+
assert exc_info.value == ConnectionConfirmationTimeoutError(
250+
timeout=client.connection_confirmation_timeout, frames=[ErrorFrame(headers={"message": "hi"})]
251+
)
242252

243253

244254
async def test_client_lifespan_unsupported_protocol_version() -> None:

tests/test_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22

33

44
def test_error_str() -> None:
5-
error = ConnectionConfirmationTimeoutError(timeout=1)
5+
error = ConnectionConfirmationTimeoutError(timeout=1, frames=[])
66
assert str(error) == repr(error)

0 commit comments

Comments
 (0)