Skip to content

Commit acb49fa

Browse files
authored
Use generic WindowType (#942)
* Use generic WindowType * Rename WindowType to WindowT
1 parent 1ccb896 commit acb49fa

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

quixstreams/dataframe/windows/definitions.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import abc
22
from abc import abstractmethod
3-
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
3+
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar, Union
44

55
from .aggregations import (
66
BaseAggregator,
@@ -45,8 +45,10 @@
4545
"TumblingTimeWindowDefinition",
4646
]
4747

48+
WindowT = TypeVar("WindowT", bound=Window)
4849

49-
class WindowDefinition(abc.ABC):
50+
51+
class WindowDefinition(abc.ABC, Generic[WindowT]):
5052
def __init__(
5153
self,
5254
name: Optional[str],
@@ -65,9 +67,9 @@ def _create_window(
6567
func_name: Optional[str],
6668
aggregators: Optional[dict[str, BaseAggregator]] = None,
6769
collectors: Optional[dict[str, BaseCollector]] = None,
68-
) -> Window: ...
70+
) -> WindowT: ...
6971

70-
def sum(self) -> "Window":
72+
def sum(self) -> WindowT:
7173
"""
7274
Configure the window to aggregate data by summing up values within
7375
each window period.
@@ -80,7 +82,7 @@ def sum(self) -> "Window":
8082
aggregators={"value": Sum(column=None)},
8183
)
8284

83-
def count(self) -> "Window":
85+
def count(self) -> WindowT:
8486
"""
8587
Configure the window to aggregate data by counting the number of values
8688
within each window period.
@@ -93,7 +95,7 @@ def count(self) -> "Window":
9395
aggregators={"value": Count()},
9496
)
9597

96-
def mean(self) -> "Window":
98+
def mean(self) -> WindowT:
9799
"""
98100
Configure the window to aggregate data by calculating the mean of the values
99101
within each window period.
@@ -109,7 +111,7 @@ def mean(self) -> "Window":
109111

110112
def reduce(
111113
self, reducer: Callable[[Any, Any], Any], initializer: Callable[[Any], Any]
112-
) -> "Window":
114+
) -> WindowT:
113115
"""
114116
Configure the window to perform a custom aggregation using `reducer`
115117
and `initializer` functions.
@@ -151,7 +153,7 @@ def initializer(current) -> dict:
151153
aggregators={"value": Reduce(reducer=reducer, initializer=initializer)},
152154
)
153155

154-
def max(self) -> "Window":
156+
def max(self) -> WindowT:
155157
"""
156158
Configure a window to aggregate the maximum value within each window period.
157159
@@ -164,7 +166,7 @@ def max(self) -> "Window":
164166
aggregators={"value": Max(column=None)},
165167
)
166168

167-
def min(self) -> "Window":
169+
def min(self) -> WindowT:
168170
"""
169171
Configure a window to aggregate the minimum value within each window period.
170172
@@ -177,7 +179,7 @@ def min(self) -> "Window":
177179
aggregators={"value": Min(column=None)},
178180
)
179181

180-
def collect(self) -> "Window":
182+
def collect(self) -> WindowT:
181183
"""
182184
Configure the window to collect all values within each window period into a
183185
list, without performing any aggregation.
@@ -202,7 +204,7 @@ def collect(self) -> "Window":
202204
collectors={"value": Collect(column=None)},
203205
)
204206

205-
def agg(self, **operations: Union[BaseAggregator, BaseCollector]) -> "Window":
207+
def agg(self, **operations: Union[BaseAggregator, BaseCollector]) -> WindowT:
206208
if "start" in operations or "end" in operations:
207209
raise ValueError(
208210
"`start` and `end` are reserved keywords for the window boundaries"
@@ -228,7 +230,7 @@ def agg(self, **operations: Union[BaseAggregator, BaseCollector]) -> "Window":
228230
)
229231

230232

231-
class TimeWindowDefinition(WindowDefinition):
233+
class TimeWindowDefinition(WindowDefinition[WindowT], Generic[WindowT]):
232234
def __init__(
233235
self,
234236
duration_ms: int,
@@ -270,7 +272,7 @@ def step_ms(self) -> Optional[int]:
270272
return self._step_ms
271273

272274

273-
class HoppingTimeWindowDefinition(TimeWindowDefinition):
275+
class HoppingTimeWindowDefinition(TimeWindowDefinition[TimeWindow]):
274276
def __init__(
275277
self,
276278
duration_ms: int,
@@ -321,7 +323,7 @@ def _create_window(
321323
)
322324

323325

324-
class TumblingTimeWindowDefinition(TimeWindowDefinition):
326+
class TumblingTimeWindowDefinition(TimeWindowDefinition[TimeWindow]):
325327
def __init__(
326328
self,
327329
duration_ms: int,
@@ -369,7 +371,7 @@ def _create_window(
369371
)
370372

371373

372-
class SlidingTimeWindowDefinition(TimeWindowDefinition):
374+
class SlidingTimeWindowDefinition(TimeWindowDefinition[SlidingWindow]):
373375
def __init__(
374376
self,
375377
duration_ms: int,
@@ -418,7 +420,7 @@ def _create_window(
418420
)
419421

420422

421-
class CountWindowDefinition(WindowDefinition):
423+
class CountWindowDefinition(WindowDefinition[CountWindow]):
422424
def __init__(
423425
self, count: int, dataframe: "StreamingDataFrame", name: Optional[str] = None
424426
) -> None:

0 commit comments

Comments
 (0)