Skip to content

Commit 40f98f4

Browse files
committed
Use sdf.register_store method
1 parent 53c36ee commit 40f98f4

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

quixstreams/dataframe/dataframe.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
from quixstreams.sinks import BaseSink
5151
from quixstreams.state.base import State
5252
from quixstreams.state.base.transaction import PartitionTransaction
53+
from quixstreams.state.manager import StoreTypes
5354
from quixstreams.state.rocksdb.timestamped import TimestampedStore
5455
from quixstreams.utils.printing import (
5556
DEFAULT_COLUMN_NAME,
@@ -277,7 +278,7 @@ def func(d: dict, state: State):
277278
Default - `False`.
278279
"""
279280
if stateful:
280-
self._register_store()
281+
self.register_store()
281282
# Force the callback to accept metadata
282283
if metadata:
283284
with_metadata_func = cast(ApplyWithMetadataCallbackStateful, func)
@@ -382,7 +383,7 @@ def func(values: list, state: State):
382383
:return: the updated StreamingDataFrame instance (reassignment NOT required).
383384
"""
384385
if stateful:
385-
self._register_store()
386+
self.register_store()
386387
# Force the callback to accept metadata
387388
if metadata:
388389
with_metadata_func = cast(UpdateWithMetadataCallbackStateful, func)
@@ -480,7 +481,7 @@ def func(d: dict, state: State):
480481
"""
481482

482483
if stateful:
483-
self._register_store()
484+
self.register_store()
484485
# Force the callback to accept metadata
485486
if metadata:
486487
with_metadata_func = cast(FilterWithMetadataCallbackStateful, func)
@@ -1648,11 +1649,7 @@ def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame":
16481649

16491650
def join(self, right: "StreamingDataFrame") -> "StreamingDataFrame":
16501651
# TODO: ensure copartitioning of left and right?
1651-
right.processing_context.state_manager.register_store(
1652-
stream_id=right.stream_id,
1653-
store_type=TimestampedStore,
1654-
changelog_config=self._topic_manager.derive_topic_config(right.topics),
1655-
)
1652+
right.register_store(store_type=TimestampedStore)
16561653

16571654
def left_func(value, key, timestamp, headers):
16581655
right_tx = _get_transaction(right)
@@ -1700,7 +1697,7 @@ def _add_update(
17001697
self._stream = self._stream.add_update(func, metadata=metadata) # type: ignore[call-overload]
17011698
return self
17021699

1703-
def _register_store(self):
1700+
def register_store(self, store_type: Optional[StoreTypes] = None):
17041701
"""
17051702
Register the default store for the current stream_id in StateStoreManager.
17061703
"""
@@ -1710,7 +1707,9 @@ def _register_store(self):
17101707
changelog_topic_config = self._topic_manager.derive_topic_config(self._topics)
17111708

17121709
self._processing_context.state_manager.register_store(
1713-
stream_id=self.stream_id, changelog_config=changelog_topic_config
1710+
stream_id=self.stream_id,
1711+
store_type=store_type,
1712+
changelog_config=changelog_topic_config,
17141713
)
17151714

17161715
def _groupby_key(

0 commit comments

Comments
 (0)