|
38 | 38 | from .partition import WindowedRocksDBStorePartition
|
39 | 39 |
|
40 | 40 |
|
41 |
| -class WindowedRocksDBPartitionTransaction(PartitionTransaction[bytes, dict]): |
| 41 | +class WindowedRocksDBPartitionTransaction(PartitionTransaction[bytes, Any]): |
42 | 42 | def __init__(
|
43 | 43 | self,
|
44 | 44 | partition: "WindowedRocksDBStorePartition",
|
@@ -108,13 +108,11 @@ def keys(self, cf_name: str = "default") -> Iterable[Any]:
|
108 | 108 | yield key
|
109 | 109 |
|
110 | 110 | def get_latest_timestamp(self, prefix: bytes) -> int:
|
111 |
| - return self._get_timestamp( |
112 |
| - prefix=prefix, cache=self._latest_timestamps, default=0 |
113 |
| - ) |
| 111 | + return self._get_timestamp(prefix=prefix, cache=self._latest_timestamps) or 0 |
114 | 112 |
|
115 | 113 | def get_latest_expired(self, prefix: bytes) -> int:
|
116 |
| - return self._get_timestamp( |
117 |
| - prefix=prefix, cache=self._last_expired_timestamps, default=0 |
| 114 | + return ( |
| 115 | + self._get_timestamp(prefix=prefix, cache=self._last_expired_timestamps) or 0 |
118 | 116 | )
|
119 | 117 |
|
120 | 118 | def get_window(
|
@@ -315,9 +313,7 @@ def expire_all_windows(
|
315 | 313 | :param delete: If True, expired windows will be deleted.
|
316 | 314 | :param collect: If True, values will be collected into windows.
|
317 | 315 | """
|
318 |
| - last_expired = self._get_timestamp( |
319 |
| - prefix=b"", cache=self._last_expired_timestamps, default=0 |
320 |
| - ) |
| 316 | + last_expired = self.get_latest_expired(prefix=b"") |
321 | 317 |
|
322 | 318 | to_delete: set[tuple[bytes, int, int]] = set()
|
323 | 319 | collected = []
|
@@ -526,18 +522,15 @@ def _get_items(
|
526 | 522 | # Sort and deserialize items merged from the cache and store
|
527 | 523 | return sorted(merged_items.items(), key=lambda kv: kv[0], reverse=backwards)
|
528 | 524 |
|
529 |
| - def _get_timestamp( |
530 |
| - self, cache: TimestampsCache, prefix: bytes, default: Any = None |
531 |
| - ) -> int: |
532 |
| - cached_ts = cache.timestamps.get(prefix) |
533 |
| - if cached_ts is not None: |
534 |
| - return cached_ts |
| 525 | + def _get_timestamp(self, cache: TimestampsCache, prefix: bytes) -> Optional[int]: |
| 526 | + if prefix in cache.timestamps: |
| 527 | + # Return the cached value if it has been set at least once |
| 528 | + return cache.timestamps[prefix] |
535 | 529 |
|
536 | 530 | stored_ts = self.get(
|
537 | 531 | key=cache.key,
|
538 | 532 | prefix=prefix,
|
539 | 533 | cf_name=cache.cf_name,
|
540 |
| - default=default, |
541 | 534 | )
|
542 | 535 | if stored_ts is not None and not isinstance(stored_ts, int):
|
543 | 536 | raise ValueError(f"invalid timestamp {stored_ts}")
|
|
0 commit comments