|
1 | 1 | import asyncio
|
2 |
| -from collections.abc import AsyncGenerator |
| 2 | +from collections.abc import AsyncGenerator, AsyncIterable |
3 | 3 | from types import SimpleNamespace
|
4 | 4 | from typing import Self
|
5 | 5 | from unittest import mock
|
@@ -235,20 +235,6 @@ class MockConnection(BaseMockConnection):
|
235 | 235 | await manager.write_frame_reconnecting(build_dataclass(ConnectFrame))
|
236 | 236 |
|
237 | 237 |
|
238 |
| -async def test_read_frames_reconnecting_raises() -> None: |
239 |
| - class MockConnection(BaseMockConnection): |
240 |
| - @staticmethod |
241 |
| - async def read_frames() -> AsyncGenerator[AnyServerFrame, None]: |
242 |
| - raise ConnectionLostError |
243 |
| - yield |
244 |
| - await asyncio.sleep(0) |
245 |
| - |
246 |
| - manager = EnrichedConnectionManager(connection_class=MockConnection) |
247 |
| - |
248 |
| - with pytest.raises(RepeatedConnectionLostError): |
249 |
| - [_ async for _ in manager.read_frames_reconnecting()] |
250 |
| - |
251 |
| - |
252 | 238 | SIDE_EFFECTS = [(None,), (ConnectionLostError(), None), (ConnectionLostError(), ConnectionLostError(), None)]
|
253 | 239 |
|
254 | 240 |
|
@@ -299,11 +285,15 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
|
299 | 285 | raise ConnectionLostError
|
300 | 286 | for frame in frames:
|
301 | 287 | yield frame
|
302 |
| - await asyncio.sleep(0) |
303 | 288 |
|
304 | 289 | manager = EnrichedConnectionManager(connection_class=MockConnection)
|
305 | 290 |
|
306 |
| - assert frames == [frame async for frame in manager.read_frames_reconnecting()] |
| 291 | + async def take_all_frames() -> AsyncIterable[AnyServerFrame]: |
| 292 | + iterator = manager.read_frames_reconnecting() |
| 293 | + for _ in frames: |
| 294 | + yield await anext(iterator) |
| 295 | + |
| 296 | + assert frames == [frame async for frame in take_all_frames()] |
307 | 297 |
|
308 | 298 |
|
309 | 299 | async def test_maybe_write_frame_connection_already_lost() -> None:
|
|
0 commit comments