@@ -55,6 +55,121 @@ def Iterators_PrintOp : Iterators_Base_Op<"print"> {
55
55
// High-level iterators
56
56
//===----------------------------------------------------------------------===//
57
57
58
+ /// Looks up the given symbol, which must refer to a FuncOp, in the scope of the
59
+ /// given op and returns the function type of that symbol.
60
+ class LookupFuncType<string opName, string symbolName>
61
+ : StrFunc<"::mlir::SymbolTable::lookupNearestSymbolFrom<func::FuncOp>("
62
+ " &$" # opName # ","
63
+ " " # symbolName # ".dyn_cast<FlatSymbolRefAttr>())"
64
+ " .getFunctionType()">;
65
+
66
+ /// An init function, that is, a function with a signature of the form () -> T.
67
+ class Iterators_InitFunctionType
68
+ : Type<And<[FunctionType.predicate,
69
+ CPred<"$_self.dyn_cast<FunctionType>().getInputs().size() == 0">,
70
+ CPred<"$_self.dyn_cast<FunctionType>().getResults().size() == 1">
71
+ ]>,
72
+ "function with signature () -> T",
73
+ "FunctionType">;
74
+
75
+ /// A FlatSymbolRef referring to an init function.
76
+ def Iterators_InitFunctionSymbol
77
+ : Confined<FlatSymbolRefAttr, [
78
+ ReferToOp<"func::FuncOp">,
79
+ AttrConstraint<
80
+ SubstLeaves<"$_self",
81
+ LookupFuncType<"_op", "$_self.dyn_cast<FlatSymbolRefAttr>()">.result,
82
+ Iterators_InitFunctionType<>.predicate>,
83
+ "referring to a " # Iterators_InitFunctionType<>.summary>
84
+ ]>;
85
+
86
+ /// An accumulate function, that is, a function with a signature of the form
87
+ /// (T1, T2) -> T1.
88
+ class Iterators_AccumulateFunctionType
89
+ : Type<And<[FunctionType.predicate,
90
+ CPred<"$_self.dyn_cast<FunctionType>().getInputs().size() == 2">,
91
+ CPred<"$_self.dyn_cast<FunctionType>().getResults().size() == 1">,
92
+ AllMatchPred<["$_self.dyn_cast<FunctionType>().getInput(0)",
93
+ "$_self.dyn_cast<FunctionType>().getResult(0)"]>
94
+ ]>,
95
+ "function with signature (T1, T2) -> T1",
96
+ "FunctionType">;
97
+
98
+ /// A FlatSymbolRef referring to an accumulate function.
99
+ def Iterators_AccumulateFunctionSymbol
100
+ : Confined<FlatSymbolRefAttr, [
101
+ ReferToOp<"func::FuncOp">,
102
+ AttrConstraint<
103
+ SubstLeaves<"$_self",
104
+ LookupFuncType<"_op", "$_self.dyn_cast<FlatSymbolRefAttr>()">.result,
105
+ Iterators_AccumulateFunctionType<>.predicate>,
106
+ "referring to a " # Iterators_AccumulateFunctionType<>.summary>
107
+ ]>;
108
+
109
+ def Iterators_AccumulateOp
110
+ : Iterators_Op<"accumulate",
111
+ [AllMatch<["getInitFunc().getResultTypes().front()",
112
+ "$result.getType().dyn_cast<StreamType>().getElementType()"],
113
+ "the return type of the init function must match the result "
114
+ "element type">,
115
+ AllMatch<["getAccumulateFunc().getArgumentTypes()[0]",
116
+ "$result.getType().dyn_cast<StreamType>().getElementType()"],
117
+ "the type of the first argument of the accumulate function must "
118
+ "match the result element type">,
119
+ AllMatch<["getAccumulateFunc().getArgumentTypes()[1]",
120
+ "$input.getType().dyn_cast<StreamType>().getElementType()"],
121
+ "the type of the second argument of the accumulate function must "
122
+ "match the input element type">]> {
123
+ let summary = "Accumulate the elements of a stream into one element";
124
+ let description = [{
125
+ Accumulate the elements of the input stream into a single element, i.e.,
126
+ compute their generalized sum. This is similar to
127
+ [`std::accumulate`](https://en.cppreference.com/w/cpp/algorithm/accumulate)
128
+ in C++ and
129
+ [`functools.reduce`](https://docs.python.org/3/library/functools.html#functools.reduce)
130
+ with *initializer* in Python. The accumulator is initialized with the value
131
+ returned by the provided init function; the logic of the accumulation is
132
+ given by the provided accumulate function.
133
+
134
+ Pseudo-code:
135
+ ```
136
+ accumulator = @initFuncRef()
137
+ while (next = upstream->next()):
138
+ accumulator = @accumulateFuncRef(accumulator, next->value())
139
+ return accumulator
140
+
141
+ Example:
142
+ ```mlir
143
+ %0 = iterators.accumulate(%input, @zero_struct, @sum_struct)
144
+ : (!iterators.stream<!llvm.struct<(i32)>>)
145
+ -> !iterators.stream<!llvm.struct<(i32)>>
146
+ ```
147
+ }];
148
+ let arguments = (ins
149
+ Iterators_Stream:$input,
150
+ Iterators_InitFunctionSymbol:$initFuncRef,
151
+ Iterators_AccumulateFunctionSymbol:$accumulateFuncRef
152
+ );
153
+ let results = (outs Iterators_Stream:$result);
154
+ let assemblyFormat = [{
155
+ `(` $input `,` $initFuncRef `,` $accumulateFuncRef `)`
156
+ attr-dict `:` `(` type($input) `)` `->` type($result)
157
+ }];
158
+ let extraClassDeclaration = [{
159
+ /// Return the init function op that the initFuncRef refers to.
160
+ func::FuncOp getInitFunc() {
161
+ return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
162
+ *this, initFuncRefAttr());
163
+ }
164
+
165
+ /// Return the accumulate function op that the accumulateFuncRef refers to.
166
+ func::FuncOp getAccumulateFunc() {
167
+ return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
168
+ *this, accumulateFuncRefAttr());
169
+ }
170
+ }];
171
+ }
172
+
58
173
/// Verifies that the element types of nested arrays in the $value array
59
174
/// correspond to the types of the LLVM-struct element type of the $result
60
175
/// Stream.
@@ -97,14 +212,6 @@ def Iterators_ConstantStreamOp
97
212
let results = (outs Iterators_StreamOfLLVMStructOfNumerics:$result);
98
213
}
99
214
100
- /// Looks up the given symbol, which must refer to a FuncOp, in the scope of the
101
- /// given op and returns the function type of that symbol.
102
- class LookupFuncType<string opName, string symbolName>
103
- : StrFunc<"::mlir::SymbolTable::lookupNearestSymbolFrom<func::FuncOp>("
104
- " &$" # opName # ","
105
- " " # symbolName # ".dyn_cast<FlatSymbolRefAttr>())"
106
- " .getFunctionType()">;
107
-
108
215
/// A predicate, that is, a function with a signature of the form (T) -> i1.
109
216
class Iterators_PredicateFunctionType
110
217
: Type<And<[FunctionType.predicate,
0 commit comments