Skip to content

Commit 3efd812

Browse files
committed
Add generic typing to the StorePartition classes to correctly reveal the transaction type
1 parent da3df51 commit 3efd812

File tree

3 files changed

+36
-16
lines changed

3 files changed

+36
-16
lines changed

quixstreams/state/base/partition.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from abc import ABC, abstractmethod
3-
from typing import TYPE_CHECKING, Literal, Optional, Union
3+
from typing import TYPE_CHECKING, Generic, Literal, Optional, TypeVar, Union, cast
44

55
from quixstreams.state.metadata import (
66
Marker,
@@ -17,8 +17,10 @@
1717

1818
logger = logging.getLogger(__name__)
1919

20+
T = TypeVar("T", bound=PartitionTransaction)
2021

21-
class StorePartition(ABC):
22+
23+
class StorePartition(ABC, Generic[T]):
2224
"""
2325
A base class to access state in the underlying storage.
2426
It represents a single instance of some storage (e.g. a single database for
@@ -108,17 +110,20 @@ def recover_from_changelog_message(
108110
:param offset: changelog message offset
109111
"""
110112

111-
def begin(self) -> PartitionTransaction:
113+
def begin(self) -> T:
112114
"""
113115
Start a new `PartitionTransaction`
114116
115117
Using `PartitionTransaction` is a recommended way for accessing the data.
116118
"""
117-
return self.partition_transaction_class(
118-
partition=self,
119-
dumps=self._dumps,
120-
loads=self._loads,
121-
changelog_producer=self._changelog_producer,
119+
return cast(
120+
T,
121+
self.partition_transaction_class(
122+
partition=self,
123+
dumps=self._dumps,
124+
loads=self._loads,
125+
changelog_producer=self._changelog_producer,
126+
),
122127
)
123128

124129
def __enter__(self):

quixstreams/state/rocksdb/partition.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
import logging
22
import time
3-
from typing import Dict, Iterator, List, Literal, Optional, Union, cast
3+
from typing import (
4+
Dict,
5+
Iterator,
6+
List,
7+
Literal,
8+
Optional,
9+
TypeVar,
10+
Union,
11+
cast,
12+
)
413

514
from rocksdict import AccessType, ColumnFamily, Rdict, ReadOptions, WriteBatch
615

7-
from quixstreams.state.base import PartitionTransactionCache, StorePartition
16+
from quixstreams.state.base import (
17+
PartitionTransaction,
18+
PartitionTransactionCache,
19+
StorePartition,
20+
)
821
from quixstreams.state.exceptions import ColumnFamilyDoesNotExist
922
from quixstreams.state.metadata import METADATA_CF_NAME, Marker
1023
from quixstreams.state.recovery import ChangelogProducer
11-
from quixstreams.state.serialization import (
12-
int_from_int64_bytes,
13-
int_to_int64_bytes,
14-
)
24+
from quixstreams.state.serialization import int_from_int64_bytes, int_to_int64_bytes
1525

1626
from .exceptions import ColumnFamilyAlreadyExists
1727
from .metadata import (
@@ -22,10 +32,13 @@
2232

2333
__all__ = ("RocksDBStorePartition",)
2434

35+
2536
logger = logging.getLogger(__name__)
2637

38+
T = TypeVar("T", bound=PartitionTransaction)
39+
2740

28-
class RocksDBStorePartition(StorePartition):
41+
class RocksDBStorePartition(StorePartition[T]):
2942
"""
3043
A base class to access state in RocksDB.
3144
It represents a single RocksDB database.

quixstreams/state/rocksdb/windowed/partition.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
logger = logging.getLogger(__name__)
1616

1717

18-
class WindowedRocksDBStorePartition(RocksDBStorePartition):
18+
class WindowedRocksDBStorePartition(
19+
RocksDBStorePartition[WindowedRocksDBPartitionTransaction]
20+
):
1921
"""
2022
A base class to access windowed state in RocksDB.
2123
It represents a single RocksDB database.

0 commit comments

Comments
 (0)