Skip to content

Allow subscriptions to be removed asynchronously #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions packages/stompman/stompman/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class Client:
connection_class: type[AbstractConnection] = Connection

_connection_manager: ConnectionManager = field(init=False)
_active_subscriptions: ActiveSubscriptions = field(default_factory=dict, init=False)
_active_subscriptions: ActiveSubscriptions = field(default_factory=ActiveSubscriptions, init=False)
_active_transactions: set[Transaction] = field(default_factory=set, init=False)
_exit_stack: AsyncExitStack = field(default_factory=AsyncExitStack, init=False)
_listen_task: asyncio.Task[None] = field(init=False)
Expand Down Expand Up @@ -86,8 +86,8 @@ async def __aexit__(
self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None
) -> None:
try:
if self._active_subscriptions and not exc_value:
await asyncio.Future()
if not exc_value:
await self._active_subscriptions.wait_until_empty()
finally:
self._listen_task.cancel()
await asyncio.wait([self._listen_task])
Expand All @@ -98,7 +98,7 @@ async def _listen_to_frames(self) -> None:
async for frame in self._connection_manager.read_frames_reconnecting():
match frame:
case MessageFrame():
if subscription := self._active_subscriptions.get(frame.headers["subscription"]):
if subscription := self._active_subscriptions.get_by_id(frame.headers["subscription"]):
task_group.create_task(
subscription._run_handler(frame=frame) # noqa: SLF001
if isinstance(subscription, AutoAckSubscription)
Expand Down
43 changes: 36 additions & 7 deletions packages/stompman/stompman/subscription.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from collections.abc import Awaitable, Callable, Coroutine
from dataclasses import dataclass, field
from typing import Any
Expand All @@ -14,7 +15,35 @@
UnsubscribeFrame,
)

ActiveSubscriptions = dict[str, "AutoAckSubscription | ManualAckSubscription"]

@dataclass(kw_only=True, slots=True, frozen=True)
class ActiveSubscriptions:
subscriptions: dict[str, "AutoAckSubscription | ManualAckSubscription"] = field(default_factory=dict, init=False)
event: asyncio.Event = field(default_factory=asyncio.Event, init=False)

def __post_init__(self) -> None:
self.event.set()

def get_by_id(self, subscription_id: str) -> "AutoAckSubscription | ManualAckSubscription | None":
return self.subscriptions.get(subscription_id)

def get_all(self) -> list["AutoAckSubscription | ManualAckSubscription"]:
return list(self.subscriptions.values())

def delete_by_id(self, subscription_id: str) -> None:
del self.subscriptions[subscription_id]
if not self.subscriptions:
self.event.set()

def add(self, subscription: "AutoAckSubscription | ManualAckSubscription") -> None:
self.subscriptions[subscription.id] = subscription
self.event.clear()

def contains_by_id(self, subscription_id: str) -> bool:
return subscription_id in self.subscriptions

async def wait_until_empty(self) -> bool:
return await self.event.wait()


@dataclass(kw_only=True, slots=True)
Expand All @@ -32,20 +61,20 @@ async def _subscribe(self) -> None:
subscription_id=self.id, destination=self.destination, ack=self.ack, headers=self.headers
)
)
self._active_subscriptions[self.id] = self # type: ignore[assignment]
self._active_subscriptions.add(self) # type: ignore[arg-type]

async def unsubscribe(self) -> None:
del self._active_subscriptions[self.id]
self._active_subscriptions.delete_by_id(self.id)
await self._connection_manager.maybe_write_frame(UnsubscribeFrame(headers={"id": self.id}))

async def _nack(self, frame: MessageFrame) -> None:
if self.id in self._active_subscriptions and (ack_id := frame.headers.get("ack")):
if self._active_subscriptions.contains_by_id(self.id) and (ack_id := frame.headers.get("ack")):
await self._connection_manager.maybe_write_frame(
NackFrame(headers={"id": ack_id, "subscription": frame.headers["subscription"]})
)

async def _ack(self, frame: MessageFrame) -> None:
if self.id in self._active_subscriptions and (ack_id := frame.headers["ack"]):
if self._active_subscriptions.contains_by_id(self.id) and (ack_id := frame.headers["ack"]):
await self._connection_manager.maybe_write_frame(
AckFrame(headers={"id": ack_id, "subscription": frame.headers["subscription"]})
)
Expand Down Expand Up @@ -96,7 +125,7 @@ def _make_subscription_id() -> str:
async def resubscribe_to_active_subscriptions(
*, connection: AbstractConnection, active_subscriptions: ActiveSubscriptions
) -> None:
for subscription in active_subscriptions.values():
for subscription in active_subscriptions.get_all():
await connection.write_frame(
SubscribeFrame.build(
subscription_id=subscription.id,
Expand All @@ -108,5 +137,5 @@ async def resubscribe_to_active_subscriptions(


async def unsubscribe_from_all_active_subscriptions(*, active_subscriptions: ActiveSubscriptions) -> None:
for subscription in active_subscriptions.copy().values():
for subscription in active_subscriptions.get_all():
await subscription.unsubscribe()
1 change: 0 additions & 1 deletion packages/stompman/test_stompman/test_send.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
from typing import Any

import faker
import pytest
from stompman import (
SendFrame,
Expand Down
39 changes: 39 additions & 0 deletions packages/stompman/test_stompman/test_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,42 @@ async def close_connection_soon(client: stompman.Client) -> None:

def test_make_subscription_id() -> None:
stompman.subscription._make_subscription_id()


async def wait_and_unsubscribe(*subscriptions: stompman.subscription.BaseSubscription, wait_in_seconds: float) -> None:
await asyncio.sleep(wait_in_seconds)
for subscription in subscriptions:
await subscription.unsubscribe()


async def test_client_exits_when_subscriptions_are_unsubscribed(
monkeypatch: pytest.MonkeyPatch, faker: faker.Faker
) -> None:
monkeypatch.setattr(
stompman.subscription,
"_make_subscription_id",
mock.Mock(side_effect=[(first_id := faker.pystr()), (second_id := faker.pystr())]),
)
first_destination, second_destination = faker.pystr(), faker.pystr()
connection_class, collected_frames = create_spying_connection(*get_read_frames_with_lifespan([]))

async with EnrichedClient(connection_class=connection_class) as client:
first_subscription = await client.subscribe(
first_destination, handler=noop_message_handler, on_suppressed_exception=noop_error_handler
)
second_subscription = await client.subscribe(
second_destination, handler=noop_message_handler, on_suppressed_exception=noop_error_handler
)
await asyncio.sleep(0)
unsubscribe_task = asyncio.create_task(
wait_and_unsubscribe(first_subscription, second_subscription, wait_in_seconds=0.5)
)

assert unsubscribe_task.done(), "Client should exit context manager only when subscriptions are unsubscribed"

assert collected_frames == enrich_expected_frames(
SubscribeFrame(headers={"id": first_id, "destination": first_destination, "ack": "client-individual"}),
SubscribeFrame(headers={"id": second_id, "destination": second_destination, "ack": "client-individual"}),
UnsubscribeFrame(headers={"id": first_id}),
UnsubscribeFrame(headers={"id": second_id}),
)