Skip to content

Commit be32e53

Browse files
authored
windows: Split Aggregations and collectors is own classes (#772)
* windows: Split Aggregations and collectors is own classes in preparation for multiple aggregations and collectors support
1 parent 3e686d0 commit be32e53

File tree

6 files changed

+375
-163
lines changed

6 files changed

+375
-163
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
from abc import ABC, abstractmethod
2+
from typing import (
3+
Any,
4+
Callable,
5+
Generic,
6+
Hashable,
7+
Iterable,
8+
Optional,
9+
TypeVar,
10+
Union,
11+
)
12+
13+
from typing_extensions import TypeAlias
14+
15+
16+
class Aggregator(ABC):
17+
"""
18+
Base class for window aggregation.
19+
20+
Subclass it to implement custom aggregations.
21+
22+
An Aggregator reduce incoming items into a single value or group of values. When the window
23+
is closed the aggregator produce a result based on the reduced value.
24+
25+
To store all incoming items without reducing them use a `Collector`.
26+
"""
27+
28+
@abstractmethod
29+
def initialize(self) -> Any:
30+
"""
31+
This method is triggered once to build the aggregation starting value.
32+
It should return the initial value for the aggregation.
33+
"""
34+
...
35+
36+
@abstractmethod
37+
def agg(self, old: Any, new: Any) -> Any:
38+
"""
39+
This method is trigged when a window is updated with a new value.
40+
It should return the updated aggregated value.
41+
"""
42+
...
43+
44+
@abstractmethod
45+
def result(self, value: Any) -> Any:
46+
"""
47+
This method is triggered when a window is closed.
48+
It should return the final aggregation result.
49+
"""
50+
...
51+
52+
53+
V = TypeVar("V", int, float)
54+
55+
56+
class ROOT:
57+
pass
58+
59+
60+
Column: TypeAlias = Union[Hashable, type[ROOT]]
61+
62+
63+
class Sum(Aggregator):
64+
def __init__(self, column: Column = ROOT) -> None:
65+
self.column = column
66+
67+
def initialize(self) -> int:
68+
return 0
69+
70+
def agg(self, old: V, new: Any) -> V:
71+
if self.column is ROOT:
72+
return old + new
73+
return old + new[self.column]
74+
75+
def result(self, value: V) -> V:
76+
return value
77+
78+
79+
class Count(Aggregator):
80+
def initialize(self) -> int:
81+
return 0
82+
83+
def agg(self, old: int, new: Any) -> int:
84+
return old + 1
85+
86+
def result(self, value: int) -> int:
87+
return value
88+
89+
90+
class Mean(Aggregator):
91+
def __init__(self, column: Column = ROOT) -> None:
92+
self.column = column
93+
94+
def initialize(self) -> tuple[float, int]:
95+
return 0.0, 0
96+
97+
def agg(self, old: tuple[V, int], new: Any) -> tuple[V, int]:
98+
old_sum, old_count = old
99+
if self.column is ROOT:
100+
return old_sum + new, old_count + 1
101+
return old_sum + new[self.column], old_count + 1
102+
103+
def result(self, value: tuple[Union[int, float], int]) -> float:
104+
sum_, count_ = value
105+
return sum_ / count_
106+
107+
108+
R = TypeVar("R", int, float)
109+
110+
111+
class Reduce(Aggregator, Generic[R]):
112+
def __init__(
113+
self,
114+
reducer: Callable[[R, Any], R],
115+
initializer: Callable[[Any], R],
116+
) -> None:
117+
self._initializer: Callable[[Any], R] = initializer
118+
self._reducer: Callable[[R, Any], R] = reducer
119+
120+
def initialize(self) -> Any:
121+
return None
122+
123+
def agg(self, old: R, new: Any) -> Any:
124+
return self._initializer(new) if old is None else self._reducer(old, new)
125+
126+
def result(self, value: R) -> R:
127+
return value
128+
129+
130+
class Max(Aggregator):
131+
def __init__(self, column: Column = ROOT) -> None:
132+
self.column = column
133+
134+
def initialize(self) -> None:
135+
return None
136+
137+
def agg(self, old: Optional[V], new: Any) -> V:
138+
if self.column is not ROOT:
139+
new = new[self.column]
140+
if old is None:
141+
return new
142+
return max(old, new)
143+
144+
def result(self, value: V) -> V:
145+
return value
146+
147+
148+
class Min(Aggregator):
149+
def __init__(self, column: Column = ROOT) -> None:
150+
self.column = column
151+
152+
def initialize(self) -> None:
153+
return None
154+
155+
def agg(self, old: Optional[V], new: Any) -> V:
156+
if self.column is not ROOT:
157+
new = new[self.column]
158+
if old is None:
159+
return new
160+
return min(old, new)
161+
162+
def result(self, value: V) -> V:
163+
return value
164+
165+
166+
I = TypeVar("I")
167+
168+
169+
class Collector(ABC, Generic[I]):
170+
"""
171+
Base class for window collections.
172+
173+
Subclass it to implement custom collections.
174+
175+
A Collector store incoming items un-modified in an optimized way.
176+
177+
To reduce incoming items as they come in use an `Aggregator`.
178+
"""
179+
180+
@property
181+
@abstractmethod
182+
def column(self) -> Column:
183+
"""
184+
The column to collect.
185+
186+
Use `ROOT` to collect the whole message.
187+
"""
188+
...
189+
190+
@abstractmethod
191+
def result(self, items: Iterable[I]) -> Any:
192+
"""
193+
This method is triggered when a window is closed.
194+
It should return the final collection result.
195+
"""
196+
...
197+
198+
199+
class Collect(Collector):
200+
def __init__(self, column: Column = ROOT) -> None:
201+
self._column = column
202+
203+
@property
204+
def column(self) -> Column:
205+
return self._column
206+
207+
def result(self, items: Iterable[Any]) -> list[Any]:
208+
return list(items)

quixstreams/dataframe/windows/base.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from quixstreams.processing import ProcessingContext
2323
from quixstreams.state import WindowedPartitionTransaction
2424

25+
from .aggregations import Aggregator, Collector
26+
2527
if TYPE_CHECKING:
2628
from quixstreams.dataframe.dataframe import StreamingDataFrame
2729

@@ -38,7 +40,6 @@ class WindowResult(TypedDict):
3840
Message: TypeAlias = tuple[WindowResult, Any, int, Any]
3941

4042
WindowAggregateFunc = Callable[[Any, Any], Any]
41-
WindowMergeFunc = Callable[[Any], Any]
4243

4344
TransformRecordCallbackExpandedWindowed = Callable[
4445
[Any, Any, int, Any, WindowedPartitionTransaction],
@@ -51,13 +52,26 @@ def __init__(
5152
self,
5253
name: str,
5354
dataframe: "StreamingDataFrame",
55+
aggregators: dict[str, Aggregator],
56+
collectors: dict[str, Collector],
5457
) -> None:
5558
if not name:
5659
raise ValueError("Window name must not be empty")
5760

5861
self._name = name
5962
self._dataframe = dataframe
6063

64+
self._aggregators = aggregators
65+
self._aggregate = len(aggregators) > 0
66+
67+
self._collectors = collectors
68+
self._collect = len(collectors) > 0
69+
70+
if not self._collect and not self._aggregate:
71+
raise ValueError("At least one aggregation or collector must be defined")
72+
elif len(collectors) + len(aggregators) > 1:
73+
raise ValueError("Only one aggregation or collector can be defined")
74+
6175
@property
6276
def name(self) -> str:
6377
return self._name
@@ -260,7 +274,3 @@ def get_window_ranges(
260274
current_window_start -= step_ms
261275

262276
return window_ranges
263-
264-
265-
def default_merge_func(state_value: Any) -> Any:
266-
return state_value

quixstreams/dataframe/windows/count_based.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33

44
from quixstreams.state import WindowedPartitionTransaction
55

6+
from .aggregations import Aggregator, Collector
67
from .base import (
78
Window,
8-
WindowAggregateFunc,
99
WindowKeyResult,
10-
WindowMergeFunc,
1110
WindowResult,
12-
default_merge_func,
1311
)
1412

1513
if TYPE_CHECKING:
@@ -35,22 +33,21 @@ class CountWindow(Window):
3533

3634
def __init__(
3735
self,
38-
name: str,
3936
count: int,
40-
aggregate_func: WindowAggregateFunc,
41-
aggregate_default: Any,
42-
aggregate_collection: bool,
37+
name: str,
4338
dataframe: "StreamingDataFrame",
44-
merge_func: Optional[WindowMergeFunc] = None,
39+
aggregators: dict[str, Aggregator],
40+
collectors: dict[str, Collector],
4541
step: Optional[int] = None,
4642
):
47-
super().__init__(name, dataframe)
43+
super().__init__(
44+
name=name,
45+
dataframe=dataframe,
46+
aggregators=aggregators,
47+
collectors=collectors,
48+
)
4849

4950
self._max_count = count
50-
self._aggregate_func = aggregate_func
51-
self._aggregate_default = aggregate_default
52-
self._aggregate_collection = aggregate_collection
53-
self._merge_func = merge_func or default_merge_func
5451
self._step = step
5552

5653
def process_window(
@@ -86,34 +83,37 @@ def process_window(
8683
msg_id = None
8784
if len(data["windows"]) == 0:
8885
# for new tumbling window, reset the collection id to 0
89-
msg_id = 0
86+
if self._collect:
87+
window_value = msg_id = 0
88+
else:
89+
window_value = self._aggregators["value"].initialize()
9090

9191
data["windows"].append(
9292
CountWindowData(
9393
count=0,
9494
start=timestamp_ms,
9595
end=timestamp_ms,
96-
value=msg_id
97-
if self._aggregate_collection
98-
else self._aggregate_default,
96+
value=window_value,
9997
)
10098
)
10199
elif self._step is not None and data["windows"][0]["count"] % self._step == 0:
102-
if self._aggregate_collection:
103-
msg_id = data["windows"][0]["value"] + data["windows"][0]["count"]
100+
if self._collect:
101+
window_value = msg_id = (
102+
data["windows"][0]["value"] + data["windows"][0]["count"]
103+
)
104+
else:
105+
window_value = self._aggregators["value"].initialize()
104106

105107
data["windows"].append(
106108
CountWindowData(
107109
count=0,
108110
start=timestamp_ms,
109111
end=timestamp_ms,
110-
value=msg_id
111-
if self._aggregate_collection
112-
else self._aggregate_default,
112+
value=window_value,
113113
)
114114
)
115115

116-
if self._aggregate_collection:
116+
if self._collect:
117117
if msg_id is None:
118118
msg_id = data["windows"][0]["value"] + data["windows"][0]["count"]
119119

@@ -127,7 +127,7 @@ def process_window(
127127
elif timestamp_ms > window["end"]:
128128
window["end"] = timestamp_ms
129129

130-
if self._aggregate_collection:
130+
if self._collect:
131131
# window must close
132132
if window["count"] >= self._max_count:
133133
values = state.get_from_collection(
@@ -141,7 +141,7 @@ def process_window(
141141
WindowResult(
142142
start=window["start"],
143143
end=window["end"],
144-
value=self._merge_func(values),
144+
value=self._collectors["value"].result(values),
145145
),
146146
)
147147
)
@@ -157,14 +157,14 @@ def process_window(
157157

158158
state.delete_from_collection(end=delete_end, start=delete_start)
159159
else:
160-
window["value"] = self._aggregate_func(window["value"], value)
160+
window["value"] = self._aggregators["value"].agg(window["value"], value)
161161

162162
result = (
163163
key,
164164
WindowResult(
165165
start=window["start"],
166166
end=window["end"],
167-
value=self._merge_func(window["value"]),
167+
value=self._aggregators["value"].result(window["value"]),
168168
),
169169
)
170170

0 commit comments

Comments
 (0)