Skip to content

Commit e57494b

Browse files
authored
Optimize state operations (#891)
- WindowedRocksDBPartitionTransaction._get_timestamp(): return the cached value if it has been set previously - set_bytes(): remove extra @validate_transaction_state decorator - set_bytes(): check if the value is bytes only at the top level
1 parent 9fccc95 commit e57494b

File tree

2 files changed

+20
-22
lines changed

2 files changed

+20
-22
lines changed

quixstreams/state/base/transaction.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,6 @@ def _get_bytes(
388388

389389
return cached
390390

391-
@validate_transaction_status(PartitionTransactionStatus.STARTED)
392391
def set(self, key: K, value: V, prefix: bytes, cf_name: str = "default") -> None:
393392
"""
394393
Set value for the key.
@@ -404,9 +403,8 @@ def set(self, key: K, value: V, prefix: bytes, cf_name: str = "default") -> None
404403
self._status = PartitionTransactionStatus.FAILED
405404
raise
406405

407-
self.set_bytes(key, value_serialized, prefix, cf_name=cf_name)
406+
self._set_bytes(key, value_serialized, prefix, cf_name=cf_name)
408407

409-
@validate_transaction_status(PartitionTransactionStatus.STARTED)
410408
def set_bytes(
411409
self, key: K, value: bytes, prefix: bytes, cf_name: str = "default"
412410
) -> None:
@@ -417,10 +415,17 @@ def set_bytes(
417415
:param value: value
418416
:param cf_name: column family name
419417
"""
420-
try:
421-
if not isinstance(value, bytes):
422-
raise StateSerializationError("Value must be bytes")
418+
if not isinstance(value, bytes):
419+
self._status = PartitionTransactionStatus.FAILED
420+
raise StateSerializationError("Value must be bytes")
421+
422+
self._set_bytes(key=key, value=value, prefix=prefix, cf_name=cf_name)
423423

424+
@validate_transaction_status(PartitionTransactionStatus.STARTED)
425+
def _set_bytes(
426+
self, key: K, value: bytes, prefix: bytes, cf_name: str = "default"
427+
) -> None:
428+
try:
424429
key_serialized = self._serialize_key(key, prefix=prefix)
425430
self._update_cache.set(
426431
key=key_serialized,

quixstreams/state/rocksdb/windowed/transaction.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from .partition import WindowedRocksDBStorePartition
3939

4040

41-
class WindowedRocksDBPartitionTransaction(PartitionTransaction[bytes, dict]):
41+
class WindowedRocksDBPartitionTransaction(PartitionTransaction[bytes, Any]):
4242
def __init__(
4343
self,
4444
partition: "WindowedRocksDBStorePartition",
@@ -108,13 +108,11 @@ def keys(self, cf_name: str = "default") -> Iterable[Any]:
108108
yield key
109109

110110
def get_latest_timestamp(self, prefix: bytes) -> int:
111-
return self._get_timestamp(
112-
prefix=prefix, cache=self._latest_timestamps, default=0
113-
)
111+
return self._get_timestamp(prefix=prefix, cache=self._latest_timestamps) or 0
114112

115113
def get_latest_expired(self, prefix: bytes) -> int:
116-
return self._get_timestamp(
117-
prefix=prefix, cache=self._last_expired_timestamps, default=0
114+
return (
115+
self._get_timestamp(prefix=prefix, cache=self._last_expired_timestamps) or 0
118116
)
119117

120118
def get_window(
@@ -315,9 +313,7 @@ def expire_all_windows(
315313
:param delete: If True, expired windows will be deleted.
316314
:param collect: If True, values will be collected into windows.
317315
"""
318-
last_expired = self._get_timestamp(
319-
prefix=b"", cache=self._last_expired_timestamps, default=0
320-
)
316+
last_expired = self.get_latest_expired(prefix=b"")
321317

322318
to_delete: set[tuple[bytes, int, int]] = set()
323319
collected = []
@@ -526,18 +522,15 @@ def _get_items(
526522
# Sort and deserialize items merged from the cache and store
527523
return sorted(merged_items.items(), key=lambda kv: kv[0], reverse=backwards)
528524

529-
def _get_timestamp(
530-
self, cache: TimestampsCache, prefix: bytes, default: Any = None
531-
) -> int:
532-
cached_ts = cache.timestamps.get(prefix)
533-
if cached_ts is not None:
534-
return cached_ts
525+
def _get_timestamp(self, cache: TimestampsCache, prefix: bytes) -> Optional[int]:
526+
if prefix in cache.timestamps:
527+
# Return the cached value if it has been set at least once
528+
return cache.timestamps[prefix]
535529

536530
stored_ts = self.get(
537531
key=cache.key,
538532
prefix=prefix,
539533
cf_name=cache.cf_name,
540-
default=default,
541534
)
542535
if stored_ts is not None and not isinstance(stored_ts, int):
543536
raise ValueError(f"invalid timestamp {stored_ts}")

0 commit comments

Comments
 (0)