Skip to content

Commit 31f10c7

Browse files
Add AccumulateOp.
1 parent 431ea5b commit 31f10c7

File tree

1 file changed

+115
-8
lines changed
  • experimental/iterators/include/iterators/Dialect/Iterators/IR

1 file changed

+115
-8
lines changed

experimental/iterators/include/iterators/Dialect/Iterators/IR/IteratorsOps.td

Lines changed: 115 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,121 @@ def Iterators_PrintOp : Iterators_Base_Op<"print"> {
5555
// High-level iterators
5656
//===----------------------------------------------------------------------===//
5757

58+
/// Looks up the given symbol, which must refer to a FuncOp, in the scope of the
59+
/// given op and returns the function type of that symbol.
60+
class LookupFuncType<string opName, string symbolName>
61+
: StrFunc<"::mlir::SymbolTable::lookupNearestSymbolFrom<func::FuncOp>("
62+
" &$" # opName # ","
63+
" " # symbolName # ".dyn_cast<FlatSymbolRefAttr>())"
64+
" .getFunctionType()">;
65+
66+
/// An init function, that is, a function with a signature of the form () -> T.
67+
class Iterators_InitFunctionType
68+
: Type<And<[FunctionType.predicate,
69+
CPred<"$_self.dyn_cast<FunctionType>().getInputs().size() == 0">,
70+
CPred<"$_self.dyn_cast<FunctionType>().getResults().size() == 1">
71+
]>,
72+
"function with signature () -> T",
73+
"FunctionType">;
74+
75+
/// A FlatSymbolRef referring to an init function.
76+
def Iterators_InitFunctionSymbol
77+
: Confined<FlatSymbolRefAttr, [
78+
ReferToOp<"func::FuncOp">,
79+
AttrConstraint<
80+
SubstLeaves<"$_self",
81+
LookupFuncType<"_op", "$_self.dyn_cast<FlatSymbolRefAttr>()">.result,
82+
Iterators_InitFunctionType<>.predicate>,
83+
"referring to a " # Iterators_InitFunctionType<>.summary>
84+
]>;
85+
86+
/// An accumulate function, that is, a function with a signature of the form
87+
/// (T1, T2) -> T1.
88+
class Iterators_AccumulateFunctionType
89+
: Type<And<[FunctionType.predicate,
90+
CPred<"$_self.dyn_cast<FunctionType>().getInputs().size() == 2">,
91+
CPred<"$_self.dyn_cast<FunctionType>().getResults().size() == 1">,
92+
AllMatchPred<["$_self.dyn_cast<FunctionType>().getInput(0)",
93+
"$_self.dyn_cast<FunctionType>().getResult(0)"]>
94+
]>,
95+
"function with signature (T1, T2) -> T1",
96+
"FunctionType">;
97+
98+
/// A FlatSymbolRef referring to an accumulate function.
99+
def Iterators_AccumulateFunctionSymbol
100+
: Confined<FlatSymbolRefAttr, [
101+
ReferToOp<"func::FuncOp">,
102+
AttrConstraint<
103+
SubstLeaves<"$_self",
104+
LookupFuncType<"_op", "$_self.dyn_cast<FlatSymbolRefAttr>()">.result,
105+
Iterators_AccumulateFunctionType<>.predicate>,
106+
"referring to a " # Iterators_AccumulateFunctionType<>.summary>
107+
]>;
108+
109+
def Iterators_AccumulateOp
110+
: Iterators_Op<"accumulate",
111+
[AllMatch<["getInitFunc().getResultTypes().front()",
112+
"$result.getType().dyn_cast<StreamType>().getElementType()"],
113+
"the return type of the init function must match the result "
114+
"element type">,
115+
AllMatch<["getAccumulateFunc().getArgumentTypes()[0]",
116+
"$result.getType().dyn_cast<StreamType>().getElementType()"],
117+
"the type of the first argument of the accumulate function must "
118+
"match the result element type">,
119+
AllMatch<["getAccumulateFunc().getArgumentTypes()[1]",
120+
"$input.getType().dyn_cast<StreamType>().getElementType()"],
121+
"the type of the second argument of the accumulate function must "
122+
"match the input element type">]> {
123+
let summary = "Accumulate the elements of a stream into one element";
124+
let description = [{
125+
Accumulate the elements of the input stream into a single element, i.e.,
126+
compute their generalized sum. This is similar to
127+
[`std::accumulate`](https://en.cppreference.com/w/cpp/algorithm/accumulate)
128+
in C++ and
129+
[`functools.reduce`](https://docs.python.org/3/library/functools.html#functools.reduce)
130+
with *initializer* in Python. The accumulator is initialized with the value
131+
returned by the provided init function; the logic of the accumulation is
132+
given by the provided accumulate function.
133+
134+
Pseudo-code:
135+
```
136+
accumulator = @initFuncRef()
137+
while (next = upstream->next()):
138+
accumulator = @accumulateFuncRef(accumulator, next->value())
139+
return accumulator
140+
141+
Example:
142+
```mlir
143+
%0 = iterators.accumulate(%input, @zero_struct, @sum_struct)
144+
: (!iterators.stream<!llvm.struct<(i32)>>)
145+
-> !iterators.stream<!llvm.struct<(i32)>>
146+
```
147+
}];
148+
let arguments = (ins
149+
Iterators_Stream:$input,
150+
Iterators_InitFunctionSymbol:$initFuncRef,
151+
Iterators_AccumulateFunctionSymbol:$accumulateFuncRef
152+
);
153+
let results = (outs Iterators_Stream:$result);
154+
let assemblyFormat = [{
155+
`(` $input `,` $initFuncRef `,` $accumulateFuncRef `)`
156+
attr-dict `:` `(` type($input) `)` `->` type($result)
157+
}];
158+
let extraClassDeclaration = [{
159+
/// Return the init function op that the initFuncRef refers to.
160+
func::FuncOp getInitFunc() {
161+
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
162+
*this, initFuncRefAttr());
163+
}
164+
165+
/// Return the accumulate function op that the accumulateFuncRef refers to.
166+
func::FuncOp getAccumulateFunc() {
167+
return SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(
168+
*this, accumulateFuncRefAttr());
169+
}
170+
}];
171+
}
172+
58173
/// Verifies that the element types of nested arrays in the $value array
59174
/// correspond to the types of the LLVM-struct element type of the $result
60175
/// Stream.
@@ -97,14 +212,6 @@ def Iterators_ConstantStreamOp
97212
let results = (outs Iterators_StreamOfLLVMStructOfNumerics:$result);
98213
}
99214

100-
/// Looks up the given symbol, which must refer to a FuncOp, in the scope of the
101-
/// given op and returns the function type of that symbol.
102-
class LookupFuncType<string opName, string symbolName>
103-
: StrFunc<"::mlir::SymbolTable::lookupNearestSymbolFrom<func::FuncOp>("
104-
" &$" # opName # ","
105-
" " # symbolName # ".dyn_cast<FlatSymbolRefAttr>())"
106-
" .getFunctionType()">;
107-
108215
/// A predicate, that is, a function with a signature of the form (T) -> i1.
109216
class Iterators_PredicateFunctionType
110217
: Type<And<[FunctionType.predicate,

0 commit comments

Comments
 (0)