@@ -55,6 +55,121 @@ def Iterators_PrintOp : Iterators_Base_Op<"print"> {
5555// High-level iterators
5656//===----------------------------------------------------------------------===//
5757
58+ /// Looks up the given symbol, which must refer to a FuncOp, in the scope of the
59+ /// given op and returns the function type of that symbol.
60+ class LookupFuncType<string opName, string symbolName>
61+ : StrFunc<"::mlir::SymbolTable::lookupNearestSymbolFrom<func::FuncOp>("
62+ " &$" # opName # ","
63+ " " # symbolName # ".dyn_cast<FlatSymbolRefAttr>())"
64+ " .getFunctionType()">;
65+
66+ /// An init function, that is, a function with a signature of the form () -> T.
67+ class Iterators_InitFunctionType
68+ : Type<And<[FunctionType.predicate,
69+ CPred<"$_self.dyn_cast<FunctionType>().getInputs().size() == 0">,
70+ CPred<"$_self.dyn_cast<FunctionType>().getResults().size() == 1">
71+ ]>,
72+ "function with signature () -> T",
73+ "FunctionType">;
74+
75+ /// A FlatSymbolRef referring to an init function.
76+ def Iterators_InitFunctionSymbol
77+ : Confined<FlatSymbolRefAttr, [
78+ ReferToOp<"func::FuncOp">,
79+ AttrConstraint<
80+ SubstLeaves<"$_self",
81+ LookupFuncType<"_op", "$_self.dyn_cast<FlatSymbolRefAttr>()">.result,
82+ Iterators_InitFunctionType<>.predicate>,
83+ "referring to a " # Iterators_InitFunctionType<>.summary>
84+ ]>;
85+
86+ /// An accumulate function, that is, a function with a signature of the form
87+ /// (T1, T2) -> T1.
88+ class Iterators_AccumulateFunctionType
89+ : Type<And<[FunctionType.predicate,
90+ CPred<"$_self.dyn_cast<FunctionType>().getInputs().size() == 2">,
91+ CPred<"$_self.dyn_cast<FunctionType>().getResults().size() == 1">,
92+ AllMatchPred<["$_self.dyn_cast<FunctionType>().getInput(0)",
93+ "$_self.dyn_cast<FunctionType>().getResult(0)"]>
94+ ]>,
95+ "function with signature (T1, T2) -> T1",
96+ "FunctionType">;
97+
98+ /// A FlatSymbolRef referring to an accumulate function.
99+ def Iterators_AccumulateFunctionSymbol
100+ : Confined<FlatSymbolRefAttr, [
101+ ReferToOp<"func::FuncOp">,
102+ AttrConstraint<
103+ SubstLeaves<"$_self",
104+ LookupFuncType<"_op", "$_self.dyn_cast<FlatSymbolRefAttr>()">.result,
105+ Iterators_AccumulateFunctionType<>.predicate>,
106+ "referring to a " # Iterators_AccumulateFunctionType<>.summary>
107+ ]>;
108+
109+ def Iterators_AccumulateOp
110+ : Iterators_Op<"accumulate",
111+ [AllMatch<["getInitFunc().getResultTypes().front()",
112+ "$result.getType().dyn_cast<StreamType>().getElementType()"],
113+ "the return type of the init function must match the result "
114+ "element type">,
115+ AllMatch<["getAccumulateFunc().getArgumentTypes()[0]",
116+ "$result.getType().dyn_cast<StreamType>().getElementType()"],
117+ "the type of the first argument of the accumulate function must "
118+ "match the result element type">,
119+ AllMatch<["getAccumulateFunc().getArgumentTypes()[1]",
120+ "$input.getType().dyn_cast<StreamType>().getElementType()"],
121+ "the type of the second argument of the accumulate function must "
122+ "match the input element type">]> {
123+ let summary = "Accumulate the elements of a stream into one element";
124+ let description = [{
125+ Accumulate the elements of the input stream into a single element, i.e.,
126+ compute their generalized sum. This is similar to
127+ [`std::accumulate`](https://en.cppreference.com/w/cpp/algorithm/accumulate)
128+ in C++ and
129+ [`functools.reduce`](https://docs.python.org/3/library/functools.html#functools.reduce)
130+ with *initializer* in Python. The accumulator is initialized with the value
131+ returned by the provided init function; the logic of the accumulation is
132+ given by the provided accumulate function.
133+
134+ Pseudo-code:
135+ ```
136+ accumulator = @initFuncRef()
137+ while (next = upstream->next()):
138+ accumulator = @accumulateFuncRef(accumulator, next->value())
139+ return accumulator
140+
141+ Example:
142+ ```mlir
143+ %0 = iterators.accumulate(%input, @zero_struct, @sum_struct)
144+ : (!iterators.stream<!llvm.struct<(i32)>>)
145+ -> !iterators.stream<!llvm.struct<(i32)>>
146+ ```
147+ }];
148+ let arguments = (ins
149+ Iterators_Stream:$input,
150+ Iterators_InitFunctionSymbol:$initFuncRef,
151+ Iterators_AccumulateFunctionSymbol:$accumulateFuncRef
152+ );
153+ let results = (outs Iterators_Stream:$result);
154+ let assemblyFormat = [{
155+ `(` $input `,` $initFuncRef `,` $accumulateFuncRef `)`
156+ attr-dict `:` `(` type($input) `)` `->` type($result)
157+ }];
158+ let extraClassDeclaration = [{
159+ /// Return the init function op that the initFuncRef refers to.
160+ func::FuncOp getInitFunc() {
161+ return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
162+ *this, initFuncRefAttr());
163+ }
164+
165+ /// Return the accumulate function op that the accumulateFuncRef refers to.
166+ func::FuncOp getAccumulateFunc() {
167+ return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
168+ *this, accumulateFuncRefAttr());
169+ }
170+ }];
171+ }
172+
58173/// Verifies that the element types of nested arrays in the $value array
59174/// correspond to the types of the LLVM-struct element type of the $result
60175/// Stream.
@@ -97,14 +212,6 @@ def Iterators_ConstantStreamOp
97212 let results = (outs Iterators_StreamOfLLVMStructOfNumerics:$result);
98213}
99214
100- /// Looks up the given symbol, which must refer to a FuncOp, in the scope of the
101- /// given op and returns the function type of that symbol.
102- class LookupFuncType<string opName, string symbolName>
103- : StrFunc<"::mlir::SymbolTable::lookupNearestSymbolFrom<func::FuncOp>("
104- " &$" # opName # ","
105- " " # symbolName # ".dyn_cast<FlatSymbolRefAttr>())"
106- " .getFunctionType()">;
107-
108215/// A predicate, that is, a function with a signature of the form (T) -> i1.
109216class Iterators_PredicateFunctionType
110217 : Type<And<[FunctionType.predicate,
0 commit comments