50
50
from quixstreams .sinks import BaseSink
51
51
from quixstreams .state .base import State
52
52
from quixstreams .state .base .transaction import PartitionTransaction
53
+ from quixstreams .state .manager import StoreTypes
53
54
from quixstreams .state .rocksdb .timestamped import TimestampedStore
54
55
from quixstreams .utils .printing import (
55
56
DEFAULT_COLUMN_NAME ,
@@ -277,7 +278,7 @@ def func(d: dict, state: State):
277
278
Default - `False`.
278
279
"""
279
280
if stateful :
280
- self ._register_store ()
281
+ self .register_store ()
281
282
# Force the callback to accept metadata
282
283
if metadata :
283
284
with_metadata_func = cast (ApplyWithMetadataCallbackStateful , func )
@@ -382,7 +383,7 @@ def func(values: list, state: State):
382
383
:return: the updated StreamingDataFrame instance (reassignment NOT required).
383
384
"""
384
385
if stateful :
385
- self ._register_store ()
386
+ self .register_store ()
386
387
# Force the callback to accept metadata
387
388
if metadata :
388
389
with_metadata_func = cast (UpdateWithMetadataCallbackStateful , func )
@@ -480,7 +481,7 @@ def func(d: dict, state: State):
480
481
"""
481
482
482
483
if stateful :
483
- self ._register_store ()
484
+ self .register_store ()
484
485
# Force the callback to accept metadata
485
486
if metadata :
486
487
with_metadata_func = cast (FilterWithMetadataCallbackStateful , func )
@@ -1648,11 +1649,7 @@ def concat(self, other: "StreamingDataFrame") -> "StreamingDataFrame":
1648
1649
1649
1650
def join (self , right : "StreamingDataFrame" ) -> "StreamingDataFrame" :
1650
1651
# TODO: ensure copartitioning of left and right?
1651
- right .processing_context .state_manager .register_store (
1652
- stream_id = right .stream_id ,
1653
- store_type = TimestampedStore ,
1654
- changelog_config = self ._topic_manager .derive_topic_config (right .topics ),
1655
- )
1652
+ right .register_store (store_type = TimestampedStore )
1656
1653
1657
1654
def left_func (value , key , timestamp , headers ):
1658
1655
right_tx = _get_transaction (right )
@@ -1700,7 +1697,7 @@ def _add_update(
1700
1697
self ._stream = self ._stream .add_update (func , metadata = metadata ) # type: ignore[call-overload]
1701
1698
return self
1702
1699
1703
- def _register_store (self ):
1700
+ def register_store (self , store_type : Optional [ StoreTypes ] = None ):
1704
1701
"""
1705
1702
Register the default store for the current stream_id in StateStoreManager.
1706
1703
"""
@@ -1710,7 +1707,9 @@ def _register_store(self):
1710
1707
changelog_topic_config = self ._topic_manager .derive_topic_config (self ._topics )
1711
1708
1712
1709
self ._processing_context .state_manager .register_store (
1713
- stream_id = self .stream_id , changelog_config = changelog_topic_config
1710
+ stream_id = self .stream_id ,
1711
+ store_type = store_type ,
1712
+ changelog_config = changelog_topic_config ,
1714
1713
)
1715
1714
1716
1715
def _groupby_key (
0 commit comments