Skip to content

Commit 7859a58

Browse files
lorenzobenvenutiLorenzo Benvenutivrslev
authored
Allow subscriptions to be removed asynchronously (#121)
Co-authored-by: Lorenzo Benvenuti <lbenvenu@redhat.com> Co-authored-by: Lev Vereshchagin <levwint@gmail.com>
1 parent 52d9352 commit 7859a58

File tree

4 files changed

+79
-12
lines changed

4 files changed

+79
-12
lines changed

packages/stompman/stompman/client.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class Client:
4747
connection_class: type[AbstractConnection] = Connection
4848

4949
_connection_manager: ConnectionManager = field(init=False)
50-
_active_subscriptions: ActiveSubscriptions = field(default_factory=dict, init=False)
50+
_active_subscriptions: ActiveSubscriptions = field(default_factory=ActiveSubscriptions, init=False)
5151
_active_transactions: set[Transaction] = field(default_factory=set, init=False)
5252
_exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack, init=False)
5353
_listen_task: asyncio.Task[None] = field(init=False)
@@ -86,8 +86,8 @@ async def __aexit__(
8686
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
8787
) -> None:
8888
try:
89-
if self._active_subscriptions and not exc_value:
90-
await asyncio.Future()
89+
if not exc_value:
90+
await self._active_subscriptions.wait_until_empty()
9191
finally:
9292
self._listen_task.cancel()
9393
await asyncio.wait([self._listen_task])
@@ -98,7 +98,7 @@ async def _listen_to_frames(self) -> None:
9898
async for frame in self._connection_manager.read_frames_reconnecting():
9999
match frame:
100100
case MessageFrame():
101-
if subscription := self._active_subscriptions.get(frame.headers["subscription"]):
101+
if subscription := self._active_subscriptions.get_by_id(frame.headers["subscription"]):
102102
task_group.create_task(
103103
subscription._run_handler(frame=frame) # noqa: SLF001
104104
if isinstance(subscription, AutoAckSubscription)

packages/stompman/stompman/subscription.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from collections.abc import Awaitable, Callable, Coroutine
23
from dataclasses import dataclass, field
34
from typing import Any
@@ -14,7 +15,35 @@
1415
UnsubscribeFrame,
1516
)
1617

17-
ActiveSubscriptions = dict[str, "AutoAckSubscription | ManualAckSubscription"]
18+
19+
@dataclass(kw_only=True, slots=True, frozen=True)
20+
class ActiveSubscriptions:
21+
subscriptions: dict[str, "AutoAckSubscription | ManualAckSubscription"] = field(default_factory=dict, init=False)
22+
event: asyncio.Event = field(default_factory=asyncio.Event, init=False)
23+
24+
def __post_init__(self) -> None:
25+
self.event.set()
26+
27+
def get_by_id(self, subscription_id: str) -> "AutoAckSubscription | ManualAckSubscription | None":
28+
return self.subscriptions.get(subscription_id)
29+
30+
def get_all(self) -> list["AutoAckSubscription | ManualAckSubscription"]:
31+
return list(self.subscriptions.values())
32+
33+
def delete_by_id(self, subscription_id: str) -> None:
34+
del self.subscriptions[subscription_id]
35+
if not self.subscriptions:
36+
self.event.set()
37+
38+
def add(self, subscription: "AutoAckSubscription | ManualAckSubscription") -> None:
39+
self.subscriptions[subscription.id] = subscription
40+
self.event.clear()
41+
42+
def contains_by_id(self, subscription_id: str) -> bool:
43+
return subscription_id in self.subscriptions
44+
45+
async def wait_until_empty(self) -> bool:
46+
return await self.event.wait()
1847

1948

2049
@dataclass(kw_only=True, slots=True)
@@ -32,20 +61,20 @@ async def _subscribe(self) -> None:
3261
subscription_id=self.id, destination=self.destination, ack=self.ack, headers=self.headers
3362
)
3463
)
35-
self._active_subscriptions[self.id] = self # type: ignore[assignment]
64+
self._active_subscriptions.add(self) # type: ignore[arg-type]
3665

3766
async def unsubscribe(self) -> None:
38-
del self._active_subscriptions[self.id]
67+
self._active_subscriptions.delete_by_id(self.id)
3968
await self._connection_manager.maybe_write_frame(UnsubscribeFrame(headers={"id": self.id}))
4069

