Skip to content

Commit 08a2e7f

Browse files
Merge PausingManager into RowConsumer (#853)
Co-authored-by: Remy Gwaramadze <gwaramadze@users.noreply.github.com>
1 parent 8d134c4 commit 08a2e7f

File tree

15 files changed

+284
-305
lines changed

15 files changed

+284
-305
lines changed

quixstreams/app.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
check_state_management_enabled,
5050
is_quix_deployment,
5151
)
52-
from .processing import PausingManager, ProcessingContext
52+
from .processing import ProcessingContext
5353
from .rowconsumer import RowConsumer
5454
from .rowproducer import RowProducer
5555
from .runtracker import RunTracker
@@ -349,9 +349,6 @@ def __init__(
349349

350350
self._source_manager = SourceManager()
351351
self._sink_manager = SinkManager()
352-
self._pausing_manager = PausingManager(
353-
consumer=self._consumer, topic_manager=self._topic_manager
354-
)
355352
self._dataframe_registry = DataFrameRegistry()
356353
self._processing_context = ProcessingContext(
357354
commit_interval=self._config.commit_interval,
@@ -361,7 +358,6 @@ def __init__(
361358
state_manager=self._state_manager,
362359
exactly_once=self._config.exactly_once,
363360
sink_manager=self._sink_manager,
364-
pausing_manager=self._pausing_manager,
365361
dataframe_registry=self._dataframe_registry,
366362
)
367363
self._run_tracker = RunTracker(processing_context=self._processing_context)
@@ -876,7 +872,7 @@ def _run_dataframe(self):
876872
else:
877873
process_message(dataframes_composed)
878874
processing_context.commit_checkpoint()
879-
processing_context.resume_ready_partitions()
875+
self._consumer.resume_backpressured()
880876
source_manager.raise_for_error()
881877
printer.print()
882878
run_tracker.update_status()
@@ -1029,7 +1025,7 @@ def _on_revoke(self, _, topic_partitions: List[TopicPartition]):
10291025
self._processing_context.commit_checkpoint(force=True)
10301026

10311027
self._revoke_state_partitions(topic_partitions=topic_partitions)
1032-
self._processing_context.on_partition_revoke()
1028+
self._consumer.reset_backpressure()
10331029

10341030
def _on_lost(self, _, topic_partitions: List[TopicPartition]):
10351031
"""
@@ -1038,7 +1034,7 @@ def _on_lost(self, _, topic_partitions: List[TopicPartition]):
10381034
logger.debug("Rebalancing: dropping lost partitions")
10391035

10401036
self._revoke_state_partitions(topic_partitions=topic_partitions)
1041-
self._processing_context.on_partition_revoke()
1037+
self._consumer.reset_backpressure()
10421038

10431039
def _revoke_state_partitions(self, topic_partitions: List[TopicPartition]):
10441040
non_changelog_topics = self._topic_manager.non_changelog_topics

quixstreams/checkpointing/checkpoint.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from confluent_kafka import KafkaException, TopicPartition
77

88
from quixstreams.dataframe import DataFrameRegistry
9-
from quixstreams.kafka import BaseConsumer
10-
from quixstreams.processing.pausing import PausingManager
9+
from quixstreams.rowconsumer import RowConsumer
1110
from quixstreams.rowproducer import RowProducer
1211
from quixstreams.sinks import SinkManager
1312
from quixstreams.sinks.base import SinkBackpressureError
@@ -124,10 +123,9 @@ def __init__(
124123
self,
125124
commit_interval: float,
126125
producer: RowProducer,
127-
consumer: BaseConsumer,
126+
consumer: RowConsumer,
128127
state_manager: StateStoreManager,
129128
sink_manager: SinkManager,
130-
pausing_manager: PausingManager,
131129
dataframe_registry: DataFrameRegistry,
132130
exactly_once: bool = False,
133131
commit_every: int = 0,
@@ -141,7 +139,6 @@ def __init__(
141139
self._consumer = consumer
142140
self._producer = producer
143141
self._sink_manager = sink_manager
144-
self._pausing_manager = pausing_manager
145142
self._dataframe_registry = dataframe_registry
146143
self._exactly_once = exactly_once
147144
if self._exactly_once:
@@ -216,7 +213,7 @@ def commit(self):
216213
# Pause the assignment to let it cool down and seek it back to
217214
# the first processed offsets of this Checkpoint (it must be equal
218215
# to the last committed offset).
219-
self._pausing_manager.pause(
216+
self._consumer.trigger_backpressure(
220217
resume_after=exc.retry_after,
221218
offsets_to_seek=self._starting_tp_offsets.copy(),
222219
)

quixstreams/kafka/consumer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"Consumer",
2828
"AutoOffsetReset",
2929
"RebalancingCallback",
30+
"raise_for_msg_error",
3031
)
3132

3233
RebalancingCallback = Callable[[ConfluentConsumer, List[TopicPartition]], None]

quixstreams/processing/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .context import ProcessingContext as ProcessingContext
2-
from .pausing import PausingManager as PausingManager

quixstreams/processing/context.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from quixstreams.checkpointing import Checkpoint
77
from quixstreams.dataframe import DataFrameRegistry
88
from quixstreams.exceptions import QuixException
9-
from quixstreams.processing.pausing import PausingManager
109
from quixstreams.rowconsumer import RowConsumer
1110
from quixstreams.rowproducer import RowProducer
1211
from quixstreams.sinks import SinkManager
@@ -33,7 +32,6 @@ class ProcessingContext:
3332
consumer: RowConsumer
3433
state_manager: StateStoreManager
3534
sink_manager: SinkManager
36-
pausing_manager: PausingManager
3735
dataframe_registry: DataFrameRegistry
3836
commit_every: int = 0
3937
exactly_once: bool = False
@@ -70,7 +68,6 @@ def init_checkpoint(self):
7068
producer=self.producer,
7169
consumer=self.consumer,
7270
sink_manager=self.sink_manager,
73-
pausing_manager=self.pausing_manager,
7471
dataframe_registry=self.dataframe_registry,
7572
exactly_once=self.exactly_once,
7673
)
@@ -98,12 +95,6 @@ def commit_checkpoint(self, force: bool = False):
9895
)
9996
self.init_checkpoint()
10097

101-
def resume_ready_partitions(self):
102-
self.pausing_manager.resume_if_ready()
103-
104-
def on_partition_revoke(self):
105-
self.pausing_manager.reset()
106-
10798
def __enter__(self):
10899
self.sink_manager.start_sinks()
109100
return self

quixstreams/processing/pausing.py

Lines changed: 0 additions & 96 deletions
This file was deleted.

quixstreams/rowconsumer.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import logging
2-
from typing import Callable, List, Mapping, Optional, Union
2+
import sys
3+
from time import monotonic
4+
from typing import Callable, List, Optional, Union
35

46
from confluent_kafka import KafkaError, TopicPartition
57

@@ -14,8 +16,12 @@
1416

1517
__all__ = ("RowConsumer",)
1618

19+
_MAX_FLOAT = sys.float_info.max
20+
1721

1822
class RowConsumer(BaseConsumer):
23+
_backpressure_resume_at: float
24+
1925
def __init__(
2026
self,
2127
broker_address: Union[str, ConnectionConfig],
@@ -29,6 +35,7 @@ def __init__(
2935
on_error: Optional[ConsumerErrorCallback] = None,
3036
):
3137
"""
38+
3239
A consumer class that is capable of deserializing Kafka messages to Rows
3340
according to the Topics deserialization settings.
3441
@@ -66,7 +73,9 @@ def __init__(
6673
extra_config=extra_config,
6774
)
6875
self._on_error: ConsumerErrorCallback = on_error or default_on_consumer_error
69-
self._topics: Mapping[str, Topic] = {}
76+
self._topics: dict[str, Topic] = {}
77+
self._backpressurred_tps: set[TopicPartition] = set()
78+
self.reset_backpressure()
7079

7180
def subscribe(
7281
self,
@@ -149,4 +158,72 @@ def poll_row(self, timeout: Optional[float] = None) -> Union[Row, List[Row], Non
149158

150159
def close(self):
151160
super().close()
161+
self.reset_backpressure()
152162
self._inner_consumer = None
163+
164+
@property
165+
def backpressured_tps(self) -> set[TopicPartition]:
166+
return self._backpressurred_tps
167+
168+
def trigger_backpressure(
169+
self,
170+
offsets_to_seek: dict[tuple[str, int], int],
171+
resume_after: float,
172+
):
173+
"""
174+
Pause all partitions for the certain period of time and seek the partitions
175+
provided in the `offsets_to_seek` dict.
176+
177+
This method is supposed to be called in case of backpressure from Sinks.
178+
"""
179+
resume_at = monotonic() + resume_after
180+
self._backpressure_resume_at = min(self._backpressure_resume_at, resume_at)
181+
182+
changelog_topics = {k for k, v in self._topics.items() if v.is_changelog}
183+
for tp in self.assignment():
184+
# Pause only data TPs excluding changelog TPs
185+
if tp.topic in changelog_topics:
186+
continue
187+
188+
position, *_ = self.position([tp])
189+
logger.debug(
190+
f'Pausing topic partition "{tp.topic}[{tp.partition}]" for {resume_after}s; '
191+
f"position={position.offset}"
192+
)
193+
self.pause(partitions=[tp])
194+
195+
# Seek the TP back to the "offset_to_seek" to start from it on resume.
196+
# The "offset_to_seek" is provided by the Checkpoint and is expected to be the
197+
# first offset processed in the checkpoint.
198+
# There may be no offset for the TP if no message has been processed yet.
199+
seek_offset = offsets_to_seek.get((tp.topic, tp.partition))
200+
if seek_offset is not None:
201+
logger.debug(
202+
f'Seek the paused partition "{tp.topic}[{tp.partition}]" back to '
203+
f"offset {seek_offset}"
204+
)
205+
self.seek(
206+
partition=TopicPartition(
207+
topic=tp.topic, partition=tp.partition, offset=seek_offset
208+
)
209+
)
210+
211+
self._backpressurred_tps.add(tp)
212+
213+
def resume_backpressured(self):
214+
"""
215+
Resume consuming from assigned data partitions after the wait period has elapsed.
216+
"""
217+
if self._backpressure_resume_at > monotonic():
218+
return
219+
220+
# Resume the previously backpressured TPs
221+
for tp in self._backpressurred_tps:
222+
logger.debug(f'Resuming topic partition "{tp.topic}[{tp.partition}]"')
223+
self.resume(partitions=[tp])
224+
self.reset_backpressure()
225+
226+
def reset_backpressure(self):
227+
# Reset the timeout back to its initial state
228+
self._backpressure_resume_at = _MAX_FLOAT
229+
self._backpressurred_tps.clear()

0 commit comments

Comments
 (0)