Skip to content

Commit 82e2139

Browse files
authored
Fix thread-queue-parallel consumer to handle Kafka rebalances correctly (#95753)
The thread-queue-parallel consumer was encountering errors during Kafka rebalances. This was caused by the commit thread continuing to run after partition revocation and attempting to commit offsets for partitions no longer owned by the consumer. I also noticed a few other problems with the structure of things and improve those as well. - Move commit thread ownership from individual strategies to FixedQueuePool to survive rebalances - Add partition assignment validation to reject messages from unassigned partitions - Implement OffsetTracker.clear() to reset state during partition revocation - Add update_assignments() to atomically update partitions and commit function <!-- Describe your PR here. -->
1 parent aa02773 commit 82e2139

File tree

4 files changed

+451
-59
lines changed

4 files changed

+451
-59
lines changed

src/sentry/remote_subscriptions/consumers/queue_consumer.py

Lines changed: 147 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import sentry_sdk
1313
from arroyo.backends.kafka.consumer import KafkaPayload
1414
from arroyo.processing.strategies import ProcessingStrategy
15+
from arroyo.processing.strategies.abstract import MessageRejected
1516
from arroyo.types import BrokerValue, FilteredPayload, Message, Partition
1617

1718
from sentry.utils import metrics
@@ -21,6 +22,12 @@
2122
T = TypeVar("T")
2223

2324

25+
class UnassignedPartitionError(Exception):
26+
"""Raised when trying to track offsets for an unassigned partition."""
27+
28+
pass
29+
30+
2431
@dataclass
2532
class WorkItem(Generic[T]):
2633
"""Work item that includes the original message for offset tracking."""
@@ -47,20 +54,25 @@ def __init__(self) -> None:
4754
self.partition_locks: dict[Partition, threading.Lock] = {}
4855

4956
def _get_partition_lock(self, partition: Partition) -> threading.Lock:
50-
"""Get or create a lock for a partition."""
51-
lock = self.partition_locks.get(partition)
52-
if lock:
53-
return lock
54-
return self.partition_locks.setdefault(partition, threading.Lock())
57+
"""Get the lock for a partition."""
58+
return self.partition_locks[partition]
5559

5660
def add_offset(self, partition: Partition, offset: int) -> None:
5761
"""Record that we've started processing an offset."""
62+
if partition not in self.partition_locks:
63+
raise UnassignedPartitionError(
64+
f"Partition {partition} is not assigned to this consumer"
65+
)
66+
5867
with self._get_partition_lock(partition):
5968
self.all_offsets[partition].add(offset)
6069
self.outstanding[partition].add(offset)
6170

6271
def complete_offset(self, partition: Partition, offset: int) -> None:
6372
"""Mark an offset as completed."""
73+
if partition not in self.partition_locks:
74+
return
75+
6476
with self._get_partition_lock(partition):
6577
self.outstanding[partition].discard(offset)
6678

@@ -104,6 +116,18 @@ def mark_committed(self, partition: Partition, offset: int) -> None:
104116
# Remove all offsets <= committed offset
105117
self.all_offsets[partition] = {o for o in self.all_offsets[partition] if o > offset}
106118

119+
def clear(self) -> None:
120+
"""Clear all offset tracking state."""
121+
self.all_offsets.clear()
122+
self.outstanding.clear()
123+
self.last_committed.clear()
124+
self.partition_locks.clear()
125+
126+
def update_assignments(self, partitions: set[Partition]) -> None:
127+
"""Update partition assignments and reset all tracking state."""
128+
self.clear()
129+
self.partition_locks = {partition: threading.Lock() for partition in partitions}
130+
107131

108132
class OrderedQueueWorker(threading.Thread, Generic[T]):
109133
"""Worker thread that processes items from a queue in order."""
@@ -138,9 +162,6 @@ def run(self) -> None:
138162
name=f"monitors.{self.identifier}.worker_{self.worker_id}",
139163
):
140164
self.result_processor(self.identifier, work_item.result)
141-
142-
except queue.ShutDown:
143-
break
144165
except Exception:
145166
logger.exception(
146167
"Unexpected error in queue worker", extra={"worker_id": self.worker_id}
@@ -173,13 +194,20 @@ def __init__(
173194
result_processor: Callable[[str, T], None],
174195
identifier: str,
175196
num_queues: int = 20,
197+
commit_interval: float = 1.0,
176198
) -> None:
177199
self.result_processor = result_processor
178200
self.identifier = identifier
179201
self.num_queues = num_queues
202+
self.commit_interval = commit_interval
180203
self.offset_tracker = OffsetTracker()
181204
self.queues: list[queue.Queue[WorkItem[T]]] = []
182205
self.workers: list[OrderedQueueWorker[T]] = []
206+
self.commit_function: Callable[[dict[Partition, int]], None] | None = None
207+
self.commit_shutdown_event = threading.Event()
208+
209+
self.commit_thread = threading.Thread(target=self._commit_loop, daemon=True)
210+
self.commit_thread.start()
183211

184212
for i in range(num_queues):
185213
work_queue: queue.Queue[WorkItem[T]] = queue.Queue()
@@ -195,6 +223,29 @@ def __init__(
195223
worker.start()
196224
self.workers.append(worker)
197225

226+
def _commit_loop(self) -> None:
227+
"""Background thread that periodically commits offsets."""
228+
while not self.commit_shutdown_event.is_set():
229+
try:
230+
self.commit_shutdown_event.wait(self.commit_interval)
231+
if self.commit_shutdown_event.is_set():
232+
break
233+
234+
committable = self.offset_tracker.get_committable_offsets()
235+
236+
if committable and self.commit_function:
237+
metrics.incr(
238+
"remote_subscriptions.queue_pool.offsets_committed",
239+
len(committable),
240+
tags={"identifier": self.identifier},
241+
)
242+
243+
self.commit_function(committable)
244+
for partition, offset in committable.items():
245+
self.offset_tracker.mark_committed(partition, offset)
246+
except Exception:
247+
logger.exception("Error in commit loop")
248+
198249
def get_queue_for_group(self, group_key: str) -> int:
199250
"""
200251
Get queue index for a group using consistent hashing.
@@ -205,10 +256,25 @@ def submit(self, group_key: str, work_item: WorkItem[T]) -> None:
205256
"""
206257
Submit a work item to the appropriate queue.
207258
"""
259+
try:
260+
self.offset_tracker.add_offset(work_item.partition, work_item.offset)
261+
except UnassignedPartitionError:
262+
logger.exception(
263+
"Received message for unassigned partition, skipping",
264+
extra={
265+
"partition": work_item.partition,
266+
"offset": work_item.offset,
267+
"identifier": self.identifier,
268+
},
269+
)
270+
metrics.incr(
271+
"remote_subscriptions.queue_pool.submit.unassigned_partition",
272+
tags={"identifier": self.identifier},
273+
)
274+
return
275+
208276
queue_index = self.get_queue_for_group(group_key)
209277
work_queue = self.queues[queue_index]
210-
211-
self.offset_tracker.add_offset(work_item.partition, work_item.offset)
212278
work_queue.put(work_item)
213279

214280
def get_stats(self) -> dict[str, Any]:
@@ -219,7 +285,7 @@ def get_stats(self) -> dict[str, Any]:
219285
"total_items": sum(queue_depths),
220286
}
221287

222-
def wait_until_empty(self, timeout: float = 5.0) -> bool:
288+
def wait_until_empty(self, timeout: float) -> bool:
223289
"""Wait until all queues are empty. Returns True if successful, False if timeout."""
224290
start_time = time.time()
225291
while time.time() - start_time < timeout:
@@ -228,8 +294,61 @@ def wait_until_empty(self, timeout: float = 5.0) -> bool:
228294
time.sleep(0.01)
229295
return False
230296

297+
def flush(self, timeout: float | None = None) -> bool:
298+
"""
299+
Wait for all queues to be empty. Returns True if successful, False if timeout.
300+
If timeout is None, immediately flush without waiting.
301+
If timeout is reached, flushes all remaining work.
302+
"""
303+
if timeout is None:
304+
success = False
305+
else:
306+
success = self.wait_until_empty(timeout)
307+
if not success:
308+
metrics.incr(
309+
"remote_subscriptions.queue_pool.flush.timeout",
310+
tags={"identifier": self.identifier},
311+
)
312+
cleared_count = 0
313+
for q in self.queues:
314+
while not q.empty():
315+
try:
316+
q.get_nowait()
317+
cleared_count += 1
318+
except queue.Empty:
319+
break
320+
except Exception:
321+
logger.exception("Error clearing queue")
322+
if cleared_count > 0:
323+
metrics.incr(
324+
"remote_subscriptions.queue_pool.timeout_queue_size",
325+
cleared_count,
326+
tags={"identifier": self.identifier},
327+
)
328+
329+
self.offset_tracker.clear()
330+
return success
331+
332+
def update_assignments(
333+
self,
334+
partitions: set[Partition],
335+
commit_function: Callable[[dict[Partition, int]], None],
336+
) -> None:
337+
"""
338+
Update partition assignments and commit function atomically.
339+
"""
340+
self.offset_tracker.update_assignments(partitions)
341+
self.commit_function = commit_function
342+
343+
logger.info(
344+
"Updated partition assignments",
345+
extra={
346+
"identifier": self.identifier,
347+
"partitions": len(partitions),
348+
},
349+
)
350+
231351
def shutdown(self) -> None:
232-
"""Gracefully shutdown all workers."""
233352
for worker in self.workers:
234353
worker.shutdown = True
235354

@@ -240,7 +359,10 @@ def shutdown(self) -> None:
240359
logger.exception("Error shutting down queue")
241360

242361
for worker in self.workers:
243-
worker.join(timeout=5.0)
362+
worker.join(timeout=1.0)
363+
364+
self.commit_shutdown_event.set()
365+
self.commit_thread.join(timeout=1.0)
244366

245367

246368
class SimpleQueueProcessingStrategy(ProcessingStrategy[KafkaPayload], Generic[T]):
@@ -260,37 +382,18 @@ def __init__(
260382
decoder: Callable[[KafkaPayload | FilteredPayload], T | None],
261383
grouping_fn: Callable[[T], str],
262384
commit_function: Callable[[dict[Partition, int]], None],
385+
partitions: set[Partition],
263386
) -> None:
264387
self.queue_pool = queue_pool
265388
self.decoder = decoder
266389
self.grouping_fn = grouping_fn
267-
self.commit_function = commit_function
268390
self.shutdown_event = threading.Event()
269-
270-
self.commit_thread = threading.Thread(target=self._commit_loop, daemon=True)
271-
self.commit_thread.start()
272-
273-
def _commit_loop(self) -> None:
274-
while not self.shutdown_event.is_set():
275-
try:
276-
self.shutdown_event.wait(1.0)
277-
278-
committable = self.queue_pool.offset_tracker.get_committable_offsets()
279-
280-
if committable:
281-
metrics.incr(
282-
"remote_subscriptions.queue_pool.offsets_committed",
283-
len(committable),
284-
tags={"identifier": self.queue_pool.identifier},
285-
)
286-
287-
self.commit_function(committable)
288-
for partition, offset in committable.items():
289-
self.queue_pool.offset_tracker.mark_committed(partition, offset)
290-
except Exception:
291-
logger.exception("Error in commit loop")
391+
self.queue_pool.update_assignments(partitions, commit_function)
292392

293393
def submit(self, message: Message[KafkaPayload | FilteredPayload]) -> None:
394+
if self.shutdown_event.is_set():
395+
raise MessageRejected("Strategy is shutting down")
396+
294397
try:
295398
result = self.decoder(message.payload)
296399

@@ -299,8 +402,11 @@ def submit(self, message: Message[KafkaPayload | FilteredPayload]) -> None:
299402
offset = message.value.offset
300403

301404
if result is None:
302-
self.queue_pool.offset_tracker.add_offset(partition, offset)
303-
self.queue_pool.offset_tracker.complete_offset(partition, offset)
405+
try:
406+
self.queue_pool.offset_tracker.add_offset(partition, offset)
407+
self.queue_pool.offset_tracker.complete_offset(partition, offset)
408+
except UnassignedPartitionError:
409+
pass
304410
return
305411

306412
group_key = self.grouping_fn(result)
@@ -334,12 +440,10 @@ def poll(self) -> None:
334440

335441
def close(self) -> None:
336442
self.shutdown_event.set()
337-
self.commit_thread.join(timeout=5.0)
338-
self.queue_pool.shutdown()
339443

340444
def terminate(self) -> None:
341445
self.shutdown_event.set()
342-
self.queue_pool.shutdown()
446+
self.queue_pool.flush(timeout=0)
343447

344448
def join(self, timeout: float | None = None) -> None:
345-
self.close()
449+
self.queue_pool.flush(timeout=timeout or 0)

src/sentry/remote_subscriptions/consumers/result_consumer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def __init__(
112112
num_processes: int | None = None,
113113
input_block_size: int | None = None,
114114
output_block_size: int | None = None,
115+
commit_interval: float | None = None,
115116
) -> None:
116117
self.mode = mode
117118
metric_tags = {"identifier": self.identifier, "mode": self.mode}
@@ -133,7 +134,8 @@ def __init__(
133134
self.queue_pool = FixedQueuePool(
134135
result_processor=self.result_processor,
135136
identifier=self.identifier,
136-
num_queues=max_workers or 20, # Number of parallel queues
137+
num_queues=max_workers or 20,
138+
commit_interval=commit_interval or 1.0,
137139
)
138140

139141
metrics.incr(
@@ -207,7 +209,7 @@ def create_with_partitions(
207209
if self.parallel:
208210
return self.create_multiprocess_worker(commit)
209211
if self.thread_queue_parallel:
210-
return self.create_thread_queue_parallel_worker(commit)
212+
return self.create_thread_queue_parallel_worker(commit, partitions)
211213
else:
212214
return self.create_serial_worker(commit)
213215

@@ -242,7 +244,7 @@ def create_thread_parallel_worker(self, commit: Commit) -> ProcessingStrategy[Ka
242244
)
243245

244246
def create_thread_queue_parallel_worker(
245-
self, commit: Commit
247+
self, commit: Commit, partitions: Mapping[Partition, int]
246248
) -> ProcessingStrategy[KafkaPayload]:
247249
assert self.queue_pool is not None
248250

@@ -256,6 +258,7 @@ def commit_offsets(offsets: dict[Partition, int]):
256258
decoder=partial(self.decode_payload, self.topic_for_codec),
257259
grouping_fn=self.build_payload_grouping_key,
258260
commit_function=commit_offsets,
261+
partitions=set(partitions.keys()),
259262
)
260263

261264
def partition_message_batch(self, message: Message[ValuesBatch[KafkaPayload]]) -> list[list[T]]:

0 commit comments

Comments
 (0)