4170
async def _nack(self, frame: MessageFrame) -> None:
42-
if self.id in self._active_subscriptions and (ack_id := frame.headers.get("ack")):
71+
if self._active_subscriptions.contains_by_id(self.id) and (ack_id := frame.headers.get("ack")):
4372
await self._connection_manager.maybe_write_frame(
4473
NackFrame(headers={"id": ack_id, "subscription": frame.headers["subscription"]})
4574
)
4675

4776
async def _ack(self, frame: MessageFrame) -> None:
48-
if self.id in self._active_subscriptions and (ack_id := frame.headers["ack"]):
77+
if self._active_subscriptions.contains_by_id(self.id) and (ack_id := frame.headers["ack"]):
4978
await self._connection_manager.maybe_write_frame(
5079
AckFrame(headers={"id": ack_id, "subscription": frame.headers["subscription"]})
5180
)
@@ -96,7 +125,7 @@ def _make_subscription_id() -> str:
96125
async def resubscribe_to_active_subscriptions(
97126
*, connection: AbstractConnection, active_subscriptions: ActiveSubscriptions
98127
) -> None:
99-
for subscription in active_subscriptions.values():
128+
for subscription in active_subscriptions.get_all():
100129
await connection.write_frame(
101130
SubscribeFrame.build(
102131
subscription_id=subscription.id,
@@ -108,5 +137,5 @@ async def resubscribe_to_active_subscriptions(
108137

109138

110139
async def unsubscribe_from_all_active_subscriptions(*, active_subscriptions: ActiveSubscriptions) -> None:
111-
for subscription in active_subscriptions.copy().values():
140+
for subscription in active_subscriptions.get_all():
112141
await subscription.unsubscribe()

packages/stompman/test_stompman/test_send.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import asyncio
22
from typing import Any
33

4-
import faker
54
import pytest
65
from stompman import (
76
SendFrame,

packages/stompman/test_stompman/test_subscription.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,42 @@ async def close_connection_soon(client: stompman.Client) -> None:
334334

335335
def test_make_subscription_id() -> None:
336336
stompman.subscription._make_subscription_id()
337+
338+
339+
async def wait_and_unsubscribe(*subscriptions: stompman.subscription.BaseSubscription, wait_in_seconds: float) -> None:
340+
await asyncio.sleep(wait_in_seconds)
341+
for subscription in subscriptions:
342+
await subscription.unsubscribe()
343+
344+
345+
async def test_client_exits_when_subscriptions_are_unsubscribed(
346+
monkeypatch: pytest.MonkeyPatch, faker: faker.Faker
347+
) -> None:
348+
monkeypatch.setattr(
349+
stompman.subscription,
350+
"_make_subscription_id",
351+
mock.Mock(side_effect=[(first_id := faker.pystr()), (second_id := faker.pystr())]),
352+
)
353+
first_destination, second_destination = faker.pystr(), faker.pystr()
354+
connection_class, collected_frames = create_spying_connection(*get_read_frames_with_lifespan([]))
355+
356+
async with EnrichedClient(connection_class=connection_class) as client:
357+
first_subscription = await client.subscribe(
358+
first_destination, handler=noop_message_handler, on_suppressed_exception=noop_error_handler
359+
)
360+
second_subscription = await client.subscribe(
361+
second_destination, handler=noop_message_handler, on_suppressed_exception=noop_error_handler
362+
)
363+
await asyncio.sleep(0)
364+
unsubscribe_task = asyncio.create_task(
365+
wait_and_unsubscribe(first_subscription, second_subscription, wait_in_seconds=0.5)
366+
)
367+
368+
assert unsubscribe_task.done(), "Client should exit context manager only when subscriptions are unsubscribed"
369+
370+
assert collected_frames == enrich_expected_frames(
371+
SubscribeFrame(headers={"id": first_id, "destination": first_destination, "ack": "client-individual"}),
372+
SubscribeFrame(headers={"id": second_id, "destination": second_destination, "ack": "client-individual"}),
373+
UnsubscribeFrame(headers={"id": first_id}),
374+
UnsubscribeFrame(headers={"id": second_id}),
375+
)

0 commit comments

Comments
 (0)