Skip to content

Commit d06cd07

Browse files
authored
Allow passing headers to client.subscribe() (#64)
1 parent db2b5e5 commit d06cd07

File tree

5 files changed

+41
-11
lines changed

5 files changed

+41
-11
lines changed

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ await client.subscribe("DLQ", handle_message_from_dlq, ack="client", on_suppress
100100
await client.subscribe("DLQ", handle_message_from_dlq, ack="auto", on_suppressed_exception=print)
101101
```
102102

103+
You can pass custom headers to `client.subscribe()`:
104+
105+
```python
106+
await client.subscribe("DLQ", handle_message_from_dlq, ack="client", headers={"selector": "location = 'Europe'"}, on_suppressed_exception=print)
107+
```
108+
103109
### Cleaning Up
104110

105111
stompman takes care of cleaning up resources automatically. When you leave the context of async context managers `stompman.Client()`, or `client.begin()`, the necessary frames will be sent to the server.

stompman/client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,12 +139,14 @@ async def subscribe(
139139
handler: Callable[[MessageFrame], Coroutine[None, None, None]],
140140
*,
141141
ack: AckMode = "client-individual",
142+
headers: dict[str, str] | None = None,
142143
on_suppressed_exception: Callable[[Exception, MessageFrame], None],
143144
supressed_exception_classes: tuple[type[Exception], ...] = (Exception,),
144145
) -> "Subscription":
145146
subscription = Subscription(
146147
destination=destination,
147148
handler=handler,
149+
headers=headers,
148150
ack=ack,
149151
on_suppressed_exception=on_suppressed_exception,
150152
supressed_exception_classes=supressed_exception_classes,

stompman/frames.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,20 +151,26 @@ def build(
151151
content_type: str | None,
152152
headers: dict[str, str] | None,
153153
) -> Self:
154-
full_headers: SendHeaders = headers or {} # type: ignore[assignment]
155-
full_headers["destination"] = destination
156-
full_headers["content-length"] = str(len(body))
154+
all_headers: SendHeaders = headers or {} # type: ignore[assignment]
155+
all_headers["destination"] = destination
156+
all_headers["content-length"] = str(len(body))
157157
if content_type is not None:
158-
full_headers["content-type"] = content_type
158+
all_headers["content-type"] = content_type
159159
if transaction is not None:
160-
full_headers["transaction"] = transaction
161-
return cls(headers=full_headers, body=body)
160+
all_headers["transaction"] = transaction
161+
return cls(headers=all_headers, body=body)
162162

163163

164164
@dataclass(frozen=True, kw_only=True, slots=True)
165165
class SubscribeFrame:
166166
headers: SubscribeHeaders
167167

168+
@classmethod
169+
def build(cls, *, subscription_id: str, destination: str, ack: AckMode, headers: dict[str, str] | None) -> Self:
170+
all_headers: SubscribeHeaders = headers.copy() if headers else {} # type: ignore[assignment, typeddict-item]
171+
all_headers.update({"id": subscription_id, "destination": destination, "ack": ack})
172+
return cls(headers=all_headers)
173+
168174

169175
@dataclass(frozen=True, kw_only=True, slots=True)
170176
class UnsubscribeFrame:

stompman/subscription.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
class Subscription:
2121
id: str = field(default_factory=lambda: _make_subscription_id(), init=False) # noqa: PLW0108
2222
destination: str
23+
headers: dict[str, str] | None
2324
handler: Callable[[MessageFrame], Coroutine[None, None, None]]
2425
ack: AckMode
2526
on_suppressed_exception: Callable[[Exception, MessageFrame], None]
@@ -34,7 +35,9 @@ def __post_init__(self) -> None:
3435

3536
async def _subscribe(self) -> None:
3637
await self._connection_manager.write_frame_reconnecting(
37-
SubscribeFrame(headers={"id": self.id, "destination": self.destination, "ack": self.ack})
38+
SubscribeFrame.build(
39+
subscription_id=self.id, destination=self.destination, ack=self.ack, headers=self.headers
40+
)
3841
)
3942
self._active_subscriptions[self.id] = self
4043

@@ -71,8 +74,11 @@ async def resubscribe_to_active_subscriptions(
7174
) -> None:
7275
for subscription in active_subscriptions.values():
7376
await connection.write_frame(
74-
SubscribeFrame(
75-
headers={"id": subscription.id, "destination": subscription.destination, "ack": subscription.ack}
77+
SubscribeFrame.build(
78+
subscription_id=subscription.id,
79+
destination=subscription.destination,
80+
ack=subscription.ack,
81+
headers=subscription.headers,
7682
)
7783
)
7884

tests/test_subscription.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,14 @@ async def test_client_subscribtions_lifespan_resubscribe(ack: AckMode) -> None:
4343
connection_class, collected_frames = create_spying_connection(*get_read_frames_with_lifespan([CONNECTED_FRAME], []))
4444
client = EnrichedClient(connection_class=connection_class)
4545
sub_destination, message_destination, message_body = FAKER.pystr(), FAKER.pystr(), FAKER.binary(length=10)
46+
sub_extra_headers = FAKER.pydict(value_types=[str])
4647

4748
async with client:
4849
subscription = await client.subscribe(
4950
destination=sub_destination,
5051
handler=noop_message_handler,
5152
ack=ack,
53+
headers=sub_extra_headers,
5254
on_suppressed_exception=noop_error_handler,
5355
)
5456
client._connection_manager._clear_active_connection_state()
@@ -57,11 +59,19 @@ async def test_client_subscribtions_lifespan_resubscribe(ack: AckMode) -> None:
5759
await asyncio.sleep(0)
5860
await asyncio.sleep(0)
5961

62+
subscribe_frame = SubscribeFrame(
63+
headers={
64+
"id": subscription.id,
65+
"destination": sub_destination,
66+
"ack": ack,
67+
**sub_extra_headers, # type: ignore[typeddict-item]
68+
}
69+
)
6070
assert collected_frames == enrich_expected_frames(
61-
SubscribeFrame(headers={"id": subscription.id, "destination": sub_destination, "ack": ack}),
71+
subscribe_frame,
6272
CONNECT_FRAME,
6373
CONNECTED_FRAME,
64-
SubscribeFrame(headers={"id": subscription.id, "destination": sub_destination, "ack": ack}),
74+
subscribe_frame,
6575
SendFrame(
6676
headers={"destination": message_destination, "content-length": str(len(message_body))}, body=message_body
6777
),

0 commit comments

Comments
 (0)