1
1
import abc
2
2
from abc import abstractmethod
3
- from typing import TYPE_CHECKING , Any , Callable , Optional , Union
3
+ from typing import TYPE_CHECKING , Any , Callable , Generic , Optional , TypeVar , Union
4
4
5
5
from .aggregations import (
6
6
BaseAggregator ,
45
45
"TumblingTimeWindowDefinition" ,
46
46
]
47
47
48
+ WindowT = TypeVar ("WindowT" , bound = Window )
48
49
49
- class WindowDefinition (abc .ABC ):
50
+
51
+ class WindowDefinition (abc .ABC , Generic [WindowT ]):
50
52
def __init__ (
51
53
self ,
52
54
name : Optional [str ],
@@ -65,9 +67,9 @@ def _create_window(
65
67
func_name : Optional [str ],
66
68
aggregators : Optional [dict [str , BaseAggregator ]] = None ,
67
69
collectors : Optional [dict [str , BaseCollector ]] = None ,
68
- ) -> Window : ...
70
+ ) -> WindowT : ...
69
71
70
- def sum (self ) -> "Window" :
72
+ def sum (self ) -> WindowT :
71
73
"""
72
74
Configure the window to aggregate data by summing up values within
73
75
each window period.
@@ -80,7 +82,7 @@ def sum(self) -> "Window":
80
82
aggregators = {"value" : Sum (column = None )},
81
83
)
82
84
83
- def count (self ) -> "Window" :
85
+ def count (self ) -> WindowT :
84
86
"""
85
87
Configure the window to aggregate data by counting the number of values
86
88
within each window period.
@@ -93,7 +95,7 @@ def count(self) -> "Window":
93
95
aggregators = {"value" : Count ()},
94
96
)
95
97
96
- def mean (self ) -> "Window" :
98
+ def mean (self ) -> WindowT :
97
99
"""
98
100
Configure the window to aggregate data by calculating the mean of the values
99
101
within each window period.
@@ -109,7 +111,7 @@ def mean(self) -> "Window":
109
111
110
112
def reduce (
111
113
self , reducer : Callable [[Any , Any ], Any ], initializer : Callable [[Any ], Any ]
112
- ) -> "Window" :
114
+ ) -> WindowT :
113
115
"""
114
116
Configure the window to perform a custom aggregation using `reducer`
115
117
and `initializer` functions.
@@ -151,7 +153,7 @@ def initializer(current) -> dict:
151
153
aggregators = {"value" : Reduce (reducer = reducer , initializer = initializer )},
152
154
)
153
155
154
- def max (self ) -> "Window" :
156
+ def max (self ) -> WindowT :
155
157
"""
156
158
Configure a window to aggregate the maximum value within each window period.
157
159
@@ -164,7 +166,7 @@ def max(self) -> "Window":
164
166
aggregators = {"value" : Max (column = None )},
165
167
)
166
168
167
- def min (self ) -> "Window" :
169
+ def min (self ) -> WindowT :
168
170
"""
169
171
Configure a window to aggregate the minimum value within each window period.
170
172
@@ -177,7 +179,7 @@ def min(self) -> "Window":
177
179
aggregators = {"value" : Min (column = None )},
178
180
)
179
181
180
- def collect (self ) -> "Window" :
182
+ def collect (self ) -> WindowT :
181
183
"""
182
184
Configure the window to collect all values within each window period into a
183
185
list, without performing any aggregation.
@@ -202,7 +204,7 @@ def collect(self) -> "Window":
202
204
collectors = {"value" : Collect (column = None )},
203
205
)
204
206
205
- def agg (self , ** operations : Union [BaseAggregator , BaseCollector ]) -> "Window" :
207
+ def agg (self , ** operations : Union [BaseAggregator , BaseCollector ]) -> WindowT :
206
208
if "start" in operations or "end" in operations :
207
209
raise ValueError (
208
210
"`start` and `end` are reserved keywords for the window boundaries"
@@ -228,7 +230,7 @@ def agg(self, **operations: Union[BaseAggregator, BaseCollector]) -> "Window":
228
230
)
229
231
230
232
231
- class TimeWindowDefinition (WindowDefinition ):
233
+ class TimeWindowDefinition (WindowDefinition [ WindowT ], Generic [ WindowT ] ):
232
234
def __init__ (
233
235
self ,
234
236
duration_ms : int ,
@@ -270,7 +272,7 @@ def step_ms(self) -> Optional[int]:
270
272
return self ._step_ms
271
273
272
274
273
- class HoppingTimeWindowDefinition (TimeWindowDefinition ):
275
+ class HoppingTimeWindowDefinition (TimeWindowDefinition [ TimeWindow ] ):
274
276
def __init__ (
275
277
self ,
276
278
duration_ms : int ,
@@ -321,7 +323,7 @@ def _create_window(
321
323
)
322
324
323
325
324
- class TumblingTimeWindowDefinition (TimeWindowDefinition ):
326
+ class TumblingTimeWindowDefinition (TimeWindowDefinition [ TimeWindow ] ):
325
327
def __init__ (
326
328
self ,
327
329
duration_ms : int ,
@@ -369,7 +371,7 @@ def _create_window(
369
371
)
370
372
371
373
372
- class SlidingTimeWindowDefinition (TimeWindowDefinition ):
374
+ class SlidingTimeWindowDefinition (TimeWindowDefinition [ SlidingWindow ] ):
373
375
def __init__ (
374
376
self ,
375
377
duration_ms : int ,
@@ -418,7 +420,7 @@ def _create_window(
418
420
)
419
421
420
422
421
- class CountWindowDefinition (WindowDefinition ):
423
+ class CountWindowDefinition (WindowDefinition [ CountWindow ] ):
422
424
def __init__ (
423
425
self , count : int , dataframe : "StreamingDataFrame" , name : Optional [str ] = None
424
426
) -> None :
0 commit comments