Skip to content

Commit 695c341

Browse files
[mlir][bufferize] Generalize filtering mechanism in BufferizationOptions
Support ALLOW filters and DENY filters. This is needed for compatibility with existing code that specifies more complex op filters. Differential Revision: https://reviews.llvm.org/D119820
1 parent 655d0d8 commit 695c341

File tree

4 files changed

+85
-31
lines changed

4 files changed

+85
-31
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -45,32 +45,93 @@ struct BufferizationOptions {
4545
using MemCpyFn =
4646
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
4747

48+
/// An op filter entry. Filters can be used to specify which ops should be
49+
/// processed by the bufferization.
50+
struct OpFilterEntry {
51+
/// If the filter function evaluates to `true`, the filter matches.
52+
using FilterFn = std::function<bool(Operation *)>;
53+
54+
/// Filter type: A filter can either be a DENY filter or an ALLOW filter.
55+
enum FilterType : int8_t { DENY = 0, ALLOW = 1 };
56+
57+
FilterFn fn;
58+
FilterType type;
59+
};
60+
4861
BufferizationOptions();
4962

50-
/// Return `true` if the op is allowed to be bufferized.
63+
/// Return whether the op should be bufferized or not.
64+
///
65+
/// If no filter is specified (`hasFilter` = false), every op will be
66+
/// bufferized. Otherwise, an op is bufferized if:
67+
///
68+
/// - At least one ALLOW filter says `true`.
69+
/// - And, no DENY filter says `true`.
5170
bool isOpAllowed(Operation *op) const {
5271
if (!hasFilter)
5372
return true;
54-
return dialectFilter.contains(op->getDialect()->getNamespace()) ||
55-
operationFilter.contains(op->getName().getStringRef());
73+
bool isAllowed = false;
74+
for (const OpFilterEntry &entry : opFilter) {
75+
bool filterResult = entry.fn(op);
76+
switch (entry.type) {
77+
case OpFilterEntry::ALLOW:
78+
isAllowed |= filterResult;
79+
break;
80+
case OpFilterEntry::DENY:
81+
if (filterResult)
82+
// DENY filter matches. This op is no allowed. (Even if other ALLOW
83+
// filters may match.)
84+
return false;
85+
};
86+
}
87+
return isAllowed;
5688
}
5789

5890
/// Allow the given dialects and activate the filter (`hasFilter`).
91+
///
92+
/// This function adds one or multiple ALLOW filters.
5993
template <typename... DialectTs>
60-
void addToDialectFilter() {
61-
// The following expands a call to addToDialectFilterImpl for each dialect
94+
void allowDialectInFilter() {
95+
// The following expands a call to allowDialectInFilterImpl for each dialect
6296
// in 'DialectTs'. This magic is necessary due to a limitation in the places
6397
// that a parameter pack can be expanded in c++11.
6498
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
6599
(void)std::initializer_list<int>{
66-
0, (addToDialectFilterImpl<DialectTs>(), 0)...};
100+
0, (allowDialectInFilterImpl<DialectTs>(), 0)...};
101+
}
102+
103+
/// Allow the given dialect and activate the filter (`hasFilter`).
104+
///
105+
/// This function adds an ALLOW filter.
106+
void allowDialectInFilter(StringRef dialectNamespace) {
107+
hasFilter = true;
108+
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
109+
return op->getDialect()->getNamespace() == dialectNamespace;
110+
};
111+
opFilter.push_back(
112+
OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
67113
}
68114

69115
/// Allow the given ops and activate the filter (`hasFilter`).
70-
template <typename... OpTys> void addToOperationFilter() {
116+
///
117+
/// This function adds one or multiple ALLOW filters.
118+
template <typename... OpTys>
119+
void allowOperationInFilter() {
71120
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
72-
(void)std::initializer_list<int>{0,
73-
(addToOperationFilterImpl<OpTys>(), 0)...};
121+
(void)std::initializer_list<int>{
122+
0, (allowOperationInFilterImpl<OpTys>(), 0)...};
123+
}
124+
125+
/// Allow the given op and activate the filter (`hasFilter`).
126+
///
127+
/// This function adds an ALLOW filter.
128+
void allowOperationInFilter(StringRef opName) {
129+
hasFilter = true;
130+
OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
131+
return op->getName().getStringRef() == opName;
132+
};
133+
opFilter.push_back(
134+
OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
74135
}
75136

76137
/// Try to cast the given op to BufferizableOpInterface if the op is allow
@@ -118,33 +179,26 @@ struct BufferizationOptions {
118179
/// Buffer alignment for new memory allocations.
119180
unsigned int bufferAlignment = 128;
120181

121-
/// If set to `true`, only ops that belong to a filtered dialect
122-
/// (`dialectFilter`) and filtered ops (`operationFilter`) are processed. All
123-
/// other ops are ignored. If set to `false`, all ops are bufferized (as long
124-
/// as they implement BufferizableOpInterface).
125-
///
126-
/// If a filter is specified, `allowUnknownOps` should be enabled. Otherwise,
127-
/// bufferization would fail when encountering a non-filtered op.
182+
/// If set to `false`, all ops are bufferized (as long as they implement
183+
/// BufferizableOpInterface). Otherwise, only filtered ops are bufferized.
128184
bool hasFilter = false;
129185

130-
/// A set of allowed dialects.
131-
DenseSet<StringRef> dialectFilter;
132-
133-
/// A set of allowed ops.
134-
DenseSet<StringRef> operationFilter;
186+
/// A list of op filters that determine whether an op should be processed or
187+
/// ignored by the bufferization. If `hasFilter`, only ops that are not
188+
/// DENY-filtered and have at least one matching ALLOW filter are processed.
189+
SmallVector<OpFilterEntry> opFilter;
135190

136191
private:
137192
/// Allow a dialect.
138193
template <typename DialectT>
139-
void addToDialectFilterImpl() {
140-
hasFilter = true;
141-
dialectFilter.insert(DialectT::getDialectNamespace());
194+
void allowDialectInFilterImpl() {
195+
allowDialectInFilter(DialectT::getDialectNamespace());
142196
}
143197

144198
/// Allow an op.
145-
template <typename OpTy> void addToOperationFilterImpl() {
146-
hasFilter = true;
147-
operationFilter.insert(OpTy::getOperationName());
199+
template <typename OpTy>
200+
void allowOperationInFilterImpl() {
201+
allowOperationInFilter(OpTy::getOperationName());
148202
}
149203
};
150204

mlir/lib/Dialect/Arithmetic/Transforms/Bufferize.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ struct ArithmeticBufferizePass
3131
void runOnOperation() override {
3232
BufferizationOptions options = getPartialBufferizationOptions();
3333
if (constantOpOnly) {
34-
options.addToOperationFilter<arith::ConstantOp>();
34+
options.allowOperationInFilter<arith::ConstantOp>();
3535
} else {
36-
options.addToDialectFilter<arith::ArithmeticDialect>();
36+
options.allowDialectInFilter<arith::ArithmeticDialect>();
3737
}
3838
options.bufferAlignment = alignment;
3939

mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace {
3131
struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
3232
void runOnOperation() override {
3333
BufferizationOptions options = getPartialBufferizationOptions();
34-
options.addToDialectFilter<tensor::TensorDialect>();
34+
options.allowDialectInFilter<tensor::TensorDialect>();
3535

3636
if (failed(bufferizeOp(getOperation(), options)))
3737
signalPassFailure();

mlir/test/lib/Dialect/Linalg/TestComprehensiveBufferize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ void TestComprehensiveFunctionBufferize::runOnOperation() {
116116
if (dialectFilter.hasValue()) {
117117
options->hasFilter = true;
118118
for (const std::string &dialectNamespace : dialectFilter)
119-
options->dialectFilter.insert(dialectNamespace);
119+
options->allowDialectInFilter(dialectNamespace);
120120
}
121121

122122
Operation *op = getOperation();

0 commit comments

Comments
 (0)