Skip to content

Commit cbace92

Browse files
authored
Force keyword arguments in some cases to avoid ambiguity (#47)
1 parent 3db311a commit cbace92

13 files changed

+76
-67
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ line-length = 120
5454
[tool.ruff.lint]
5555
preview = true
5656
select = ["ALL"]
57-
ignore = ["D1", "D203", "D213", "COM812", "ISC001", "CPY001"]
57+
ignore = ["D1", "D203", "D213", "COM812", "ISC001", "CPY001", "PLR0913", "PLC2801"]
5858
extend-per-file-ignores = { "tests/*" = ["S101", "SLF001", "ARG"] }
5959

6060
[tool.pytest.ini_options]

stompman/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232

3333
@asynccontextmanager
34-
async def connection_lifespan( # noqa: PLR0913
34+
async def connection_lifespan(
3535
*,
3636
connection: AbstractConnection,
3737
connection_parameters: ConnectionParameters,
@@ -122,7 +122,7 @@ async def unsubscribe(self) -> None:
122122
del self._active_subscriptions[self.id]
123123
await self._connection_manager.maybe_write_frame(UnsubscribeFrame(headers={"id": self.id}))
124124

125-
async def _run_handler(self, frame: MessageFrame) -> None:
125+
async def _run_handler(self, *, frame: MessageFrame) -> None:
126126
try:
127127
await self.handler(frame)
128128
except self.supressed_exception_classes as exception:
@@ -185,7 +185,7 @@ async def __aexit__(
185185
self._active_transactions.remove(self)
186186

187187
async def send(
188-
self, body: bytes, destination: str, content_type: str | None = None, headers: dict[str, str] | None = None
188+
self, body: bytes, destination: str, *, content_type: str | None = None, headers: dict[str, str] | None = None
189189
) -> None:
190190
frame = SendFrame.build(
191191
body=body, destination=destination, transaction=self.id, content_type=content_type, headers=headers
@@ -295,7 +295,7 @@ async def _listen_to_frames(self) -> None:
295295
match frame:
296296
case MessageFrame():
297297
if subscription := self._active_subscriptions.get(frame.headers["subscription"]):
298-
task_group.create_task(subscription._run_handler(frame)) # noqa: SLF001
298+
task_group.create_task(subscription._run_handler(frame=frame)) # noqa: SLF001
299299
elif self.on_unhandled_message_frame:
300300
self.on_unhandled_message_frame(frame)
301301
case ErrorFrame():
@@ -308,7 +308,7 @@ async def _listen_to_frames(self) -> None:
308308
pass
309309

310310
async def send(
311-
self, body: bytes, destination: str, content_type: str | None = None, headers: dict[str, str] | None = None
311+
self, body: bytes, destination: str, *, content_type: str | None = None, headers: dict[str, str] | None = None
312312
) -> None:
313313
await self._connection_manager.write_frame_reconnecting(
314314
SendFrame.build(
@@ -323,7 +323,7 @@ async def begin(self) -> AsyncGenerator[Transaction, None]:
323323
) as transaction:
324324
yield transaction
325325

326-
async def subscribe( # noqa: PLR0913
326+
async def subscribe(
327327
self,
328328
destination: str,
329329
handler: Callable[[MessageFrame], Coroutine[None, None, None]],

stompman/config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from dataclasses import dataclass, field
2-
from typing import NamedTuple, Self, TypedDict
2+
from typing import Self, TypedDict
33
from urllib.parse import unquote
44

55

6-
class Heartbeat(NamedTuple):
6+
@dataclass(frozen=True, slots=True)
7+
class Heartbeat:
78
will_send_interval_ms: int
89
want_to_receive_interval_ms: int
910

stompman/connection.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
@dataclass(kw_only=True)
1414
class AbstractConnection(Protocol):
1515
@classmethod
16-
async def connect( # noqa: PLR0913
17-
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
16+
async def connect(
17+
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
1818
) -> Self | None: ...
1919
async def close(self) -> None: ...
2020
def write_heartbeat(self) -> None: ...
@@ -38,8 +38,8 @@ class Connection(AbstractConnection):
3838
read_timeout: int
3939

4040
@classmethod
41-
async def connect( # noqa: PLR0913
42-
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
41+
async def connect(
42+
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
4343
) -> Self | None:
4444
try:
4545
reader, writer = await asyncio.wait_for(asyncio.open_connection(host, port), timeout=timeout)

stompman/connection_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class ConnectionManager:
2727
connect_timeout: int
2828
read_timeout: int
2929
read_max_chunk_size: int
30+
3031
_active_connection_state: ActiveConnectionState | None = field(default=None, init=False)
3132
_reconnect_lock: asyncio.Lock = field(default_factory=asyncio.Lock)
3233

@@ -87,7 +88,7 @@ async def _get_active_connection_state(self) -> ActiveConnectionState:
8788
self._active_connection_state = ActiveConnectionState(connection=connection, lifespan=lifespan)
8889

8990
try:
90-
await lifespan.__aenter__() # noqa: PLC2801
91+
await lifespan.__aenter__()
9192
except ConnectionLostError:
9293
self._clear_active_connection_state()
9394
else:

stompman/errors.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ def __str__(self) -> str:
1313
return self.__repr__()
1414

1515

16+
@dataclass(frozen=True, kw_only=True, slots=True)
17+
class ConnectionLostError(Error):
18+
"""Raised in stompman.AbstractConnection—and handled in stompman.ConnectionManager, therefore is private."""
19+
20+
1621
@dataclass(frozen=True, kw_only=True, slots=True)
1722
class ConnectionConfirmationTimeoutError(Error):
1823
timeout: int
@@ -36,7 +41,3 @@ class FailedAllConnectAttemptsError(Error):
3641
@dataclass(frozen=True, kw_only=True, slots=True)
3742
class RepeatedConnectionLostError(Error):
3843
retry_attempts: int
39-
40-
41-
@dataclass(frozen=True, kw_only=True, slots=True)
42-
class ConnectionLostError(Error): ...

stompman/frames.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class SendFrame:
142142
body: bytes = b""
143143

144144
@classmethod
145-
def build( # noqa: PLR0913
145+
def build(
146146
cls,
147147
*,
148148
body: bytes,

stompman/serde.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def dump_frame(frame: AnyClientFrame | AnyRealServerFrame) -> bytes:
9898
return b"".join(lines)
9999

100100

101-
def unescape_byte(byte: bytes, previous_byte: bytes | None) -> bytes | None:
101+
def unescape_byte(*, byte: bytes, previous_byte: bytes | None) -> bytes | None:
102102
if previous_byte == BACKSLASH:
103103
return HEADER_UNESCAPE_CHARS.get(byte)
104104
if byte == BACKSLASH:
@@ -123,7 +123,7 @@ def parse_header(buffer: bytearray) -> tuple[str, str] | None:
123123
just_escaped_line = False
124124
if byte != BACKSLASH:
125125
(value_buffer if key_parsed else key_buffer).extend(byte)
126-
elif unescaped_byte := unescape_byte(byte, previous_byte):
126+
elif unescaped_byte := unescape_byte(byte=byte, previous_byte=previous_byte):
127127
just_escaped_line = True
128128
(value_buffer if key_parsed else key_buffer).extend(unescaped_byte)
129129

@@ -136,13 +136,10 @@ def parse_header(buffer: bytearray) -> tuple[str, str] | None:
136136
return None
137137

138138

139-
def make_frame_from_parts(command: bytes, headers: dict[str, str], body: bytes) -> AnyClientFrame | AnyServerFrame:
139+
def make_frame_from_parts(*, command: bytes, headers: dict[str, str], body: bytes) -> AnyClientFrame | AnyServerFrame:
140140
frame_type = COMMANDS_TO_FRAMES[command]
141-
return (
142-
frame_type(headers=cast(Any, headers), body=body) # type: ignore[call-arg]
143-
if frame_type in FRAMES_WITH_BODY
144-
else frame_type(headers=cast(Any, headers)) # type: ignore[call-arg]
145-
)
141+
headers_ = cast(Any, headers)
142+
return frame_type(headers=headers_, body=body) if frame_type in FRAMES_WITH_BODY else frame_type(headers=headers_) # type: ignore[call-arg]
146143

147144

148145
def parse_lines_into_frame(lines: deque[bytearray]) -> AnyClientFrame | AnyServerFrame:

testing/consumer.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66

77
async def main() -> None:
8-
async with stompman.Client(servers=[CONNECTION_PARAMETERS]) as client:
9-
10-
async def handle_message(frame: stompman.MessageFrame) -> None: # noqa: RUF029
11-
print(frame) # noqa: T201
8+
async def handle_message(frame: stompman.MessageFrame) -> None: # noqa: RUF029
9+
print(frame) # noqa: T201
1210

11+
async with stompman.Client(servers=[CONNECTION_PARAMETERS]) as client:
1312
await client.subscribe("DLQ", handler=handle_message, on_suppressed_exception=print)
1413

1514

tests/conftest.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from polyfactory.factories.dataclass_factory import DataclassFactory
1010

1111
import stompman
12+
from stompman.frames import HeartbeatFrame
1213

1314

1415
@pytest.fixture(
@@ -34,8 +35,8 @@ def noop_error_handler(exception: Exception, frame: stompman.MessageFrame) -> No
3435

3536
class BaseMockConnection(stompman.AbstractConnection):
3637
@classmethod
37-
async def connect( # noqa: PLR0913
38-
cls, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
38+
async def connect(
39+
cls, *, host: str, port: int, timeout: int, read_max_chunk_size: int, read_timeout: int
3940
) -> Self | None:
4041
return cls()
4142

@@ -45,7 +46,7 @@ async def write_frame(self, frame: stompman.AnyClientFrame) -> None: ...
4546
@staticmethod
4647
async def read_frames() -> AsyncGenerator[stompman.AnyServerFrame, None]: # pragma: no cover
4748
await asyncio.Future()
48-
yield # type: ignore[misc]
49+
yield HeartbeatFrame()
4950

5051

5152
@dataclass(kw_only=True, slots=True)

0 commit comments

Comments
 (0)