Skip to content

Commit 591b103

Browse files
committed
review comments
1 parent f4978b6 commit 591b103

File tree

7 files changed

+114
-52
lines changed

7 files changed

+114
-52
lines changed

quixstreams/app.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -281,15 +281,8 @@ def __init__(
281281
extra_config=consumer_extra_config,
282282
on_error=on_consumer_error,
283283
)
284-
self._producer = RowProducer(
285-
broker_address=broker_address,
286-
extra_config=producer_extra_config,
287-
on_error=on_producer_error,
288-
flush_timeout=consumer_extra_config.get(
289-
"max.poll.interval.ms", _default_max_poll_interval_ms
290-
)
291-
/ 1000, # convert to seconds
292-
transactional=self._uses_exactly_once,
284+
self._producer = self._get_rowproducer(
285+
on_error=on_producer_error, transactional=self._uses_exactly_once
293286
)
294287
self._consumer_poll_timeout = consumer_poll_timeout
295288
self._producer_poll_timeout = producer_poll_timeout
@@ -603,7 +596,7 @@ def dataframe(
603596
:return: `StreamingDataFrame` object
604597
"""
605598
if not source and not topic:
606-
raise TypeError("one of `source` or `topic` is required")
599+
raise ValueError("one of `source` or `topic` is required")
607600
elif source and not topic:
608601
topic = self._topic_manager.source_topic(source)
609602

@@ -638,6 +631,26 @@ def stop(self, fail: bool = False):
638631
if self._state_manager.using_changelogs:
639632
self._state_manager.stop_recovery()
640633

634+
def _get_rowproducer(
635+
self,
636+
on_error: Optional[ProducerErrorCallback] = None,
637+
transactional: bool = False,
638+
) -> RowProducer:
639+
flush_timeout = (
640+
self._consumer_extra_config.get(
641+
"max.poll.interval.ms", _default_max_poll_interval_ms
642+
)
643+
/ 1000
644+
)
645+
646+
return RowProducer(
647+
broker_address=self._broker_address,
648+
extra_config=self._producer_extra_config,
649+
flush_timeout=flush_timeout,
650+
on_error=on_error,
651+
transactional=transactional,
652+
)
653+
641654
def get_producer(self) -> Producer:
642655
"""
643656
Create and return a pre-configured Producer instance.
@@ -721,7 +734,7 @@ def clear_state(self):
721734
"""
722735
self._state_manager.clear_stores()
723736

724-
def source(self, source: BaseSource, topic: Optional[Topic] = None) -> Topic:
737+
def add_source(self, source: BaseSource, topic: Optional[Topic] = None) -> Topic:
725738
"""
726739
Add a source to the application.
727740
@@ -734,15 +747,7 @@ def source(self, source: BaseSource, topic: Optional[Topic] = None) -> Topic:
734747
if not topic:
735748
topic = self._topic_manager.source_topic(source)
736749

737-
producer = RowProducer(
738-
broker_address=self._broker_address,
739-
extra_config=self._producer_extra_config,
740-
flush_timeout=self._consumer_extra_config.get(
741-
"max.poll.interval.ms", _default_max_poll_interval_ms
742-
)
743-
/ 1000, # convert to seconds
744-
transactional=False,
745-
)
750+
producer = self._get_rowproducer(transactional=False)
746751
source.configure(topic, producer)
747752
self._source_manager.register(source)
748753
return topic
@@ -804,7 +809,7 @@ def _run(self, dataframe: Optional[StreamingDataFrame] = None):
804809
self._setup_topics()
805810

806811
if dataframe is not None and dataframe.source:
807-
self.source(dataframe.source, dataframe.topic)
812+
self.add_source(dataframe.source, dataframe.topic)
808813

809814
exit_stack = contextlib.ExitStack()
810815
exit_stack.enter_context(self._processing_context)
@@ -852,7 +857,7 @@ def _run_sources(self):
852857
while self._running:
853858
self._source_manager.raise_for_error()
854859

855-
if not self._source_manager.alives():
860+
if not self._source_manager.is_alive():
856861
self.stop()
857862

858863
time.sleep(1)

quixstreams/dataframe/dataframe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
from quixstreams.sinks import BaseSink
5050
from quixstreams.state.types import State
5151
from quixstreams.sources import BaseSource
52-
from quixstreams.rowproducer import RowProducer
5352
from .base import BaseStreaming
5453
from .exceptions import InvalidOperation, GroupByLimitExceeded, DataFrameLocked
5554
from .series import StreamingSeries

quixstreams/models/topics/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def topic(
289289

290290
def source_topic(self, source: "BaseSource") -> Topic:
291291
topic_args = source.default_topic()
292-
return self.topic(**topic_args.asdict())
292+
return self.topic(**topic_args.asargs())
293293

294294
def repartition_topic(
295295
self,

quixstreams/sources/base.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class SourceTopic:
4646
key_serializer: Optional[SerializerType] = BytesSerializer()
4747
timestamp_extractor: Optional[TimestampExtractor] = None
4848

49-
def asdict(self):
49+
def asargs(self):
5050
return {
5151
field.name: getattr(self, field.name) for field in dataclasses.fields(self)
5252
}
@@ -257,15 +257,20 @@ def __init__(
257257

258258
self._checkpoint_interval = checkpoint_interval
259259
self._checkpoint_error: Optional[BaseException] = None
260+
self._checkpoint_timer: Optional[threading.Timer] = None
260261

261-
self.__checkpoint_timer = self.__make_timer()
262+
@abstractmethod
263+
def run(self):
264+
self._checkpoint_timer = self._make_timer()
265+
self._checkpoint_timer.start()
266+
super().run()
262267

263-
def __make_timer(self):
268+
def _make_timer(self) -> threading.Timer:
264269
return threading.Timer(
265270
interval=self._checkpoint_interval, function=self.__checkpoint
266271
)
267272

268-
def __checkpoint(self):
273+
def _checkpoint(self):
269274
try:
270275
self.checkpoint()
271276
except BaseException as err:
@@ -274,16 +279,11 @@ def __checkpoint(self):
274279
self.stop()
275280
return
276281

277-
self.__checkpoint_timer = self.__make_timer()
278-
self.__checkpoint_timer.start()
279-
280-
@abstractmethod
281-
def run(self):
282-
self.__checkpoint_timer.start()
283-
super().run()
282+
self._checkpoint_timer = self._make_timer()
283+
self._checkpoint_timer.start()
284284

285285
def cleanup(self, failed):
286-
self.__checkpoint_timer.cancel()
286+
self._checkpoint_timer.cancel()
287287
if self._checkpoint_error:
288288
super().cleanup(True)
289289
raise self._checkpoint_error
@@ -321,10 +321,12 @@ def __init__(
321321
super().__init__(name, shutdown_timeout)
322322

323323
self._polling_delay = polling_delay
324-
self._stopping = threading.Event()
324+
self._stopping: Optional[threading.Event] = None
325325

326326
def run(self) -> None:
327327
super().run()
328+
329+
self._stopping = threading.Event()
328330
while not self._stopping.is_set():
329331
try:
330332
msg = self.poll()
@@ -350,7 +352,8 @@ def sleep(self, seconds: float):
350352
raise PollingSourceShutdown("shutdown")
351353

352354
def stop(self) -> None:
353-
self._stopping.set()
355+
if self._stopping is not None:
356+
self._stopping.set()
354357
super().stop()
355358

356359
def cleanup(self, failed: bool) -> None:

quixstreams/sources/iterable.py

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
from typing import Iterable, Optional, Tuple
3+
from typing import Iterable, Optional, Tuple, Callable, Generator
44

55
from quixstreams.models.messages import KafkaMessage
66

@@ -10,6 +10,7 @@
1010

1111

1212
__all__ = (
13+
"GeneratorSource",
1314
"ValueIterableSource",
1415
"KeyValueIterableSource",
1516
)
@@ -26,7 +27,7 @@ class ValueIterableSource(PollingSource):
2627
from quixstreams.sources import ValueIterableSource
2728
2829
app = Application(broker_address='localhost:9092', consumer_group='group')
29-
source = ValueIterableSource(name="my_source", values=iter(range(10)))
30+
source = ValueIterableSource(name="my_source", values=range(10))
3031
3132
sdf = app.dataframe(source=source)
3233
sdf.print()
@@ -51,7 +52,7 @@ def __init__(
5152
super().__init__(name, shutdown_timeout)
5253

5354
self._key = key
54-
self._values = values
55+
self._values = iter(values)
5556

5657
def poll(self) -> KafkaMessage:
5758
try:
@@ -75,14 +76,9 @@ class KeyValueIterableSource(PollingSource):
7576
from quixstreams import Application
7677
from quixstreams.sources import KeyValueIterableSource
7778
78-
def messages():
79-
yield "one", 1
80-
yield "two", 2
81-
yield "three", 3
82-
yield "four", 4
83-
79+
keys = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
8480
app = Application(broker_address='localhost:9092', consumer_group='group')
85-
source = KeyValueIterableSource(name="my_source", iterable=messages())
81+
source = KeyValueIterableSource(name="my_source", iterable=zip(keys, range(10)))
8682
8783
sdf = app.dataframe(source=source)
8884
sdf.print()
@@ -104,11 +100,70 @@ def __init__(
104100
"""
105101
super().__init__(name, shutdown_timeout)
106102

107-
self._iterable = iterable
103+
self._iterator = iter(iterable)
104+
105+
def poll(self) -> KafkaMessage:
106+
try:
107+
data = next(self._iterator)
108+
except StopIteration:
109+
raise PollingSourceShutdown()
110+
111+
if data is None:
112+
return data
113+
114+
key, value = data
115+
return self.serialize(key=key, value=value)
116+
117+
118+
class GeneratorSource(PollingSource):
119+
"""
120+
PollingSource implementation that iterator over a generator of (key, value)
121+
122+
Example Snippet:
123+
124+
```python
125+
from quixstreams import Application
126+
from quixstreams.sources import GeneratorSource
127+
128+
def messages():
129+
yield "one", 1
130+
yield "two", 2
131+
yield "three", 3
132+
yield "four", 4
133+
134+
app = Application(broker_address='localhost:9092', consumer_group='group')
135+
source = GeneratorSource(name="my_source", generator=messages)
136+
137+
sdf = app.dataframe(source=source)
138+
sdf.print()
139+
140+
app.run(sdf)
141+
```
142+
"""
143+
144+
def __init__(
145+
self,
146+
name: str,
147+
generator: Callable[[], Generator[Optional[Tuple[any, any]], None, None]],
148+
polling_delay: float = 1,
149+
shutdown_timeout: float = 10,
150+
) -> None:
151+
super().__init__(name, polling_delay, shutdown_timeout)
152+
153+
self._generator = generator
154+
self._generator_instance: Optional[
155+
Generator[Optional[Tuple[any, any]], None, None]
156+
] = None
157+
158+
def run(self):
159+
self._generator_instance: Generator[Optional[Tuple[any, any]], None, None] = (
160+
self._generator()
161+
)
162+
super().run()
108163

109164
def poll(self) -> KafkaMessage:
110165
try:
111-
data = next(self._iterable)
166+
data = next(self._generator_instance)
112167
except StopIteration:
113168
raise PollingSourceShutdown()
114169

quixstreams/sources/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def raise_for_error(self) -> None:
211211
for process in self.processes:
212212
process.raise_for_error()
213213

214-
def alives(self) -> bool:
214+
def is_alive(self) -> bool:
215215
"""
216216
Check if any process is alive
217217

tests/test_quixstreams/test_app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2145,7 +2145,7 @@ def values():
21452145
yield 2
21462146

21472147
source = ValueIterableSource(name="foo", key="foo", values=values())
2148-
app.source(source, topic=app.topic(topic_name))
2148+
app.add_source(source, topic=app.topic(topic_name))
21492149
app._run()
21502150

21512151
results = []

0 commit comments

Comments
 (0)