Skip to content

Commit 16d8637

Browse files
authored
Calculate window watermarks by individual keys instead of partitions (#591)
1 parent fafc2f0 commit 16d8637

File tree

9 files changed

+131
-173
lines changed

9 files changed

+131
-173
lines changed

examples/bank_example/quix_platform_version/producer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66
from dotenv import load_dotenv
77

88
from quixstreams import Application
9-
from quixstreams.models.serializers import (
10-
QuixTimeseriesSerializer,
11-
)
9+
from quixstreams.models.serializers import QuixTimeseriesSerializer
1210

1311
load_dotenv("./bank_example/quix_platform_version/quix_vars.env")
1412

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY = b"__expired_start_gt__"
22

33
LATEST_EXPIRED_WINDOW_CF_NAME = "__expiration-index__"
4-
5-
LATEST_TIMESTAMP_KEY = b"__topic_latest_timestamp__"
4+
LATEST_TIMESTAMPS_CF_NAME = "__latest-timestamps__"
5+
LATEST_TIMESTAMP_KEY = b"__latest_timestamp__"
Lines changed: 6 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
import logging
22
from typing import Optional
33

4-
from rocksdict import RdictItems, ReadOptions, WriteBatch # type: ignore
4+
from rocksdict import RdictItems, ReadOptions # type: ignore
55

6-
from quixstreams.state.base import PartitionTransactionCache
76
from quixstreams.state.recovery import ChangelogProducer
8-
from quixstreams.state.serialization import int_from_int64_bytes, int_to_int64_bytes
97

108
from ..exceptions import ColumnFamilyDoesNotExist
11-
from ..metadata import METADATA_CF_NAME
129
from ..partition import RocksDBStorePartition
1310
from ..types import RocksDBOptionsType
14-
from .metadata import LATEST_EXPIRED_WINDOW_CF_NAME, LATEST_TIMESTAMP_KEY
11+
from .metadata import (
12+
LATEST_EXPIRED_WINDOW_CF_NAME,
13+
LATEST_TIMESTAMPS_CF_NAME,
14+
)
1515
from .transaction import WindowedRocksDBPartitionTransaction
1616

1717
logger = logging.getLogger(__name__)
@@ -38,8 +38,8 @@ def __init__(
3838
super().__init__(
3939
path=path, options=options, changelog_producer=changelog_producer
4040
)
41-
self._latest_timestamp_ms = self._get_latest_timestamp_from_db()
4241
self._ensure_column_family(LATEST_EXPIRED_WINDOW_CF_NAME)
42+
self._ensure_column_family(LATEST_TIMESTAMPS_CF_NAME)
4343

4444
def iter_items(
4545
self, from_key: bytes, read_opt: ReadOptions, cf_name: str = "default"
@@ -52,48 +52,11 @@ def begin(self) -> "WindowedRocksDBPartitionTransaction":
5252
partition=self,
5353
dumps=self._dumps,
5454
loads=self._loads,
55-
latest_timestamp_ms=self._latest_timestamp_ms,
5655
changelog_producer=self._changelog_producer,
5756
)
5857

59-
def set_latest_timestamp(self, timestamp_ms: int):
60-
self._latest_timestamp_ms = timestamp_ms
61-
62-
def _get_latest_timestamp_from_db(self) -> int:
63-
value = self.get(LATEST_TIMESTAMP_KEY, cf_name=METADATA_CF_NAME)
64-
if value is None:
65-
return 0
66-
return int_from_int64_bytes(value)
67-
6858
def _ensure_column_family(self, cf_name: str):
6959
try:
7060
self.get_column_family(cf_name)
7161
except ColumnFamilyDoesNotExist:
7262
self.create_column_family(cf_name)
73-
74-
def write(
75-
self,
76-
cache: PartitionTransactionCache,
77-
processed_offset: Optional[int],
78-
changelog_offset: Optional[int],
79-
batch: Optional[WriteBatch] = None,
80-
latest_timestamp_ms: Optional[int] = None,
81-
):
82-
batch = WriteBatch(raw_mode=True)
83-
84-
if latest_timestamp_ms is not None:
85-
cf_handle = self.get_column_family_handle(METADATA_CF_NAME)
86-
batch.put(
87-
LATEST_TIMESTAMP_KEY,
88-
int_to_int64_bytes(latest_timestamp_ms),
89-
cf_handle,
90-
)
91-
super().write(
92-
cache=cache,
93-
processed_offset=processed_offset,
94-
changelog_offset=changelog_offset,
95-
batch=batch,
96-
)
97-
98-
if latest_timestamp_ms is not None:
99-
self.set_latest_timestamp(latest_timestamp_ms)

quixstreams/state/rocksdb/windowed/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ def update_window(self, start_ms: int, end_ms: int, value: Any, timestamp_ms: in
5959

6060
def get_latest_timestamp(self) -> int:
6161
"""
62-
Get the latest observed timestamp for the current state partition.
62+
Get the latest observed timestamp for the current message key.
6363
6464
Use this timestamp to determine if the arriving event is late and should be
6565
discarded from the processing.
6666
6767
:return: latest observed event timestamp in milliseconds
6868
"""
6969

70-
return self._transaction.get_latest_timestamp()
70+
return self._transaction.get_latest_timestamp(prefix=self._prefix)
7171

7272
def expire_windows(
7373
self, duration_ms: int, grace_ms: int = 0

quixstreams/state/rocksdb/windowed/transaction.py

Lines changed: 76 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
from rocksdict import ReadOptions
55

66
from quixstreams.state.base.transaction import PartitionTransaction
7-
from quixstreams.state.exceptions import InvalidChangelogOffset
87
from quixstreams.state.metadata import DEFAULT_PREFIX, PREFIX_SEPARATOR
98
from quixstreams.state.recovery import ChangelogProducer
10-
from quixstreams.state.serialization import (
11-
DumpsFunc,
12-
LoadsFunc,
13-
serialize,
14-
)
9+
from quixstreams.state.serialization import DumpsFunc, LoadsFunc, serialize
1510

16-
from .metadata import LATEST_EXPIRED_WINDOW_CF_NAME, LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY
11+
from .metadata import (
12+
LATEST_EXPIRED_WINDOW_CF_NAME,
13+
LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY,
14+
LATEST_TIMESTAMP_KEY,
15+
LATEST_TIMESTAMPS_CF_NAME,
16+
)
1717
from .serialization import encode_window_key, encode_window_prefix, parse_window_key
1818
from .state import WindowedTransactionState
1919

@@ -22,14 +22,11 @@
2222

2323

2424
class WindowedRocksDBPartitionTransaction(PartitionTransaction):
25-
__slots__ = ("_latest_timestamp_ms",)
26-
2725
def __init__(
2826
self,
2927
partition: "WindowedRocksDBStorePartition",
3028
dumps: DumpsFunc,
3129
loads: LoadsFunc,
32-
latest_timestamp_ms: int,
3330
changelog_producer: Optional[ChangelogProducer] = None,
3431
):
3532
super().__init__(
@@ -39,7 +36,11 @@ def __init__(
3936
changelog_producer=changelog_producer,
4037
)
4138
self._partition = cast("WindowedRocksDBStorePartition", self._partition)
42-
self._latest_timestamp_ms = latest_timestamp_ms
39+
# Cache the metadata separately to avoid serdes on each access
40+
# (we are 100% sure that the underlying types are immutable, while windows'
41+
# values are not)
42+
self._latest_timestamps: dict[bytes, int] = {}
43+
self._last_expired_timestamps: dict[bytes, int] = {}
4344

4445
def as_state(self, prefix: Any = DEFAULT_PREFIX) -> WindowedTransactionState:
4546
return WindowedTransactionState(
@@ -51,15 +52,19 @@ def as_state(self, prefix: Any = DEFAULT_PREFIX) -> WindowedTransactionState:
5152
),
5253
)
5354

54-
def get_latest_timestamp(self) -> int:
55-
return self._latest_timestamp_ms
55+
def get_latest_timestamp(self, prefix: bytes) -> int:
56+
cached_ts = self._latest_timestamps.get(prefix)
57+
if cached_ts is not None:
58+
return cached_ts
5659

57-
def _validate_duration(self, start_ms: int, end_ms: int):
58-
if end_ms <= start_ms:
59-
raise ValueError(
60-
f"Invalid window duration: window end {end_ms} is smaller or equal "
61-
f"than window start {start_ms}"
62-
)
60+
stored_ts = self.get(
61+
key=LATEST_TIMESTAMP_KEY,
62+
prefix=prefix,
63+
cf_name=LATEST_TIMESTAMPS_CF_NAME,
64+
default=0,
65+
)
66+
self._latest_timestamps[prefix] = stored_ts
67+
return stored_ts
6368

6469
def get_window(
6570
self,
@@ -81,34 +86,16 @@ def update_window(
8186

8287
key = encode_window_key(start_ms, end_ms)
8388
self.set(key=key, value=value, prefix=prefix)
84-
self._latest_timestamp_ms = max(self._latest_timestamp_ms, timestamp_ms)
89+
latest_timestamp_ms = self.get_latest_timestamp(prefix=prefix)
90+
self._set_latest_timestamp(
91+
prefix=prefix, timestamp_ms=max(latest_timestamp_ms, timestamp_ms)
92+
)
8593

8694
def delete_window(self, start_ms: int, end_ms: int, prefix: bytes):
8795
self._validate_duration(start_ms=start_ms, end_ms=end_ms)
8896
key = encode_window_key(start_ms, end_ms)
8997
self.delete(key=key, prefix=prefix)
9098

91-
def _flush(self, processed_offset: Optional[int], changelog_offset: Optional[int]):
92-
if self._update_cache.is_empty():
93-
return
94-
95-
if changelog_offset is not None:
96-
current_changelog_offset = self._partition.get_changelog_offset()
97-
if (
98-
current_changelog_offset is not None
99-
and changelog_offset < current_changelog_offset
100-
):
101-
raise InvalidChangelogOffset(
102-
"Cannot set changelog offset lower than already saved one"
103-
)
104-
105-
self._partition.write(
106-
cache=self._update_cache,
107-
processed_offset=processed_offset,
108-
changelog_offset=changelog_offset,
109-
latest_timestamp_ms=self._latest_timestamp_ms,
110-
)
111-
11299
def expire_windows(
113100
self, duration_ms: int, prefix: bytes, grace_ms: int = 0
114101
) -> list[tuple[tuple[int, int], Any]]:
@@ -134,16 +121,12 @@ def expire_windows(
134121
Defaults to 0, meaning no grace period is applied.
135122
:return: A generator that yields sorted tuples in the format `((start, end), value)`.
136123
"""
137-
latest_timestamp = self._latest_timestamp_ms
124+
latest_timestamp = self.get_latest_timestamp(prefix=prefix)
138125
start_to = latest_timestamp - duration_ms - grace_ms
139126
start_from = -1
140127

141128
# Find the latest start timestamp of the expired windows for the given key
142-
last_expired = self.get(
143-
key=LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY,
144-
prefix=prefix,
145-
cf_name=LATEST_EXPIRED_WINDOW_CF_NAME,
146-
)
129+
last_expired = self._get_last_expired_timestamp(prefix=prefix)
147130
if last_expired is not None:
148131
start_from = max(start_from, last_expired)
149132

@@ -160,22 +143,15 @@ def expire_windows(
160143
# Save the start of the latest expired window to the expiration index
161144
latest_window = expired_windows[-1]
162145
last_expired__gt = latest_window[0][0]
163-
self.set(
164-
key=LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY,
165-
value=last_expired__gt,
166-
prefix=prefix,
167-
cf_name=LATEST_EXPIRED_WINDOW_CF_NAME,
146+
147+
self._set_last_expired_timestamp(
148+
prefix=prefix, timestamp_ms=last_expired__gt
168149
)
169150
# Delete expired windows from the state
170151
for (start, end), _ in expired_windows:
171152
self.delete_window(start, end, prefix=prefix)
172153
return expired_windows
173154

174-
def _serialize_key(self, key: Any, prefix: bytes) -> bytes:
175-
# Allow bytes keys in WindowedStore
176-
key_bytes = key if isinstance(key, bytes) else serialize(key, dumps=self._dumps)
177-
return prefix + PREFIX_SEPARATOR + key_bytes
178-
179155
def get_windows(
180156
self,
181157
start_from_ms: int,
@@ -240,3 +216,46 @@ def get_windows(
240216
result.append(((start, end), value))
241217

242218
return result
219+
220+
def _set_latest_timestamp(self, prefix: bytes, timestamp_ms: int):
221+
self._latest_timestamps[prefix] = timestamp_ms
222+
self.set(
223+
key=LATEST_TIMESTAMP_KEY,
224+
value=timestamp_ms,
225+
prefix=prefix,
226+
cf_name=LATEST_TIMESTAMPS_CF_NAME,
227+
)
228+
229+
def _get_last_expired_timestamp(self, prefix: bytes) -> Optional[int]:
230+
cached_ts = self._last_expired_timestamps.get(prefix)
231+
if cached_ts is not None:
232+
return cached_ts
233+
234+
stored_ts = self.get(
235+
key=LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY,
236+
prefix=prefix,
237+
cf_name=LATEST_EXPIRED_WINDOW_CF_NAME,
238+
)
239+
self._last_expired_timestamps[prefix] = stored_ts
240+
return stored_ts
241+
242+
def _set_last_expired_timestamp(self, prefix: bytes, timestamp_ms: int):
243+
self._last_expired_timestamps[prefix] = timestamp_ms
244+
self.set(
245+
key=LATEST_EXPIRED_WINDOW_TIMESTAMP_KEY,
246+
value=timestamp_ms,
247+
prefix=prefix,
248+
cf_name=LATEST_EXPIRED_WINDOW_CF_NAME,
249+
)
250+
251+
def _validate_duration(self, start_ms: int, end_ms: int):
252+
if end_ms <= start_ms:
253+
raise ValueError(
254+
f"Invalid window duration: window end {end_ms} is smaller or equal "
255+
f"than window start {start_ms}"
256+
)
257+
258+
def _serialize_key(self, key: Any, prefix: bytes) -> bytes:
259+
# Allow bytes keys in WindowedStore
260+
key_bytes = key if isinstance(key, bytes) else serialize(key, dumps=self._dumps)
261+
return prefix + PREFIX_SEPARATOR + key_bytes

quixstreams/state/types.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,10 @@ def update_window(
163163
"""
164164
...
165165

166-
def get_latest_timestamp(self) -> int:
166+
def get_latest_timestamp(self, prefix: bytes) -> int:
167167
"""
168-
Get the latest observed timestamp for the current state partition.
168+
Get the latest observed timestamp for the current state prefix
169+
(same as message key).
169170
170171
Use this timestamp to determine if the arriving event is late and should be
171172
discarded from the processing.

0 commit comments

Comments
 (0)