Skip to content

Commit c994f48

Browse files
authored
app: Add option to select store backend (#544)
Add support for multiple store backend
1 parent 0209388 commit c994f48

File tree

13 files changed

+448
-375
lines changed

13 files changed

+448
-375
lines changed

quixstreams/state/manager.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import shutil
33
from pathlib import Path
4-
from typing import List, Dict, Optional, Union
4+
from typing import List, Dict, Optional, Union, Type, get_args
55

66
from quixstreams.rowproducer import RowProducer
77
from .exceptions import (
@@ -14,12 +14,15 @@
1414
from .rocksdb.windowed.store import WindowedRocksDBStore
1515
from .base import Store, StorePartition
1616

17-
__all__ = ("StateStoreManager", "DEFAULT_STATE_STORE_NAME")
17+
__all__ = ("StateStoreManager", "DEFAULT_STATE_STORE_NAME", "StoreTypes")
1818

1919
logger = logging.getLogger(__name__)
2020

2121
DEFAULT_STATE_STORE_NAME = "default"
2222

23+
StoreTypes = Union[Type[RocksDBStore]]
24+
SUPPORTED_STORES = get_args(StoreTypes)
25+
2326

2427
class StateStoreManager:
2528
"""
@@ -38,12 +41,14 @@ def __init__(
3841
rocksdb_options: Optional[RocksDBOptionsType] = None,
3942
producer: Optional[RowProducer] = None,
4043
recovery_manager: Optional[RecoveryManager] = None,
44+
default_store_type: StoreTypes = RocksDBStore,
4145
):
4246
self._state_dir = (Path(state_dir) / group_id).absolute()
4347
self._rocksdb_options = rocksdb_options
4448
self._stores: Dict[str, Dict[str, Store]] = {}
4549
self._producer = producer
4650
self._recovery_manager = recovery_manager
51+
self._default_store_type = default_store_type
4752

4853
def _init_state_dir(self):
4954
logger.info(f'Initializing state directory at "{self._state_dir}"')
@@ -84,6 +89,10 @@ def using_changelogs(self) -> bool:
8489
"""
8590
return bool(self._recovery_manager)
8691

92+
@property
93+
def default_store_type(self) -> StoreTypes:
94+
return self._default_store_type
95+
8796
def do_recovery(self):
8897
"""
8998
Perform a state recovery, if necessary.
@@ -130,7 +139,10 @@ def _setup_changelogs(
130139
)
131140

132141
def register_store(
133-
self, topic_name: str, store_name: str = DEFAULT_STATE_STORE_NAME
142+
self,
143+
topic_name: str,
144+
store_name: str = DEFAULT_STATE_STORE_NAME,
145+
store_type: Optional[StoreTypes] = None,
134146
):
135147
"""
136148
Register a state store to be managed by StateStoreManager.
@@ -142,17 +154,25 @@ def register_store(
142154
143155
:param topic_name: topic name
144156
:param store_name: store name
157+
:param store_type: the storage type used for this store.
158+
Default to StateStoreManager `default_store_type`
145159
"""
146160
if self._stores.get(topic_name, {}).get(store_name) is None:
147-
self._stores.setdefault(topic_name, {})[store_name] = RocksDBStore(
148-
name=store_name,
149-
topic=topic_name,
150-
base_dir=str(self._state_dir),
151-
changelog_producer_factory=self._setup_changelogs(
152-
topic_name, store_name
153-
),
154-
options=self._rocksdb_options,
155-
)
161+
changelog_producer_factory = self._setup_changelogs(topic_name, store_name)
162+
163+
store_type = store_type or self.default_store_type
164+
if store_type == RocksDBStore:
165+
factory = RocksDBStore(
166+
name=store_name,
167+
topic=topic_name,
168+
base_dir=str(self._state_dir),
169+
changelog_producer_factory=changelog_producer_factory,
170+
options=self._rocksdb_options,
171+
)
172+
else:
173+
raise ValueError(f"invalid store type: {store_type}")
174+
175+
self._stores.setdefault(topic_name, {})[store_name] = factory
156176

157177
def register_windowed_store(self, topic_name: str, store_name: str):
158178
"""

tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
"tests.test_quixstreams.test_models.test_serializers.fixtures",
1616
"tests.test_quixstreams.test_platforms.test_quix.fixtures",
1717
"tests.test_quixstreams.test_state.fixtures",
18-
"tests.test_quixstreams.test_state.test_rocksdb.fixtures",
1918
"tests.test_quixstreams.test_state.test_rocksdb.test_windowed.fixtures",
2019
]
2120

tests/test_quixstreams/fixtures.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import uuid
22
from concurrent.futures import ThreadPoolExecutor
33
from typing import Optional, Union
4-
from unittest.mock import create_autospec, patch
4+
from unittest.mock import create_autospec, patch, PropertyMock
55

66
import pytest
77
from confluent_kafka.admin import (
@@ -47,6 +47,7 @@
4747
from quixstreams.rowconsumer import RowConsumer
4848
from quixstreams.rowproducer import RowProducer
4949
from quixstreams.state import StateStoreManager
50+
from quixstreams.state.manager import StoreTypes
5051
from quixstreams.state.recovery import RecoveryManager
5152

5253

@@ -278,7 +279,7 @@ def factory(
278279

279280

280281
@pytest.fixture()
281-
def app_factory(kafka_container, random_consumer_group, tmp_path):
282+
def app_factory(kafka_container, random_consumer_group, tmp_path, store_type):
282283
def factory(
283284
consumer_group: Optional[str] = None,
284285
auto_offset_reset: AutoOffsetReset = "latest",
@@ -296,6 +297,7 @@ def factory(
296297
topic_manager: Optional[TopicManager] = None,
297298
processing_guarantee: ProcessingGuarantee = "at-least-once",
298299
request_timeout: float = 30,
300+
store_type: StoreTypes = store_type,
299301
) -> Application:
300302
state_dir = state_dir or (tmp_path / "state").absolute()
301303
return Application(
@@ -318,16 +320,22 @@ def factory(
318320
request_timeout=request_timeout,
319321
)
320322

321-
return factory
323+
with patch(
324+
"quixstreams.state.manager.StateStoreManager.default_store_type",
325+
new_callable=PropertyMock,
326+
) as m:
327+
m.return_value = store_type
328+
yield factory
322329

323330

324331
@pytest.fixture()
325-
def state_manager_factory(tmp_path):
332+
def state_manager_factory(store_type, tmp_path):
326333
def factory(
327334
group_id: Optional[str] = None,
328335
state_dir: Optional[str] = None,
329336
producer: Optional[RowProducer] = None,
330337
recovery_manager: Optional[RecoveryManager] = None,
338+
default_store_type: StoreTypes = store_type,
331339
) -> StateStoreManager:
332340
group_id = group_id or str(uuid.uuid4())
333341
state_dir = state_dir or str(uuid.uuid4())
@@ -336,6 +344,7 @@ def factory(
336344
state_dir=str(tmp_path / state_dir),
337345
producer=producer,
338346
recovery_manager=recovery_manager,
347+
default_store_type=default_store_type,
339348
)
340349

341350
return factory
@@ -434,6 +443,7 @@ def quix_app_factory(
434443
topic_admin,
435444
quix_mock_config_builder_factory,
436445
quix_topic_manager_factory,
446+
store_type,
437447
):
438448
"""
439449
For doing testing with Quix Applications against a local cluster.
@@ -454,6 +464,7 @@ def factory(
454464
auto_create_topics: bool = True,
455465
use_changelog_topics: bool = True,
456466
workspace_id: str = "my_ws",
467+
store_type: Optional[StoreTypes] = store_type,
457468
) -> Application:
458469
state_dir = state_dir or (tmp_path / "state").absolute()
459470
return Application(
@@ -474,7 +485,12 @@ def factory(
474485
),
475486
)
476487

477-
return factory
488+
with patch(
489+
"quixstreams.state.manager.StateStoreManager.default_store_type",
490+
new_callable=PropertyMock,
491+
) as m:
492+
m.return_value = store_type
493+
yield factory
478494

479495

480496
@pytest.fixture()

tests/test_quixstreams/test_app.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from quixstreams.rowproducer import RowProducer
3131
from quixstreams.sinks import SinkBatch, SinkBackpressureError
3232
from quixstreams.state import State
33+
from quixstreams.state.manager import SUPPORTED_STORES
3334
from quixstreams.sources import SourceException, multiprocessing
3435
from tests.utils import DummySink, DummySource
3536

@@ -1036,6 +1037,7 @@ def test_topic_name_and_config(self, quix_app_factory):
10361037
assert expected_topic.config.num_partitions == topic_partitions
10371038

10381039

1040+
@pytest.mark.parametrize("store_type", SUPPORTED_STORES, indirect=True)
10391041
class TestQuixApplicationWithState:
10401042
def test_quix_app_no_state_management_warning(
10411043
self, quix_app_factory, monkeypatch, topic_factory, executor
@@ -1086,6 +1088,7 @@ def test_quix_app_state_dir_mismatch_warning(
10861088
assert "does not match the state directory" in warning
10871089

10881090

1091+
@pytest.mark.parametrize("store_type", SUPPORTED_STORES, indirect=True)
10891092
class TestApplicationWithState:
10901093
def test_run_stateful_success(
10911094
self,
@@ -1393,6 +1396,7 @@ def test_app_use_changelog_false(self, app_factory):
13931396
assert not app._state_manager.using_changelogs
13941397

13951398

1399+
@pytest.mark.parametrize("store_type", SUPPORTED_STORES, indirect=True)
13961400
class TestApplicationRecovery:
13971401
def test_changelog_recovery_default_store(
13981402
self,
@@ -1672,6 +1676,7 @@ def validate_state():
16721676
@pytest.mark.parametrize("processing_guarantee", ["at-least-once", "exactly-once"])
16731677
def test_changelog_recovery_consistent_after_failed_commit(
16741678
self,
1679+
store_type,
16751680
app_factory,
16761681
executor,
16771682
tmp_path,
@@ -1734,6 +1739,7 @@ def get_app():
17341739
consumer_group=consumer_group,
17351740
state_dir=state_dir,
17361741
processing_guarantee=processing_guarantee,
1742+
store_type=store_type,
17371743
)
17381744
topic = app.topic(topic_name)
17391745
sdf = app.dataframe(topic)
@@ -2376,6 +2382,7 @@ def on_message_processed(*_):
23762382
"groupby_timestamp": timestamp,
23772383
}
23782384

2385+
@pytest.mark.parametrize("store_type", SUPPORTED_STORES, indirect=True)
23792386
def test_stateful(
23802387
self,
23812388
app_factory,
@@ -2466,6 +2473,7 @@ def count(_, state: State):
24662473
# All keys in state must be prefixed with the message key
24672474
assert tx.get("total", prefix=message_key) == messages_per_topic
24682475

2476+
@pytest.mark.parametrize("store_type", SUPPORTED_STORES, indirect=True)
24692477
def test_changelog_recovery(
24702478
self,
24712479
app_factory,

tests/test_quixstreams/test_checkpointing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from quixstreams.sinks import SinkManager, BatchingSink, SinkBackpressureError
1717
from quixstreams.sinks.base import SinkBatch
1818
from quixstreams.state import StateStoreManager
19+
from quixstreams.state.manager import SUPPORTED_STORES
1920
from quixstreams.state.base import PartitionTransaction
2021
from quixstreams.state.exceptions import StoreNotRegisteredError, StoreTransactionFailed
2122
from tests.utils import DummySink
@@ -71,6 +72,7 @@ def write(self, batch: SinkBatch):
7172
raise ValueError("Sink write failed")
7273

7374

75+
@pytest.mark.parametrize("store_type", SUPPORTED_STORES, indirect=True)
7476
class TestCheckpoint:
7577
def test_empty_true(self, checkpoint_factory):
7678
checkpoint = checkpoint_factory()

tests/test_quixstreams/test_state/fixtures.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
import uuid
2-
from typing import Optional
3-
from unittest.mock import MagicMock
2+
from typing import Optional, Generator
3+
from unittest.mock import MagicMock, PropertyMock
44

55
import pytest
6+
from quixstreams.state.base import StorePartition
67

78
from quixstreams.kafka import Consumer
89
from quixstreams.models import TopicManager
9-
from quixstreams.state.recovery import RecoveryPartition, RecoveryManager
10-
from quixstreams.state.base import StorePartition
10+
from quixstreams.state.recovery import (
11+
RecoveryPartition,
12+
RecoveryManager,
13+
ChangelogProducerFactory,
14+
ChangelogProducer,
15+
)
16+
from quixstreams.state.rocksdb import (
17+
RocksDBStore,
18+
RocksDBStorePartition,
19+
RocksDBOptions,
20+
)
1121

1222

1323
@pytest.fixture()
@@ -45,3 +55,83 @@ def factory(
4555
return recovery_partition
4656

4757
return factory
58+
59+
60+
@pytest.fixture()
61+
def store_type(request):
62+
if hasattr(request, "param"):
63+
return request.param
64+
else:
65+
return RocksDBStore
66+
67+
68+
def rocksdb_store_factory(tmp_path):
69+
def factory(
70+
topic: Optional[str] = None,
71+
name: str = "default",
72+
changelog_producer_factory: Optional[ChangelogProducerFactory] = None,
73+
) -> RocksDBStore:
74+
topic = topic or str(uuid.uuid4())
75+
return RocksDBStore(
76+
topic=topic,
77+
name=name,
78+
base_dir=str(tmp_path),
79+
changelog_producer_factory=changelog_producer_factory,
80+
)
81+
82+
return factory
83+
84+
85+
@pytest.fixture()
86+
def store_factory(store_type, tmp_path):
87+
if store_type == RocksDBStore:
88+
return rocksdb_store_factory(tmp_path)
89+
else:
90+
raise ValueError(f"invalid store type {store_type}")
91+
92+
93+
@pytest.fixture()
94+
def store(store_factory):
95+
store = store_factory()
96+
yield store
97+
store.close()
98+
99+
100+
def rocksdb_partition_factory(tmp_path, changelog_producer_mock):
101+
def factory(
102+
name: str = "db",
103+
options: Optional[RocksDBOptions] = None,
104+
changelog_producer: Optional[ChangelogProducer] = None,
105+
) -> RocksDBStorePartition:
106+
path = (tmp_path / name).as_posix()
107+
_options = options or RocksDBOptions(open_max_retries=0, open_retry_backoff=3.0)
108+
return RocksDBStorePartition(
109+
path,
110+
changelog_producer=changelog_producer or changelog_producer_mock,
111+
options=_options,
112+
)
113+
114+
return factory
115+
116+
117+
@pytest.fixture()
118+
def store_partition_factory(store_type, tmp_path, changelog_producer_mock):
119+
if store_type == RocksDBStore:
120+
return rocksdb_partition_factory(tmp_path, changelog_producer_mock)
121+
else:
122+
raise ValueError(f"invalid store type {store_type}")
123+
124+
125+
@pytest.fixture()
126+
def store_partition(store_partition_factory) -> Generator[StorePartition, None, None]:
127+
partition = store_partition_factory()
128+
yield partition
129+
partition.close()
130+
131+
132+
@pytest.fixture()
133+
def changelog_producer_mock():
134+
producer = MagicMock(spec_set=ChangelogProducer)
135+
type(producer).changelog_name = PropertyMock(return_value="test-changelog-topic")
136+
type(producer).partition = PropertyMock(return_value=0)
137+
return producer

0 commit comments

Comments
 (0)