1
1
import logging
2
2
import time
3
- from typing import Dict , List , Literal , Optional , Union , cast
3
+ from typing import (
4
+ Dict ,
5
+ Iterator ,
6
+ List ,
7
+ Literal ,
8
+ Optional ,
9
+ Union ,
10
+ cast ,
11
+ )
4
12
5
- from rocksdict import AccessType , ColumnFamily , Rdict , WriteBatch
13
+ from rocksdict import AccessType , ColumnFamily , Rdict , ReadOptions , WriteBatch
6
14
7
- from quixstreams .state .base import PartitionTransactionCache , StorePartition
15
+ from quixstreams .state .base import (
16
+ PartitionTransaction ,
17
+ PartitionTransactionCache ,
18
+ StorePartition ,
19
+ )
8
20
from quixstreams .state .exceptions import ColumnFamilyDoesNotExist
9
21
from quixstreams .state .metadata import METADATA_CF_NAME , Marker
10
22
from quixstreams .state .recovery import ChangelogProducer
11
- from quixstreams .state .serialization import (
12
- int_from_int64_bytes ,
13
- int_to_int64_bytes ,
14
- )
23
+ from quixstreams .state .serialization import int_from_int64_bytes , int_to_int64_bytes
15
24
16
25
from .exceptions import ColumnFamilyAlreadyExists
17
26
from .metadata import (
22
31
23
32
__all__ = ("RocksDBStorePartition" ,)
24
33
34
+
25
35
logger = logging .getLogger (__name__ )
26
36
27
37
@@ -42,6 +52,8 @@ class RocksDBStorePartition(StorePartition):
42
52
:param options: RocksDB options. If `None`, the default options will be used.
43
53
"""
44
54
55
+ additional_column_families : tuple [str , ...] = ()
56
+
45
57
def __init__ (
46
58
self ,
47
59
path : str ,
@@ -60,6 +72,8 @@ def __init__(
60
72
self ._db = self ._init_rocksdb ()
61
73
self ._cf_cache : Dict [str , Rdict ] = {}
62
74
self ._cf_handle_cache : Dict [str , ColumnFamily ] = {}
75
+ for cf_name in self .additional_column_families :
76
+ self ._ensure_column_family (cf_name )
63
77
64
78
def recover_from_changelog_message (
65
79
self , key : bytes , value : Optional [bytes ], cf_name : str , offset : int
@@ -139,6 +153,61 @@ def get(
139
153
# RDict accept Any type as value but we only write bytes so we should only get bytes back.
140
154
return cast (Union [bytes , Literal [Marker .UNDEFINED ]], result )
141
155
156
+ def iter_items (
157
+ self ,
158
+ lower_bound : bytes , # inclusive
159
+ upper_bound : bytes , # exclusive
160
+ backwards : bool = False ,
161
+ cf_name : str = "default" ,
162
+ ) -> Iterator [tuple [bytes , bytes ]]:
163
+ """
164
+ Iterate over key-value pairs within a specified range in a column family.
165
+
166
+ :param lower_bound: The lower bound key (inclusive) for the iteration range.
167
+ :param upper_bound: The upper bound key (exclusive) for the iteration range.
168
+ :param backwards: If `True`, iterate in reverse order (descending).
169
+ Default is `False` (ascending).
170
+ :param cf_name: The name of the column family to iterate over.
171
+ Default is "default".
172
+ :return: An iterator yielding (key, value) tuples.
173
+ """
174
+ cf = self .get_column_family (cf_name = cf_name )
175
+
176
+ # Set iterator bounds to reduce IO by limiting the range of keys fetched
177
+ read_opt = ReadOptions ()
178
+ read_opt .set_iterate_lower_bound (lower_bound )
179
+ read_opt .set_iterate_upper_bound (upper_bound )
180
+
181
+ from_key = upper_bound if backwards else lower_bound
182
+
183
+ # RDict accepts Any type as value but we only write bytes so we should only get bytes back.
184
+ items = cast (
185
+ Iterator [tuple [bytes , bytes ]],
186
+ cf .items (from_key = from_key , read_opt = read_opt , backwards = backwards ),
187
+ )
188
+
189
+ if not backwards :
190
+ # NOTE: Forward iteration respects bounds correctly.
191
+ # Also, we need to use yield from notation to replace RdictItems
192
+ # with Python-native generator or else garbage collection
193
+ # will make the result unpredictable.
194
+ yield from items
195
+ else :
196
+ # NOTE: When iterating backwards, the `read_opt` lower bound
197
+ # is not respected by Rdict for some reason. We need to manually
198
+ # filter it here.
199
+ for key , value in items :
200
+ if lower_bound <= key :
201
+ yield key , value
202
+
203
+ def begin (self ) -> PartitionTransaction :
204
+ return PartitionTransaction (
205
+ partition = self ,
206
+ dumps = self ._dumps ,
207
+ loads = self ._loads ,
208
+ changelog_producer = self ._changelog_producer ,
209
+ )
210
+
142
211
def exists (self , key : bytes , cf_name : str = "default" ) -> bool :
143
212
"""
144
213
Check if a key is present in the DB.
@@ -328,3 +397,9 @@ def _update_changelog_offset(self, batch: WriteBatch, offset: int):
328
397
int_to_int64_bytes (offset ),
329
398
self .get_column_family_handle (METADATA_CF_NAME ),
330
399
)
400
+
401
+ def _ensure_column_family (self , cf_name : str ) -> None :
402
+ try :
403
+ self .get_column_family (cf_name )
404
+ except ColumnFamilyDoesNotExist :
405
+ self .create_column_family (cf_name )
0 commit comments