Skip to content

Commit cbace92

Browse files
authored
Force keyword arguments in some cases to avoid ambiguity (#47)
1 parent 3db311a commit cbace92

13 files changed

+76
-67
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ line-length = 120
5454
[tool.ruff.lint]
5555
preview = true
5656
select = ["ALL"]
57-
ignore = ["D1", "D203", "D213", "COM812", "ISC001", "CPY001"]
57+
ignore = ["D1", "D203", "D213", "COM812", "ISC001", "CPY001", "PLR0913", "PLC2801"]
5858
extend-per-file-ignores = { "tests/*" = ["S101", "SLF001", "ARG"] }
5959

6060
[tool.pytest.ini_options]

stompman/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
@asynccontextmanager
34-
async def connection_lifespan( # noqa: PLR0913
34+
async def connection_lifespan(
3535
*,
3636
connection: AbstractConnection,
3737
connection_parameters: ConnectionParameters,
@@ -122,7 +122,7 @@ async def unsubscribe(self) -> None:
122122
del self._active_subscriptions[self.id]
123123
await self._connection_manager.maybe_write_frame(UnsubscribeFrame(headers={"id": self.id}))
124124

125-
async def _run_handler(self, frame: MessageFrame) -> None:
125+
async def _run_handler(self, *, frame: MessageFrame) -> None:
126126
try:
127127
await self.handler(frame)
128128
except self.supressed_exception_classes as exception:
@@ -185,7 +185,7 @@ async def __aexit__(
185185
self._active_transactions.remove(self)
186186

187187
async def send(
188-
self, body: bytes, destination: str, content_type: str | None = None, headers: dict[str, str] | None = None
188+
self, body: bytes, destination: str, *, content_type: str | None = None, headers: dict[str, str] | None = None
189189
) -> None:
190190
frame = SendFrame.build(
191191
body=body, destination=destination, transaction=self.id, content_type=content_type, headers=headers
@@ -295,7 +295,7 @@ async def _listen_to_frames(self) -> None:
295295
match frame:
296296
case MessageFrame():
297297
if subscription := self._active_subscriptions.get(frame.headers["subscription"]):
298-
task_group.create_task(subscription._run_handler(frame)) # noqa: SLF001
298+
task_group.create_task(subscription._run_handler(frame=frame)) # noqa: SLF001
299299
elif self.on_unhandled_message_frame:
300300
self.on_unhandled_message_frame(frame)
301301
case ErrorFrame():
@@ -308,7 +308,7 @@ async def _listen_to_frames(self) -> None:
308308
pass
309309

310310
async def send(
311-
self, body: bytes, destination: str, content_type: str | None = None, headers: dict[str, str] | None = None
311+
self, body: bytes, destination: str, *, content_type: str | None = None, headers: dict[str, str] | None = None
312312
) -> None:
313313
await self._connection_manager.write_frame_reconnecting(
314314
SendFrame.build(
@@ -323,7 +323,7 @@ async def begin(self) -> AsyncGenerator[Transaction, None]:
323323
) as transaction:
324324
yield transaction
325325

326-
async def subscribe( # noqa: PLR0913
326+
async def subscribe(
327327
self,
328328
destination: str,
329329
handler: Callable[[MessageFrame], Coroutine[None, None, None]],

stompman/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from dataclasses import dataclass, field
2-
from typing import NamedTuple, Self, TypedDict
2+
from typing import Self, TypedDict
33
from urllib.parse import unquote
44

55

6-
class Heartbeat(NamedTuple):
6+
@dataclass(frozen=True, slots=True)
7+
class Heartbeat:
78
will_send_interval_ms: int
89
want_to_receive_interval_ms: int
910

stompman/connection.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
@dataclass(kw_only=True)
1414
class AbstractConnection(Protocol):
1515
@classmethod
16-
async def connect( # noqa: PLR0913
17-
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
16+
async def connect(
17+
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
1818
) -> Self | None: ...
1919
async def close(self) -> None: ...
2020
def write_heartbeat(self) -> None: ...
@@ -38,8 +38,8 @@ class Connection(AbstractConnection):
3838
read_timeout: int
3939

4040
@classmethod
41-
async def connect( # noqa: PLR0913
42-
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
41+
async def connect(
42+
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
4343
) -> Self | None:
4444
try:
4545
reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout)

stompman/connection_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ConnectionManager:
2727
connect_timeout: int
2828
read_timeout: int
2929
read_max_chunk_size: int
30+
3031
_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
3132
_reconnect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
3233

@@ -87,7 +88,7 @@ async def _get_active_connection_state(self) -> ActiveConnectionState:
8788
self._active_connection_state = ActiveConnectionState(connection=connection, lifespan=lifespan)
8889

8990
try:
90-
await lifespan.__aenter__() # noqa: PLC2801
91+
await lifespan.__aenter__()
9192
except ConnectionLostError:
9293
self._clear_active_connection_state()
9394
else:

stompman/errors.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def __str__(self) -> str:
1313
return self.__repr__()
1414

1515

16+
@dataclass(frozen=True, kw_only=True, slots=True)
17+
class ConnectionLostError(Error):
18+
"""Raised in stompman.AbstractConnection—and handled in stompman.ConnectionManager, therefore is private."""
19+
20+
1621
@dataclass(frozen=True, kw_only=True, slots=True)
1722
class ConnectionConfirmationTimeoutError(Error):
1823
timeout: int
@@ -36,7 +41,3 @@ class FailedAllConnectAttemptsError(Error):
3641
@dataclass(frozen=True, kw_only=True, slots=True)
3742
class RepeatedConnectionLostError(Error):
3843
retry_attempts: int
39-
40-
41-
@dataclass(frozen=True, kw_only=True, slots=True)
42-
class ConnectionLostError(Error): ...

stompman/frames.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class SendFrame:
142142
body: bytes = b""
143143

144144
@classmethod
145-
def build( # noqa: PLR0913
145+
def build(
146146
cls,
147147
*,
148148
body: bytes,

stompman/serde.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def dump_frame(frame: AnyClientFrame | AnyRealServerFrame) -> bytes:
9898
return b"".join(lines)
9999

100100

101-
def unescape_byte(byte: bytes, previous_byte: bytes | None) -> bytes | None:
101+
def unescape_byte(*, byte: bytes, previous_byte: bytes | None) -> bytes | None:
102102
if previous_byte == BACKSLASH:
103103
return HEADER_UNESCAPE_CHARS.get(byte)
104104
if byte == BACKSLASH:
@@ -123,7 +123,7 @@ def parse_header(buffer: bytearray) -> tuple[str, str] | None:
123123
just_escaped_line = False
124124
if byte != BACKSLASH:
125125
(value_buffer if key_parsed else key_buffer).extend(byte)
126-
elif unescaped_byte := unescape_byte(byte, previous_byte):
126+
elif unescaped_byte := unescape_byte(byte=byte, previous_byte=previous_byte):
127127
just_escaped_line = True
128128
(value_buffer if key_parsed else key_buffer).extend(unescaped_byte)
129129

@@ -136,13 +136,10 @@ def parse_header(buffer: bytearray) -> tuple[str, str] | None:
136136
return None
137137

138138

139-
def make_frame_from_parts(command: bytes, headers: dict[str, str], body: bytes) -> AnyClientFrame | AnyServerFrame:
139+
def make_frame_from_parts(*, command: bytes, headers: dict[str, str], body: bytes) -> AnyClientFrame | AnyServerFrame:
140140
frame_type = COMMANDS_TO_FRAMES[command]
141-
return (
142-
frame_type(headers=cast(Any, headers), body=body) # type: ignore[call-arg]
143-
if frame_type in FRAMES_WITH_BODY
144-
else frame_type(headers=cast(Any, headers)) # type: ignore[call-arg]
145-
)
141+
headers_ = cast(Any, headers)
142+
return frame_type(headers=headers_, body=body) if frame_type in FRAMES_WITH_BODY else frame_type(headers=headers_) # type: ignore[call-arg]
146143

147144

148145
def parse_lines_into_frame(lines: deque[bytearray]) -> AnyClientFrame | AnyServerFrame:

testing/consumer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66

77
async def main() -> None:
8-
async with stompman.Client(servers=[CONNECTION_PARAMETERS]) as client:
9-
10-
async def handle_message(frame: stompman.MessageFrame) -> None: # noqa: RUF029
11-
print(frame) # noqa: T201
8+
async def handle_message(frame: stompman.MessageFrame) -> None: # noqa: RUF029
9+
print(frame) # noqa: T201
1210

11+
async with stompman.Client(servers=[CONNECTION_PARAMETERS]) as client:
1312
await client.subscribe("DLQ", handler=handle_message, on_suppressed_exception=print)
1413

1514

tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from polyfactory.factories.dataclass_factory import DataclassFactory
1010

1111
import stompman
12+
from stompman.frames import HeartbeatFrame
1213

1314

1415
@pytest.fixture(
@@ -34,8 +35,8 @@ def noop_error_handler(exception: Exception, frame: stompman.MessageFrame) -> No
3435

3536
class BaseMockConnection(stompman.AbstractConnection):
3637
@classmethod
37-
async def connect( # noqa: PLR0913
38-
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
38+
async def connect(
39+
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
3940
) -> Self | None:
4041
return cls()
4142

@@ -45,7 +46,7 @@ async def write_frame(self, frame: stompman.AnyClientFrame) -> None: ...
4546
@staticmethod
4647
async def read_frames() -> AsyncGenerator[stompman.AnyServerFrame, None]: # pragma: no cover
4748
await asyncio.Future()
48-
yield # type: ignore[misc]
49+
yield HeartbeatFrame()
4950

5051

5152
@dataclass(kw_only=True, slots=True)

tests/test_client.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
148148
await asyncio.sleep(0)
149149

150150
with pytest.raises(ConnectionConfirmationTimeoutError) as exc_info:
151-
await EnrichedClient( # noqa: PLC2801
151+
await EnrichedClient(
152152
connection_class=MockConnection, connection_confirmation_timeout=connection_confirmation_timeout
153153
).__aenter__()
154154

@@ -161,7 +161,7 @@ async def test_client_connection_lifespan_unsupported_protocol_version() -> None
161161
given_version = FAKER.pystr()
162162

163163
with pytest.raises(UnsupportedProtocolVersionError) as exc_info:
164-
await EnrichedClient( # noqa: PLC2801
164+
await EnrichedClient(
165165
connection_class=create_spying_connection(
166166
[build_dataclass(ConnectedFrame, headers={"version": given_version})]
167167
)[0]
@@ -279,7 +279,8 @@ async def test_client_subscribtions_lifespan_no_active_subs_in_aexit(monkeypatch
279279
@pytest.mark.parametrize("direct_error", [True, False])
280280
async def test_client_subscribtions_lifespan_with_active_subs_in_aexit(
281281
monkeypatch: pytest.MonkeyPatch,
282-
direct_error: bool, # noqa: FBT001
282+
*,
283+
direct_error: bool,
283284
) -> None:
284285
subscription_id, destination = FAKER.pystr(), FAKER.pystr()
285286
monkeypatch.setattr(stompman.client, "_make_subscription_id", mock.Mock(return_value=subscription_id))
@@ -388,7 +389,7 @@ async def test_client_listen_unsubscribe_before_ack_or_nack(
388389

389390
@pytest.mark.parametrize("ok", [True, False])
390391
@pytest.mark.parametrize("ack", ["client", "client-individual"])
391-
async def test_client_listen_ack_nack_sent(monkeypatch: pytest.MonkeyPatch, ack: AckMode, ok: bool) -> None: # noqa: FBT001
392+
async def test_client_listen_ack_nack_sent(monkeypatch: pytest.MonkeyPatch, ack: AckMode, *, ok: bool) -> None:
392393
subscription_id, destination, message_id = FAKER.pystr(), FAKER.pystr(), FAKER.pystr()
393394
monkeypatch.setattr(stompman.client, "_make_subscription_id", mock.Mock(return_value=subscription_id))
394395

@@ -418,7 +419,7 @@ async def test_client_listen_ack_nack_sent(monkeypatch: pytest.MonkeyPatch, ack:
418419

419420

420421
@pytest.mark.parametrize("ok", [True, False])
421-
async def test_client_listen_auto_ack_nack(monkeypatch: pytest.MonkeyPatch, ok: bool) -> None: # noqa: FBT001
422+
async def test_client_listen_auto_ack_nack(monkeypatch: pytest.MonkeyPatch, *, ok: bool) -> None:
422423
subscription_id, destination, message_id = FAKER.pystr(), FAKER.pystr(), FAKER.pystr()
423424
monkeypatch.setattr(stompman.client, "_make_subscription_id", mock.Mock(return_value=subscription_id))
424425

tests/test_connection.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ class MockWriter:
7676
HeartbeatFrame(),
7777
ConnectedFrame(headers={"heart-beat": "0,0", "version": "1.2", "server": "some server"}),
7878
]
79-
max_chunk_size = 1024
8079

8180
class MockReader:
8281
read = mock.AsyncMock(side_effect=read_bytes)
@@ -100,7 +99,7 @@ async def take_frames(count: int) -> list[AnyServerFrame]:
10099
MockWriter.close.assert_called_once_with()
101100
MockWriter.wait_closed.assert_called_once_with()
102101
MockWriter.drain.assert_called_once_with()
103-
MockReader.read.mock_calls = [mock.call(max_chunk_size)] * len(read_bytes) # type: ignore[assignment]
102+
assert MockReader.read.mock_calls == [mock.call(connection.read_max_chunk_size)] * len(read_bytes)
104103
assert MockWriter.write.mock_calls == [mock.call(NEWLINE), mock.call(b"COMMIT\ntransaction:transaction\n\n\x00")]
105104

106105

tests/test_connection_manager.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,21 @@ async def test_connect_to_one_server_ok(ok_on_attempt: int, monkeypatch: pytest.
2929

3030
class MockConnection(BaseMockConnection):
3131
@classmethod
32-
async def connect( # noqa: PLR0913
33-
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
32+
async def connect(
33+
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
3434
) -> Self | None:
3535
assert (host, port) == (manager.servers[0].host, manager.servers[0].port)
3636
nonlocal attempts
3737
attempts += 1
3838

3939
return (
40-
await super().connect(host, port, timeout, read_max_chunk_size, read_timeout)
40+
await super().connect(
41+
host=host,
42+
port=port,
43+
timeout=timeout,
44+
read_max_chunk_size=read_max_chunk_size,
45+
read_timeout=read_timeout,
46+
)
4147
if attempts == ok_on_attempt
4248
else None
4349
)
@@ -60,11 +66,17 @@ class MockConnection(BaseMockConnection):
6066
async def test_connect_to_any_server_ok() -> None:
6167
class MockConnection(BaseMockConnection):
6268
@classmethod
63-
async def connect( # noqa: PLR0913
64-
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
69+
async def connect(
70+
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
6571
) -> Self | None:
6672
return (
67-
await super().connect(host, port, timeout, read_max_chunk_size, read_timeout)
73+
await super().connect(
74+
host=host,
75+
port=port,
76+
timeout=timeout,
77+
read_max_chunk_size=read_max_chunk_size,
78+
read_timeout=read_timeout,
79+
)
6880
if port == successful_server.port
6981
else None
7082
)
@@ -224,19 +236,17 @@ class MockConnection(BaseMockConnection):
224236

225237

226238
async def test_read_frames_reconnecting_raises() -> None:
227-
async def read_frames_mock(self: object) -> AsyncGenerator[AnyServerFrame, None]:
228-
raise ConnectionLostError
229-
yield
230-
await asyncio.sleep(0)
231-
232239
class MockConnection(BaseMockConnection):
233-
read_frames = read_frames_mock # type: ignore[assignment]
240+
@staticmethod
241+
async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
242+
raise ConnectionLostError
243+
yield
244+
await asyncio.sleep(0)
234245

235246
manager = EnrichedConnectionManager(connection_class=MockConnection)
236247

237-
with pytest.raises(RepeatedConnectionLostError): # noqa: PT012
238-
async for _ in manager.read_frames_reconnecting():
239-
pass # pragma: no cover
248+
with pytest.raises(RepeatedConnectionLostError):
249+
[_ async for _ in manager.read_frames_reconnecting()]
240250

241251

242252
SIDE_EFFECTS = [(None,), (ConnectionLostError(), None), (ConnectionLostError(), ConnectionLostError(), None)]
@@ -279,18 +289,17 @@ async def test_read_frames_reconnecting_ok(side_effect: tuple[None | ConnectionL
279289
]
280290
attempt = -1
281291

282-
async def read_frames_mock(self: object) -> AsyncGenerator[AnyServerFrame, None]:
283-
nonlocal attempt
284-
attempt += 1
285-
current_effect = side_effect[attempt]
286-
if isinstance(current_effect, ConnectionLostError):
287-
raise ConnectionLostError
288-
for frame in frames:
289-
yield frame
290-
await asyncio.sleep(0)
291-
292292
class MockConnection(BaseMockConnection):
293-
read_frames = read_frames_mock # type: ignore[assignment]
293+
@staticmethod
294+
async def read_frames() -> AsyncGenerator[AnyServerFrame, None]:
295+
nonlocal attempt
296+
attempt += 1
297+
current_effect = side_effect[attempt]
298+
if isinstance(current_effect, ConnectionLostError):
299+
raise ConnectionLostError
300+
for frame in frames:
301+
yield frame
302+
await asyncio.sleep(0)
294303

295304
manager = EnrichedConnectionManager(connection_class=MockConnection)
296305

0 commit comments

Comments
 (0)