1
+ import asyncio
1
2
from collections .abc import Awaitable , Callable , Coroutine
2
3
from dataclasses import dataclass , field
3
4
from typing import Any
14
15
UnsubscribeFrame ,
15
16
)
16
17
17
- ActiveSubscriptions = dict [str , "AutoAckSubscription | ManualAckSubscription" ]
18
+
19
+ @dataclass (kw_only = True , slots = True , frozen = True )
20
+ class ActiveSubscriptions :
21
+ subscriptions : dict [str , "AutoAckSubscription | ManualAckSubscription" ] = field (default_factory = dict , init = False )
22
+ event : asyncio .Event = field (default_factory = asyncio .Event , init = False )
23
+
24
+ def __post_init__ (self ) -> None :
25
+ self .event .set ()
26
+
27
+ def get_by_id (self , subscription_id : str ) -> "AutoAckSubscription | ManualAckSubscription | None" :
28
+ return self .subscriptions .get (subscription_id )
29
+
30
+ def get_all (self ) -> list ["AutoAckSubscription | ManualAckSubscription" ]:
31
+ return list (self .subscriptions .values ())
32
+
33
+ def delete_by_id (self , subscription_id : str ) -> None :
34
+ del self .subscriptions [subscription_id ]
35
+ if not self .subscriptions :
36
+ self .event .set ()
37
+
38
+ def add (self , subscription : "AutoAckSubscription | ManualAckSubscription" ) -> None :
39
+ self .subscriptions [subscription .id ] = subscription
40
+ self .event .clear ()
41
+
42
+ def contains_by_id (self , subscription_id : str ) -> bool :
43
+ return subscription_id in self .subscriptions
44
+
45
+ async def wait_until_empty (self ) -> bool :
46
+ return await self .event .wait ()
18
47
19
48
20
49
@dataclass (kw_only = True , slots = True )
@@ -32,20 +61,20 @@ async def _subscribe(self) -> None:
32
61
subscription_id = self .id , destination = self .destination , ack = self .ack , headers = self .headers
33
62
)
34
63
)
35
- self ._active_subscriptions [ self . id ] = self # type: ignore[assignment ]
64
+ self ._active_subscriptions . add ( self ) # type: ignore[arg-type ]
36
65
37
66
async def unsubscribe (self ) -> None :
38
- del self ._active_subscriptions [ self .id ]
67
+ self ._active_subscriptions . delete_by_id ( self .id )
39
68
await self ._connection_manager .maybe_write_frame (UnsubscribeFrame (headers = {"id" : self .id }))
40
69
41
70
async def _nack (self , frame : MessageFrame ) -> None :
42
- if self .id in self ._active_subscriptions and (ack_id := frame .headers .get ("ack" )):
71
+ if self ._active_subscriptions . contains_by_id ( self .id ) and (ack_id := frame .headers .get ("ack" )):
43
72
await self ._connection_manager .maybe_write_frame (
44
73
NackFrame (headers = {"id" : ack_id , "subscription" : frame .headers ["subscription" ]})
45
74
)
46
75
47
76
async def _ack (self , frame : MessageFrame ) -> None :
48
- if self .id in self ._active_subscriptions and (ack_id := frame .headers ["ack" ]):
77
+ if self ._active_subscriptions . contains_by_id ( self .id ) and (ack_id := frame .headers ["ack" ]):
49
78
await self ._connection_manager .maybe_write_frame (
50
79
AckFrame (headers = {"id" : ack_id , "subscription" : frame .headers ["subscription" ]})
51
80
)
@@ -96,7 +125,7 @@ def _make_subscription_id() -> str:
96
125
async def resubscribe_to_active_subscriptions (
97
126
* , connection : AbstractConnection , active_subscriptions : ActiveSubscriptions
98
127
) -> None :
99
- for subscription in active_subscriptions .values ():
128
+ for subscription in active_subscriptions .get_all ():
100
129
await connection .write_frame (
101
130
SubscribeFrame .build (
102
131
subscription_id = subscription .id ,
@@ -108,5 +137,5 @@ async def resubscribe_to_active_subscriptions(
108
137
109
138
110
139
async def unsubscribe_from_all_active_subscriptions (* , active_subscriptions : ActiveSubscriptions ) -> None :
111
- for subscription in active_subscriptions .copy (). values ():
140
+ for subscription in active_subscriptions .get_all ():
112
141
await subscription .unsubscribe ()
0 commit comments