Skip to content

Commit 81cbc83

Browse files
authored
Abstract away the state update cache (#576)
1 parent ab6d437 commit 81cbc83

File tree

7 files changed

+317
-116
lines changed

7 files changed

+317
-116
lines changed

quixstreams/state/base/partition.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from quixstreams.state.serialization import DumpsFunc, LoadsFunc
1313
from quixstreams.utils.json import loads as json_loads
1414

15-
from .transaction import PartitionTransaction, CACHE_TYPE
15+
from .transaction import PartitionTransaction, PartitionTransactionCache
1616

1717
if TYPE_CHECKING:
1818
from quixstreams.state.recovery import ChangelogProducer
@@ -72,14 +72,14 @@ def get_changelog_offset(self) -> Optional[int]:
7272
@abstractmethod
7373
def write(
7474
self,
75-
data: CACHE_TYPE,
75+
cache: PartitionTransactionCache,
7676
processed_offset: Optional[int],
7777
changelog_offset: Optional[int],
7878
):
7979
"""
80-
Update the state with data
80+
Update the state with data from the update cache
8181
82-
:param data: The modified data
82+
:param cache: The modified data
8383
:param processed_offset: The offset processed to generate the data.
8484
:param changelog_offset: The changelog message offset of the data.
8585
"""

quixstreams/state/base/transaction.py

Lines changed: 157 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,11 @@
11
import enum
2-
import logging
32
import functools
4-
from typing import (
5-
Any,
6-
Optional,
7-
Dict,
8-
Tuple,
9-
Union,
10-
TYPE_CHECKING,
11-
)
12-
3+
import logging
134
from abc import ABC
5+
from collections import defaultdict
6+
from typing import Any, Optional, Dict, Tuple, Union, TYPE_CHECKING, Set
147

15-
from quixstreams.state.exceptions import (
16-
StateTransactionError,
17-
InvalidChangelogOffset,
18-
)
8+
from quixstreams.state.exceptions import StateTransactionError, InvalidChangelogOffset
199
from quixstreams.state.metadata import (
2010
CHANGELOG_CF_MESSAGE_HEADER,
2111
CHANGELOG_PROCESSED_OFFSET_MESSAGE_HEADER,
@@ -25,24 +15,129 @@
2515
Undefined,
2616
DEFAULT_PREFIX,
2717
)
28-
from quixstreams.state.serialization import (
29-
serialize,
30-
deserialize,
31-
LoadsFunc,
32-
DumpsFunc,
33-
)
18+
from quixstreams.state.serialization import serialize, deserialize, LoadsFunc, DumpsFunc
3419
from quixstreams.utils.json import dumps as json_dumps
35-
3620
from .state import State, TransactionState
3721

3822
if TYPE_CHECKING:
3923
from quixstreams.state.recovery import ChangelogProducer
4024
from .partition import StorePartition
4125

42-
__all__ = ("PartitionTransactionStatus", "PartitionTransaction", "CACHE_TYPE")
26+
__all__ = (
27+
"PartitionTransactionStatus",
28+
"PartitionTransaction",
29+
"PartitionTransactionCache",
30+
)
4331

4432
logger = logging.getLogger(__name__)
45-
CACHE_TYPE = Dict[str, Dict[bytes, Dict[bytes, Union[bytes, Undefined]]]]
33+
34+
35+
class PartitionTransactionCache:
36+
"""
37+
A cache with the data updated in the current PartitionTransaction.
38+
It is used to read-your-own-writes before the transaction is committed to the Store.
39+
40+
Internally, updates and deletes are separated into two separate structures
41+
to simplify the querying over them.
42+
"""
43+
44+
def __init__(self):
45+
# A map with updated keys in format {<cf>: {<prefix>: {<key>: <value>}}}
46+
# Note: "updates" are bucketed per prefix to speed up iterating over the
47+
# specific set of keys when we merge updates with data from the stores.
48+
# Using a prefix like that allows us to perform fewer iterations.
49+
self._updated: dict[str, dict[bytes, dict[bytes, bytes]]] = defaultdict(
50+
lambda: defaultdict(dict)
51+
)
52+
# Dict of sets with deleted keys in format {<cf>: set[<key1>, <key2>]}
53+
# Deletes are stored without prefixes because we don't need to iterate over
54+
# them.
55+
self._deleted: dict[str, set[bytes]] = defaultdict(set)
56+
self._empty = True
57+
58+
def get(
59+
self,
60+
key: bytes,
61+
prefix: bytes,
62+
cf_name: str = "default",
63+
) -> Union[bytes, Undefined]:
64+
"""
65+
Get a value for the key.
66+
67+
Returns the key value if it has been updated during the transaction.
68+
69+
If the key has already been deleted, returns "DELETED" sentinel
70+
(we don't need to check the actual store).
71+
If the key is not present in the cache, returns "UNDEFINED sentinel
72+
(we need to check the store).
73+
74+
:param: key: key as bytes
75+
:param: prefix: key prefix as bytes
76+
:param: cf_name: column family name
77+
"""
78+
# Check if the key has been deleted
79+
if key in self._deleted[cf_name]:
80+
# The key is deleted and the store doesn't need to be checked
81+
return DELETED
82+
83+
# Check if the key has been updated
84+
# If the key is not present in the cache, we need to check the store and return
85+
# UNDEFINED to signify that
86+
return self._updated[cf_name][prefix].get(key, UNDEFINED)
87+
88+
def set(self, key: bytes, value: bytes, prefix: bytes, cf_name: str = "default"):
89+
"""
90+
Set a value for the key.
91+
92+
:param: key: key as bytes
93+
:param: value: value as bytes
94+
:param: prefix: key prefix as bytes
95+
:param: cf_name: column family name
96+
"""
97+
self._updated[cf_name][prefix][key] = value
98+
self._deleted[cf_name].discard(key)
99+
self._empty = False
100+
101+
def delete(self, key: Any, prefix: bytes, cf_name: str = "default"):
102+
"""
103+
Delete a key.
104+
105+
:param: key: key as bytes
106+
:param: value: value as bytes
107+
:param: prefix: key prefix as bytes
108+
:param: cf_name: column family name
109+
"""
110+
self._updated[cf_name][prefix].pop(key, None)
111+
self._deleted[cf_name].add(key)
112+
self._empty = False
113+
114+
def is_empty(self) -> bool:
115+
"""
116+
Return True if any changes have been made (updates or deletes), otherwise
117+
return False.
118+
"""
119+
return self._empty
120+
121+
def get_column_families(self) -> Set[str]:
122+
"""
123+
Get all update column families.
124+
"""
125+
return set(self._updated.keys()) | set(self._deleted.keys())
126+
127+
def get_updates(self, cf_name: str = "default") -> Dict[bytes, Dict[bytes, bytes]]:
128+
"""
129+
Get all updated keys (excluding deleted)
130+
in the format "{<prefix>: {<key>: <value>}}".
131+
132+
:param: cf_name: column family name
133+
"""
134+
return self._updated.get(cf_name, {})
135+
136+
def get_deletes(self, cf_name: str = "default") -> Set[bytes]:
137+
"""
138+
Get all deleted keys (excluding updated) as a set.
139+
"""
140+
return self._deleted[cf_name]
46141

47142

48143
class PartitionTransactionStatus(enum.Enum):
@@ -97,7 +192,7 @@ def __init__(
97192
self._loads = loads
98193
self._partition = partition
99194

100-
self._update_cache: CACHE_TYPE = {}
195+
self._update_cache = PartitionTransactionCache()
101196

102197
@property
103198
def changelog_producer(self) -> Optional["ChangelogProducer"]:
@@ -197,14 +292,13 @@ def get(
197292
:param key: key
198293
:param prefix: a key prefix
199294
:param default: default value to return if the key is not found
295+
:param cf_name: column family name
200296
:return: value or None if the key is not found and `default` is not provided
201297
"""
202298
key_serialized = self._serialize_key(key, prefix=prefix)
203299

204-
cached = (
205-
self._update_cache.get(cf_name, {})
206-
.get(prefix, {})
207-
.get(key_serialized, UNDEFINED)
300+
cached = self._update_cache.get(
301+
key=key_serialized, prefix=prefix, cf_name=cf_name
208302
)
209303
if cached is DELETED:
210304
return default
@@ -225,14 +319,18 @@ def set(self, key: Any, value: Any, prefix: bytes, cf_name: str = "default"):
225319
:param key: key
226320
:param prefix: a key prefix
227321
:param value: value
322+
:param cf_name: column family name
228323
"""
229324

230325
try:
231326
key_serialized = self._serialize_key(key, prefix=prefix)
232327
value_serialized = self._serialize_value(value)
233-
self._update_cache.setdefault(cf_name, {}).setdefault(prefix, {})[
234-
key_serialized
235-
] = value_serialized
328+
self._update_cache.set(
329+
key=key_serialized,
330+
value=value_serialized,
331+
prefix=prefix,
332+
cf_name=cf_name,
333+
)
236334
except Exception:
237335
self._status = PartitionTransactionStatus.FAILED
238336
raise
@@ -245,12 +343,13 @@ def delete(self, key: Any, prefix: bytes, cf_name: str = "default"):
245343
This function always returns `None`, even if value is not found.
246344
:param key: key
247345
:param prefix: a key prefix
346+
:param cf_name: column family name
248347
"""
249348
try:
250349
key_serialized = self._serialize_key(key, prefix=prefix)
251-
self._update_cache.setdefault(cf_name, {}).setdefault(prefix, {})[
252-
key_serialized
253-
] = DELETED
350+
self._update_cache.delete(
351+
key=key_serialized, prefix=prefix, cf_name=cf_name
352+
)
254353
except Exception:
255354
self._status = PartitionTransactionStatus.FAILED
256355
raise
@@ -261,21 +360,19 @@ def exists(self, key: Any, prefix: bytes, cf_name: str = "default") -> bool:
261360
Check if the key exists in state.
262361
:param key: key
263362
:param prefix: a key prefix
363+
:param cf_name: column family name
264364
:return: True if key exists, False otherwise
265365
"""
266366
key_serialized = self._serialize_key(key, prefix=prefix)
267-
cached = (
268-
self._update_cache.get(cf_name, {})
269-
.get(prefix, {})
270-
.get(key_serialized, UNDEFINED)
367+
cached = self._update_cache.get(
368+
key=key_serialized, prefix=prefix, cf_name=cf_name
271369
)
272370
if cached is DELETED:
273371
return False
274-
275-
if cached is not UNDEFINED:
372+
elif cached is not UNDEFINED:
276373
return True
277-
278-
return self._partition.exists(key_serialized, cf_name=cf_name)
374+
else:
375+
return self._partition.exists(key_serialized, cf_name=cf_name)
279376

280377
@validate_transaction_status(PartitionTransactionStatus.STARTED)
281378
def prepare(self, processed_offset: int):
@@ -310,21 +407,32 @@ def _prepare(self, processed_offset: int):
310407
f"partition={self._changelog_producer.partition} "
311408
f"processed_offset={processed_offset}"
312409
)
313-
for cf_name, cf_update_cache in self._update_cache.items():
314-
source_tp_offset_header = json_dumps(processed_offset)
410+
source_tp_offset_header = json_dumps(processed_offset)
411+
column_families = self._update_cache.get_column_families()
412+
413+
for cf_name in column_families:
315414
headers = {
316415
CHANGELOG_CF_MESSAGE_HEADER: cf_name,
317416
CHANGELOG_PROCESSED_OFFSET_MESSAGE_HEADER: source_tp_offset_header,
318417
}
319-
for _, prefix_update_cache in cf_update_cache.items():
418+
419+
updates = self._update_cache.get_updates(cf_name=cf_name)
420+
for prefix_update_cache in updates.values():
320421
for key, value in prefix_update_cache.items():
321-
# Produce changes to the changelog topic
322422
self._changelog_producer.produce(
323423
key=key,
324-
value=value if value is not DELETED else None,
424+
value=value,
325425
headers=headers,
326426
)
327427

428+
deletes = self._update_cache.get_deletes(cf_name=cf_name)
429+
for key in deletes:
430+
self._changelog_producer.produce(
431+
key=key,
432+
value=None,
433+
headers=headers,
434+
)
435+
328436
@validate_transaction_status(
329437
PartitionTransactionStatus.STARTED, PartitionTransactionStatus.PREPARED
330438
)
@@ -357,7 +465,7 @@ def flush(
357465
raise
358466

359467
def _flush(self, processed_offset: Optional[int], changelog_offset: Optional[int]):
360-
if not self._update_cache:
468+
if self._update_cache.is_empty():
361469
return
362470

363471
if changelog_offset is not None:
@@ -371,7 +479,7 @@ def _flush(self, processed_offset: Optional[int], changelog_offset: Optional[int
371479
)
372480

373481
self._partition.write(
374-
data=self._update_cache,
482+
cache=self._update_cache,
375483
processed_offset=processed_offset,
376484
changelog_offset=changelog_offset,
377485
)

quixstreams/state/rocksdb/partition.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from rocksdict import WriteBatch, Rdict, ColumnFamily, AccessType
66

77
from quixstreams.models import ConfluentKafkaMessageProto
8-
from quixstreams.state.metadata import DELETED
98
from quixstreams.state.recovery import ChangelogProducer
10-
from quixstreams.state.base import StorePartition, CACHE_TYPE
9+
from quixstreams.state.base import StorePartition, PartitionTransactionCache
1110
from quixstreams.state.serialization import (
1211
int_from_int64_bytes,
1312
int_to_int64_bytes,
@@ -113,32 +112,37 @@ def _recover_from_changelog_message(
113112

114113
def write(
115114
self,
116-
data: CACHE_TYPE,
115+
cache: PartitionTransactionCache,
117116
processed_offset: Optional[int],
118117
changelog_offset: Optional[int],
119118
batch: Optional[WriteBatch] = None,
120119
):
121120
"""
122121
Write data to RocksDB
123122
124-
:param data: The modified data
123+
:param cache: The modified data
125124
:param processed_offset: The offset processed to generate the data.
126125
:param changelog_offset: The changelog message offset of the data.
126+
:param batch: prefilled `rocksdict.WriteBatch`, optional.
127127
"""
128128
if batch is None:
129129
batch = WriteBatch(raw_mode=True)
130130

131131
meta_cf_handle = self.get_column_family_handle(METADATA_CF_NAME)
132+
132133
# Iterate over the transaction update cache
133-
for cf_name, cf_update_cache in data.items():
134+
column_families = cache.get_column_families()
135+
for cf_name in column_families:
134136
cf_handle = self.get_column_family_handle(cf_name)
135-
for _prefix, prefix_update_cache in cf_update_cache.items():
137+
138+
updates = cache.get_updates(cf_name=cf_name)
139+
for prefix_update_cache in updates.values():
136140
for key, value in prefix_update_cache.items():
137-
# Apply changes to the Writebatch
138-
if value is DELETED:
139-
batch.delete(key, cf_handle)
140-
else:
141-
batch.put(key, value, cf_handle)
141+
batch.put(key, value, cf_handle)
142+
143+
deletes = cache.get_deletes(cf_name=cf_name)
144+
for key in deletes:
145+
batch.delete(key, cf_handle)
142146

143147
# Save the latest processed input topic offset
144148
if processed_offset is not None:

0 commit comments

Comments
 (0)