Skip to content

Commit 835ad39

Browse files
authored
typing: make state protocols and ABCs generic (#777)
* typing: make state protocols and ABCs generic * windows: set transaction type in count windows
1 parent b15f21b commit 835ad39

File tree

5 files changed

+98
-53
lines changed

5 files changed

+98
-53
lines changed

quixstreams/dataframe/windows/count_based.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def process_window(
5858
value: Any,
5959
key: Any,
6060
timestamp_ms: int,
61-
transaction: WindowedPartitionTransaction,
61+
transaction: WindowedPartitionTransaction[str, CountWindowsData],
6262
) -> tuple[Iterable[WindowKeyResult], Iterable[WindowKeyResult]]:
6363
"""
6464
Count based windows are different from time based windows as we don't
@@ -81,9 +81,7 @@ def process_window(
8181
optimisation. Instead the msg id reset to 0 on every new window.
8282
"""
8383
state = transaction.as_state(prefix=key)
84-
data = state.get(key=self.STATE_KEY)
85-
if data is None:
86-
data = CountWindowsData(windows=[])
84+
data = state.get(key=self.STATE_KEY, default=CountWindowsData(windows=[]))
8785

8886
msg_id = None
8987
if len(data["windows"]) == 0:

quixstreams/state/base/state.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from abc import ABC, abstractmethod
3-
from typing import TYPE_CHECKING, Any, Optional
3+
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, overload
44

55
if TYPE_CHECKING:
66
from .transaction import PartitionTransaction
@@ -10,13 +10,23 @@
1010
logger = logging.getLogger(__name__)
1111

1212

13-
class State(ABC):
13+
K = TypeVar("K")
14+
V = TypeVar("V")
15+
16+
17+
class State(ABC, Generic[K, V]):
1418
"""
1519
Primary interface for working with key-value state data from `StreamingDataFrame`
1620
"""
1721

22+
@overload
23+
def get(self, key: K) -> Optional[V]: ...
24+
25+
@overload
26+
def get(self, key: K, default: V) -> V: ...
27+
1828
@abstractmethod
19-
def get(self, key: Any, default: Any = None) -> Optional[Any]:
29+
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
2030
"""
2131
Get the value for key if key is present in the state, else default
2232
@@ -27,7 +37,7 @@ def get(self, key: Any, default: Any = None) -> Optional[Any]:
2737
...
2838

2939
@abstractmethod
30-
def set(self, key: Any, value: Any):
40+
def set(self, key: K, value: V):
3141
"""
3242
Set value for the key.
3343
:param key: key
@@ -36,7 +46,7 @@ def set(self, key: Any, value: Any):
3646
...
3747

3848
@abstractmethod
39-
def delete(self, key: Any):
49+
def delete(self, key: K):
4050
"""
4151
Delete value for the key.
4252
@@ -46,7 +56,7 @@ def delete(self, key: Any):
4656
...
4757

4858
@abstractmethod
49-
def exists(self, key: Any) -> bool:
59+
def exists(self, key: K) -> bool:
5060
"""
5161
Check if the key exists in state.
5262
:param key: key
@@ -70,7 +80,13 @@ def __init__(self, prefix: bytes, transaction: "PartitionTransaction"):
7080
self._prefix = prefix
7181
self._transaction = transaction
7282

73-
def get(self, key: Any, default: Any = None) -> Optional[Any]:
83+
@overload
84+
def get(self, key: K) -> Optional[V]: ...
85+
86+
@overload
87+
def get(self, key: K, default: V) -> V: ...
88+
89+
def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
7490
"""
7591
Get the value for key if key is present in the state, else default
7692
@@ -80,15 +96,15 @@ def get(self, key: Any, default: Any = None) -> Optional[Any]:
8096
"""
8197
return self._transaction.get(key=key, prefix=self._prefix, default=default)
8298

83-
def set(self, key: Any, value: Any):
99+
def set(self, key: K, value: V):
84100
"""
85101
Set value for the key.
86102
:param key: key
87103
:param value: value
88104
"""
89105
return self._transaction.set(key=key, value=value, prefix=self._prefix)
90106

91-
def delete(self, key: Any):
107+
def delete(self, key: K):
92108
"""
93109
Delete value for the key.
94110
@@ -97,7 +113,7 @@ def delete(self, key: Any):
97113
"""
98114
return self._transaction.delete(key=key, prefix=self._prefix)
99115

100-
def exists(self, key: Any) -> bool:
116+
def exists(self, key: K) -> bool:
101117
"""
102118
Check if the key exists in state.
103119
:param key: key

quixstreams/state/base/transaction.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
TYPE_CHECKING,
88
Any,
99
Dict,
10+
Generic,
1011
Optional,
1112
Set,
1213
Tuple,
14+
TypeVar,
1315
Union,
16+
overload,
1417
)
1518

1619
from quixstreams.models import Headers
@@ -179,7 +182,11 @@ def _wrapper(tx: "PartitionTransaction", *args, **kwargs):
179182
return wrapper
180183

181184

182-
class PartitionTransaction(ABC):
185+
K = TypeVar("K")
186+
V = TypeVar("V")
187+
188+
189+
class PartitionTransaction(ABC, Generic[K, V]):
183190
"""
184191
A transaction class to perform simple key-value operations like
185192
"get", "set", "delete" and "exists" on a single storage partition.
@@ -257,18 +264,18 @@ def changelog_topic_partition(self) -> Optional[Tuple[str, int]]:
257264
self.changelog_producer.partition,
258265
)
259266

