12
12
import sentry_sdk
13
13
from arroyo .backends .kafka .consumer import KafkaPayload
14
14
from arroyo .processing .strategies import ProcessingStrategy
15
+ from arroyo .processing .strategies .abstract import MessageRejected
15
16
from arroyo .types import BrokerValue , FilteredPayload , Message , Partition
16
17
17
18
from sentry .utils import metrics
21
22
T = TypeVar ("T" )
22
23
23
24
25
+ class UnassignedPartitionError (Exception ):
26
+ """Raised when trying to track offsets for an unassigned partition."""
27
+
28
+ pass
29
+
30
+
24
31
@dataclass
25
32
class WorkItem (Generic [T ]):
26
33
"""Work item that includes the original message for offset tracking."""
@@ -47,20 +54,25 @@ def __init__(self) -> None:
47
54
self .partition_locks : dict [Partition , threading .Lock ] = {}
48
55
49
56
def _get_partition_lock (self , partition : Partition ) -> threading .Lock :
50
- """Get or create a lock for a partition."""
51
- lock = self .partition_locks .get (partition )
52
- if lock :
53
- return lock
54
- return self .partition_locks .setdefault (partition , threading .Lock ())
57
+ """Get the lock for a partition."""
58
+ return self .partition_locks [partition ]
55
59
56
60
def add_offset (self , partition : Partition , offset : int ) -> None :
57
61
"""Record that we've started processing an offset."""
62
+ if partition not in self .partition_locks :
63
+ raise UnassignedPartitionError (
64
+ f"Partition { partition } is not assigned to this consumer"
65
+ )
66
+
58
67
with self ._get_partition_lock (partition ):
59
68
self .all_offsets [partition ].add (offset )
60
69
self .outstanding [partition ].add (offset )
61
70
62
71
def complete_offset (self , partition : Partition , offset : int ) -> None :
63
72
"""Mark an offset as completed."""
73
+ if partition not in self .partition_locks :
74
+ return
75
+
64
76
with self ._get_partition_lock (partition ):
65
77
self .outstanding [partition ].discard (offset )
66
78
@@ -104,6 +116,18 @@ def mark_committed(self, partition: Partition, offset: int) -> None:
104
116
# Remove all offsets <= committed offset
105
117
self .all_offsets [partition ] = {o for o in self .all_offsets [partition ] if o > offset }
106
118
119
+ def clear (self ) -> None :
120
+ """Clear all offset tracking state."""
121
+ self .all_offsets .clear ()
122
+ self .outstanding .clear ()
123
+ self .last_committed .clear ()
124
+ self .partition_locks .clear ()
125
+
126
+ def update_assignments (self , partitions : set [Partition ]) -> None :
127
+ """Update partition assignments and reset all tracking state."""
128
+ self .clear ()
129
+ self .partition_locks = {partition : threading .Lock () for partition in partitions }
130
+
107
131
108
132
class OrderedQueueWorker (threading .Thread , Generic [T ]):
109
133
"""Worker thread that processes items from a queue in order."""
@@ -138,9 +162,6 @@ def run(self) -> None:
138
162
name = f"monitors.{ self .identifier } .worker_{ self .worker_id } " ,
139
163
):
140
164
self .result_processor (self .identifier , work_item .result )
141
-
142
- except queue .ShutDown :
143
- break
144
165
except Exception :
145
166
logger .exception (
146
167
"Unexpected error in queue worker" , extra = {"worker_id" : self .worker_id }
@@ -173,13 +194,20 @@ def __init__(
173
194
result_processor : Callable [[str , T ], None ],
174
195
identifier : str ,
175
196
num_queues : int = 20 ,
197
+ commit_interval : float = 1.0 ,
176
198
) -> None :
177
199
self .result_processor = result_processor
178
200
self .identifier = identifier
179
201
self .num_queues = num_queues
202
+ self .commit_interval = commit_interval
180
203
self .offset_tracker = OffsetTracker ()
181
204
self .queues : list [queue .Queue [WorkItem [T ]]] = []
182
205
self .workers : list [OrderedQueueWorker [T ]] = []
206
+ self .commit_function : Callable [[dict [Partition , int ]], None ] | None = None
207
+ self .commit_shutdown_event = threading .Event ()
208
+
209
+ self .commit_thread = threading .Thread (target = self ._commit_loop , daemon = True )
210
+ self .commit_thread .start ()
183
211
184
212
for i in range (num_queues ):
185
213
work_queue : queue .Queue [WorkItem [T ]] = queue .Queue ()
@@ -195,6 +223,29 @@ def __init__(
195
223
worker .start ()
196
224
self .workers .append (worker )
197
225
226
+ def _commit_loop (self ) -> None :
227
+ """Background thread that periodically commits offsets."""
228
+ while not self .commit_shutdown_event .is_set ():
229
+ try :
230
+ self .commit_shutdown_event .wait (self .commit_interval )
231
+ if self .commit_shutdown_event .is_set ():
232
+ break
233
+
234
+ committable = self .offset_tracker .get_committable_offsets ()
235
+
236
+ if committable and self .commit_function :
237
+ metrics .incr (
238
+ "remote_subscriptions.queue_pool.offsets_committed" ,
239
+ len (committable ),
240
+ tags = {"identifier" : self .identifier },
241
+ )
242
+
243
+ self .commit_function (committable )
244
+ for partition , offset in committable .items ():
245
+ self .offset_tracker .mark_committed (partition , offset )
246
+ except Exception :
247
+ logger .exception ("Error in commit loop" )
248
+
198
249
def get_queue_for_group (self , group_key : str ) -> int :
199
250
"""
200
251
Get queue index for a group using consistent hashing.
@@ -205,10 +256,25 @@ def submit(self, group_key: str, work_item: WorkItem[T]) -> None:
205
256
"""
206
257
Submit a work item to the appropriate queue.
207
258
"""
259
+ try :
260
+ self .offset_tracker .add_offset (work_item .partition , work_item .offset )
261
+ except UnassignedPartitionError :
262
+ logger .exception (
263
+ "Received message for unassigned partition, skipping" ,
264
+ extra = {
265
+ "partition" : work_item .partition ,
266
+ "offset" : work_item .offset ,
267
+ "identifier" : self .identifier ,
268
+ },
269
+ )
270
+ metrics .incr (
271
+ "remote_subscriptions.queue_pool.submit.unassigned_partition" ,
272
+ tags = {"identifier" : self .identifier },
273
+ )
274
+ return
275
+
208
276
queue_index = self .get_queue_for_group (group_key )
209
277
work_queue = self .queues [queue_index ]
210
-
211
- self .offset_tracker .add_offset (work_item .partition , work_item .offset )
212
278
work_queue .put (work_item )
213
279
214
280
def get_stats (self ) -> dict [str , Any ]:
@@ -219,7 +285,7 @@ def get_stats(self) -> dict[str, Any]:
219
285
"total_items" : sum (queue_depths ),
220
286
}
221
287
222
- def wait_until_empty (self , timeout : float = 5.0 ) -> bool :
288
+ def wait_until_empty (self , timeout : float ) -> bool :
223
289
"""Wait until all queues are empty. Returns True if successful, False if timeout."""
224
290
start_time = time .time ()
225
291
while time .time () - start_time < timeout :
@@ -228,8 +294,61 @@ def wait_until_empty(self, timeout: float = 5.0) -> bool:
228
294
time .sleep (0.01 )
229
295
return False
230
296
297
+ def flush (self , timeout : float | None = None ) -> bool :
298
+ """
299
+ Wait for all queues to be empty. Returns True if successful, False if timeout.
300
+ If timeout is None, immediately flush without waiting.
301
+ If timeout is reached, flushes all remaining work.
302
+ """
303
+ if timeout is None :
304
+ success = False
305
+ else :
306
+ success = self .wait_until_empty (timeout )
307
+ if not success :
308
+ metrics .incr (
309
+ "remote_subscriptions.queue_pool.flush.timeout" ,
310
+ tags = {"identifier" : self .identifier },
311
+ )
312
+ cleared_count = 0
313
+ for q in self .queues :
314
+ while not q .empty ():
315
+ try :
316
+ q .get_nowait ()
317
+ cleared_count += 1
318
+ except queue .Empty :
319
+ break
320
+ except Exception :
321
+ logger .exception ("Error clearing queue" )
322
+ if cleared_count > 0 :
323
+ metrics .incr (
324
+ "remote_subscriptions.queue_pool.timeout_queue_size" ,
325
+ cleared_count ,
326
+ tags = {"identifier" : self .identifier },
327
+ )
328
+
329
+ self .offset_tracker .clear ()
330
+ return success
331
+
332
+ def update_assignments (
333
+ self ,
334
+ partitions : set [Partition ],
335
+ commit_function : Callable [[dict [Partition , int ]], None ],
336
+ ) -> None :
337
+ """
338
+ Update partition assignments and commit function atomically.
339
+ """
340
+ self .offset_tracker .update_assignments (partitions )
341
+ self .commit_function = commit_function
342
+
343
+ logger .info (
344
+ "Updated partition assignments" ,
345
+ extra = {
346
+ "identifier" : self .identifier ,
347
+ "partitions" : len (partitions ),
348
+ },
349
+ )
350
+
231
351
def shutdown (self ) -> None :
232
- """Gracefully shutdown all workers."""
233
352
for worker in self .workers :
234
353
worker .shutdown = True
235
354
@@ -240,7 +359,10 @@ def shutdown(self) -> None:
240
359
logger .exception ("Error shutting down queue" )
241
360
242
361
for worker in self .workers :
243
- worker .join (timeout = 5.0 )
362
+ worker .join (timeout = 1.0 )
363
+
364
+ self .commit_shutdown_event .set ()
365
+ self .commit_thread .join (timeout = 1.0 )
244
366
245
367
246
368
class SimpleQueueProcessingStrategy (ProcessingStrategy [KafkaPayload ], Generic [T ]):
@@ -260,37 +382,18 @@ def __init__(
260
382
decoder : Callable [[KafkaPayload | FilteredPayload ], T | None ],
261
383
grouping_fn : Callable [[T ], str ],
262
384
commit_function : Callable [[dict [Partition , int ]], None ],
385
+ partitions : set [Partition ],
263
386
) -> None :
264
387
self .queue_pool = queue_pool
265
388
self .decoder = decoder
266
389
self .grouping_fn = grouping_fn
267
- self .commit_function = commit_function
268
390
self .shutdown_event = threading .Event ()
269
-
270
- self .commit_thread = threading .Thread (target = self ._commit_loop , daemon = True )
271
- self .commit_thread .start ()
272
-
273
- def _commit_loop (self ) -> None :
274
- while not self .shutdown_event .is_set ():
275
- try :
276
- self .shutdown_event .wait (1.0 )
277
-
278
- committable = self .queue_pool .offset_tracker .get_committable_offsets ()
279
-
280
- if committable :
281
- metrics .incr (
282
- "remote_subscriptions.queue_pool.offsets_committed" ,
283
- len (committable ),
284
- tags = {"identifier" : self .queue_pool .identifier },
285
- )
286
-
287
- self .commit_function (committable )
288
- for partition , offset in committable .items ():
289
- self .queue_pool .offset_tracker .mark_committed (partition , offset )
290
- except Exception :
291
- logger .exception ("Error in commit loop" )
391
+ self .queue_pool .update_assignments (partitions , commit_function )
292
392
293
393
def submit (self , message : Message [KafkaPayload | FilteredPayload ]) -> None :
394
+ if self .shutdown_event .is_set ():
395
+ raise MessageRejected ("Strategy is shutting down" )
396
+
294
397
try :
295
398
result = self .decoder (message .payload )
296
399
@@ -299,8 +402,11 @@ def submit(self, message: Message[KafkaPayload | FilteredPayload]) -> None:
299
402
offset = message .value .offset
300
403
301
404
if result is None :
302
- self .queue_pool .offset_tracker .add_offset (partition , offset )
303
- self .queue_pool .offset_tracker .complete_offset (partition , offset )
405
+ try :
406
+ self .queue_pool .offset_tracker .add_offset (partition , offset )
407
+ self .queue_pool .offset_tracker .complete_offset (partition , offset )
408
+ except UnassignedPartitionError :
409
+ pass
304
410
return
305
411
306
412
group_key = self .grouping_fn (result )
@@ -334,12 +440,10 @@ def poll(self) -> None:
334
440
335
441
def close (self ) -> None :
336
442
self .shutdown_event .set ()
337
- self .commit_thread .join (timeout = 5.0 )
338
- self .queue_pool .shutdown ()
339
443
340
444
def terminate (self ) -> None :
341
445
self .shutdown_event .set ()
342
- self .queue_pool .shutdown ( )
446
+ self .queue_pool .flush ( timeout = 0 )
343
447
344
448
def join (self , timeout : float | None = None ) -> None :
345
- self .close ( )
449
+ self .queue_pool . flush ( timeout = timeout or 0 )
0 commit comments