@@ -45,32 +45,93 @@ struct BufferizationOptions {
45
45
using MemCpyFn =
46
46
std::function<LogicalResult(OpBuilder &, Location, Value, Value)>;
47
47
48
+ // / An op filter entry. Filters can be used to specify which ops should be
49
+ // / processed by the bufferization.
50
+ struct OpFilterEntry {
51
+ // / If the filter function evaluates to `true`, the filter matches.
52
+ using FilterFn = std::function<bool (Operation *)>;
53
+
54
+ // / Filter type: A filter can either be a DENY filter or an ALLOW filter.
55
+ enum FilterType : int8_t { DENY = 0 , ALLOW = 1 };
56
+
57
+ FilterFn fn;
58
+ FilterType type;
59
+ };
60
+
48
61
BufferizationOptions ();
49
62
50
- // / Return `true` if the op is allowed to be bufferized.
63
+ // / Return whether the op should be bufferized or not.
64
+ // /
65
+ // / If no filter is specified (`hasFilter` = false), every op will be
66
+ // / bufferized. Otherwise, an op is bufferized if:
67
+ // /
68
+ // / - At least one ALLOW filter says `true`.
69
+ // / - And, no DENY filter says `true`.
51
70
bool isOpAllowed (Operation *op) const {
52
71
if (!hasFilter)
53
72
return true ;
54
- return dialectFilter.contains (op->getDialect ()->getNamespace ()) ||
55
- operationFilter.contains (op->getName ().getStringRef ());
73
+ bool isAllowed = false ;
74
+ for (const OpFilterEntry &entry : opFilter) {
75
+ bool filterResult = entry.fn (op);
76
+ switch (entry.type ) {
77
+ case OpFilterEntry::ALLOW:
78
+ isAllowed |= filterResult;
79
+ break ;
80
+ case OpFilterEntry::DENY:
81
+ if (filterResult)
82
+ // DENY filter matches. This op is no allowed. (Even if other ALLOW
83
+ // filters may match.)
84
+ return false ;
85
+ };
86
+ }
87
+ return isAllowed;
56
88
}
57
89
58
90
// / Allow the given dialects and activate the filter (`hasFilter`).
91
+ // /
92
+ // / This function adds one or multiple ALLOW filters.
59
93
template <typename ... DialectTs>
60
- void addToDialectFilter () {
61
- // The following expands a call to addToDialectFilterImpl for each dialect
94
+ void allowDialectInFilter () {
95
+ // The following expands a call to allowDialectInFilterImpl for each dialect
62
96
// in 'DialectTs'. This magic is necessary due to a limitation in the places
63
97
// that a parameter pack can be expanded in c++11.
64
98
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
65
99
(void )std::initializer_list<int >{
66
- 0 , (addToDialectFilterImpl<DialectTs>(), 0 )...};
100
+ 0 , (allowDialectInFilterImpl<DialectTs>(), 0 )...};
101
+ }
102
+
103
+ // / Allow the given dialect and activate the filter (`hasFilter`).
104
+ // /
105
+ // / This function adds an ALLOW filter.
106
+ void allowDialectInFilter (StringRef dialectNamespace) {
107
+ hasFilter = true ;
108
+ OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
109
+ return op->getDialect ()->getNamespace () == dialectNamespace;
110
+ };
111
+ opFilter.push_back (
112
+ OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
67
113
}
68
114
69
115
// / Allow the given ops and activate the filter (`hasFilter`).
70
- template <typename ... OpTys> void addToOperationFilter () {
116
+ // /
117
+ // / This function adds one or multiple ALLOW filters.
118
+ template <typename ... OpTys>
119
+ void allowOperationInFilter () {
71
120
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
72
- (void )std::initializer_list<int >{0 ,
73
- (addToOperationFilterImpl<OpTys>(), 0 )...};
121
+ (void )std::initializer_list<int >{
122
+ 0 , (allowOperationInFilterImpl<OpTys>(), 0 )...};
123
+ }
124
+
125
+ // / Allow the given op and activate the filter (`hasFilter`).
126
+ // /
127
+ // / This function adds an ALLOW filter.
128
+ void allowOperationInFilter (StringRef opName) {
129
+ hasFilter = true ;
130
+ OpFilterEntry::FilterFn filterFn = [=](Operation *op) {
131
+ return op->getName ().getStringRef () == opName;
132
+ };
133
+ opFilter.push_back (
134
+ OpFilterEntry{filterFn, OpFilterEntry::FilterType::ALLOW});
74
135
}
75
136
76
137
// / Try to cast the given op to BufferizableOpInterface if the op is allow
@@ -118,33 +179,26 @@ struct BufferizationOptions {
118
179
// / Buffer alignment for new memory allocations.
119
180
unsigned int bufferAlignment = 128 ;
120
181
121
- // / If set to `true`, only ops that belong to a filtered dialect
122
- // / (`dialectFilter`) and filtered ops (`operationFilter`) are processed. All
123
- // / other ops are ignored. If set to `false`, all ops are bufferized (as long
124
- // / as they implement BufferizableOpInterface).
125
- // /
126
- // / If a filter is specified, `allowUnknownOps` should be enabled. Otherwise,
127
- // / bufferization would fail when encountering a non-filtered op.
182
+ // / If set to `false`, all ops are bufferized (as long as they implement
183
+ // / BufferizableOpInterface). Otherwise, only filtered ops are bufferized.
128
184
bool hasFilter = false ;
129
185
130
- // / A set of allowed dialects.
131
- DenseSet<StringRef> dialectFilter;
132
-
133
- // / A set of allowed ops.
134
- DenseSet<StringRef> operationFilter;
186
+ // / A list of op filters that determine whether an op should be processed or
187
+ // / ignored by the bufferization. If `hasFilter`, only ops that are not
188
+ // / DENY-filtered and have at least one matching ALLOW filter are processed.
189
+ SmallVector<OpFilterEntry> opFilter;
135
190
136
191
private:
137
192
// / Allow a dialect.
138
193
template <typename DialectT>
139
- void addToDialectFilterImpl () {
140
- hasFilter = true ;
141
- dialectFilter.insert (DialectT::getDialectNamespace ());
194
+ void allowDialectInFilterImpl () {
195
+ allowDialectInFilter (DialectT::getDialectNamespace ());
142
196
}
143
197
144
198
// / Allow an op.
145
- template <typename OpTy> void addToOperationFilterImpl () {
146
- hasFilter = true ;
147
- operationFilter. insert (OpTy::getOperationName ());
199
+ template <typename OpTy>
200
+ void allowOperationInFilterImpl () {
201
+ allowOperationInFilter (OpTy::getOperationName ());
148
202
}
149
203
};
150
204
0 commit comments