260-
def _serialize_value(self, value: Any) -> bytes:
267+
def _serialize_value(self, value: V) -> bytes:
261268
return serialize(value, dumps=self._dumps)
262269

263-
def _deserialize_value(self, value: bytes) -> Any:
270+
def _deserialize_value(self, value: bytes) -> V:
264271
return deserialize(value, loads=self._loads)
265272

266-
def _serialize_key(self, key: Any, prefix: bytes) -> bytes:
273+
def _serialize_key(self, key: K, prefix: bytes) -> bytes:
267274
key_bytes = serialize(key, dumps=self._dumps)
268275
prefix = prefix + SEPARATOR if prefix else b""
269276
return prefix + key_bytes
270277

271-
def as_state(self, prefix: Any = DEFAULT_PREFIX) -> State:
278+
def as_state(self, prefix: Any = DEFAULT_PREFIX) -> State[K, V]:
272279
"""
273280
Create an instance implementing the `State` protocol to be provided
274281
to `StreamingDataFrame` functions.
@@ -286,14 +293,20 @@ def as_state(self, prefix: Any = DEFAULT_PREFIX) -> State:
286293
),
287294
)
288295

296+
@overload
297+
def get(self, key: K, prefix: bytes, cf_name: str = "default") -> Optional[V]: ...
298+
299+
@overload
300+
def get(self, key: K, prefix: bytes, default: V, cf_name: str = "default") -> V: ...
301+
289302
@validate_transaction_status(PartitionTransactionStatus.STARTED)
290303
def get(
291304
self,
292-
key: Any,
305+
key: K,
293306
prefix: bytes,
294-
default: Any = None,
307+
default: Optional[V] = None,
295308
cf_name: str = "default",
296-
) -> Any:
309+
) -> Optional[V]:
297310
"""
298311
Get a key from the store.
299312
@@ -323,7 +336,7 @@ def get(
323336
return self._deserialize_value(stored)
324337

325338
@validate_transaction_status(PartitionTransactionStatus.STARTED)
326-
def set(self, key: Any, value: Any, prefix: bytes, cf_name: str = "default"):
339+
def set(self, key: K, value: V, prefix: bytes, cf_name: str = "default"):
327340
"""
328341
Set value for the key.
329342
:param key: key
@@ -346,7 +359,7 @@ def set(self, key: Any, value: Any, prefix: bytes, cf_name: str = "default"):
346359
raise
347360

348361
@validate_transaction_status(PartitionTransactionStatus.STARTED)
349-
def delete(self, key: Any, prefix: bytes, cf_name: str = "default"):
362+
def delete(self, key: K, prefix: bytes, cf_name: str = "default"):
350363
"""
351364
Delete value for the key.
352365
@@ -365,7 +378,7 @@ def delete(self, key: Any, prefix: bytes, cf_name: str = "default"):
365378
raise
366379

367380
@validate_transaction_status(PartitionTransactionStatus.STARTED)
368-
def exists(self, key: Any, prefix: bytes, cf_name: str = "default") -> bool:
381+
def exists(self, key: K, prefix: bytes, cf_name: str = "default") -> bool:
369382
"""
370383
Check if the key exists in state.
371384
:param key: key

quixstreams/state/rocksdb/windowed/transaction.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,13 +358,14 @@ def expire_all_windows(
358358
prefix, start, end = parse_window_key(key)
359359
to_delete.add((prefix, start, end))
360360
if collect:
361-
value = self.get_from_collection(
361+
value: Any = self.get_from_collection(
362362
start=start,
363363
end=end,
364364
prefix=prefix,
365365
)
366366
else:
367367
value = self.get(encode_integer_pair(start, end), prefix=prefix)
368+
assert value is not None # noqa: S101
368369

369370
yield (start, end), value, prefix
370371
else:
@@ -385,6 +386,7 @@ def expire_all_windows(
385386
)
386387
else:
387388
value = self.get(encode_integer_pair(start, end), prefix=prefix)
389+
assert value is not None # noqa: S101
388390

389391
yield (start, end), value, prefix
390392

@@ -603,14 +605,15 @@ def _get_next_count(self) -> int:
603605
:return: Next sequential counter value
604606
"""
605607
cache = self._global_counter
606-
kwargs = {"key": cache.key, "prefix": b"", "cf_name": cache.cf_name}
607608

608609
if cache.counter is None:
609-
cache.counter = self.get(default=-1, **kwargs)
610+
cache.counter = self.get(
611+
default=-1, key=cache.key, prefix=b"", cf_name=cache.cf_name
612+
)
610613

611614
cache.counter += 1
612615

613-
self.set(value=cache.counter, **kwargs)
616+
self.set(value=cache.counter, key=cache.key, prefix=b"", cf_name=cache.cf_name)
614617
return cache.counter
615618

616619

0 commit comments

Comments
 (0)