@@ -31,11 +31,12 @@ using namespace mlir;
31
31
32
32
#define DEBUG_TYPE " greedy-rewriter"
33
33
34
+ namespace {
35
+
34
36
// ===----------------------------------------------------------------------===//
35
37
// Debugging Infrastructure
36
38
// ===----------------------------------------------------------------------===//
37
39
38
- namespace {
39
40
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
40
41
// / A helper struct that stores finger prints of ops in order to detect broken
41
42
// / RewritePatterns. A rewrite pattern is broken if it modifies IR without
@@ -130,6 +131,100 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
130
131
};
131
132
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
132
133
134
+ // ===----------------------------------------------------------------------===//
135
+ // Worklist
136
+ // ===----------------------------------------------------------------------===//
137
+
138
+ // / A LIFO worklist of operations with efficient removal and set semantics.
139
+ // /
140
+ // / This class maintains a vector of operations and a mapping of operations to
141
+ // / positions in the vector, so that operations can be removed efficiently at
142
+ // / random. When an operation is removed, it is replaced with nullptr. Such
143
+ // / nullptr are skipped when pop'ing elements.
144
+ class Worklist {
145
+ public:
146
+ Worklist ();
147
+
148
+ // / Clear the worklist.
149
+ void clear ();
150
+
151
+ // / Return whether the worklist is empty.
152
+ bool empty () const ;
153
+
154
+ // / Push an operation to the end of the worklist, unless the operation is
155
+ // / already on the worklist.
156
+ void push (Operation *op);
157
+
158
+ // / Pop the an operation from the end of the worklist. Only allowed on
159
+ // / non-empty worklists.
160
+ Operation *pop ();
161
+
162
+ // / Remove an operation from the worklist.
163
+ void remove (Operation *op);
164
+
165
+ // / Reverse the worklist.
166
+ void reverse ();
167
+
168
+ private:
169
+ // / The worklist of operations.
170
+ std::vector<Operation *> list;
171
+
172
+ // / A mapping of operations to positions in `list`.
173
+ DenseMap<Operation *, unsigned > map;
174
+ };
175
+
176
+ Worklist::Worklist () { list.reserve (64 ); }
177
+
178
+ void Worklist::clear () {
179
+ list.clear ();
180
+ map.clear ();
181
+ }
182
+
183
+ bool Worklist::empty () const {
184
+ // Skip all nullptr.
185
+ return !llvm::any_of (list,
186
+ [](Operation *op) { return static_cast <bool >(op); });
187
+ }
188
+
189
+ void Worklist::push (Operation *op) {
190
+ assert (op && " cannot push nullptr to worklist" );
191
+ // Check to see if the worklist already contains this op.
192
+ if (map.count (op))
193
+ return ;
194
+ map[op] = list.size ();
195
+ list.push_back (op);
196
+ }
197
+
198
+ Operation *Worklist::pop () {
199
+ assert (!empty () && " cannot pop from empty worklist" );
200
+ // Skip and remove all trailing nullptr.
201
+ while (!list.back ())
202
+ list.pop_back ();
203
+ Operation *op = list.back ();
204
+ list.pop_back ();
205
+ map.erase (op);
206
+ // Cleanup: Remove all trailing nullptr.
207
+ while (!list.empty () && !list.back ())
208
+ list.pop_back ();
209
+ return op;
210
+ }
211
+
212
+ void Worklist::remove (Operation *op) {
213
+ assert (op && " cannot remove nullptr from worklist" );
214
+ auto it = map.find (op);
215
+ if (it != map.end ()) {
216
+ assert (list[it->second ] == op && " malformed worklist data structure" );
217
+ list[it->second ] = nullptr ;
218
+ map.erase (it);
219
+ }
220
+ }
221
+
222
+ void Worklist::reverse () {
223
+ std::reverse (list.begin (), list.end ());
224
+ for (size_t i = 0 , e = list.size (); i != e; ++i)
225
+ map[list[i]] = i;
226
+ }
227
+
133
228
// ===----------------------------------------------------------------------===//
134
229
// GreedyPatternRewriteDriver
135
230
// ===----------------------------------------------------------------------===//
@@ -176,11 +271,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
176
271
bool processWorklist ();
177
272
178
273
// / The worklist for this transformation keeps track of the operations that
179
- // / need to be revisited, plus their index in the worklist. This allows us to
180
- // / efficiently remove operations from the worklist when they are erased, even
181
- // / if they aren't the root of a pattern.
182
- std::vector<Operation *> worklist;
183
- DenseMap<Operation *, unsigned > worklistMap;
274
+ // / need to be (re)visited.
275
+ Worklist worklist;
184
276
185
277
// / Non-pattern based folder for operations.
186
278
OperationFolder folder;
@@ -201,9 +293,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
201
293
// / simplifications.
202
294
void addOperandsToWorklist (ValueRange operands);
203
295
204
- // / Pop the next operation from the worklist.
205
- Operation *popFromWorklist ();
206
-
207
296
// / Notify the driver that the given block was created.
208
297
void notifyBlockCreated (Block *block) override ;
209
298
@@ -212,9 +301,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
212
301
notifyMatchFailure (Location loc,
213
302
function_ref<void (Diagnostic &)> reasonCallback) override ;
214
303
215
- // / If the specified operation is in the worklist, remove it.
216
- void removeFromWorklist (Operation *op);
217
-
218
304
#ifndef NDEBUG
219
305
// / A logger used to emit information during the application process.
220
306
llvm::ScopedPrinter logger{llvm::dbgs ()};
@@ -239,8 +325,6 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
239
325
// clang-format on
240
326
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
241
327
{
242
- worklist.reserve (64 );
243
-
244
328
// Apply a simple cost model based solely on pattern benefit.
245
329
matcher.applyDefaultCostModel ();
246
330
@@ -278,12 +362,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
278
362
while (!worklist.empty () &&
279
363
(numRewrites < config.maxNumRewrites ||
280
364
config.maxNumRewrites == GreedyRewriteConfig::kNoLimit )) {
281
- auto *op = popFromWorklist ();
282
-
283
- // Nulls get added to the worklist when operations are removed, ignore
284
- // them.
285
- if (op == nullptr )
286
- continue ;
365
+ auto *op = worklist.pop ();
287
366
288
367
LLVM_DEBUG ({
289
368
logger.getOStream () << " \n " ;
@@ -395,33 +474,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
395
474
396
475
void GreedyPatternRewriteDriver::addSingleOpToWorklist (Operation *op) {
397
476
if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
398
- strictModeFilteredOps.contains (op)) {
399
- // Check to see if the worklist already contains this op.
400
- if (worklistMap.count (op))
401
- return ;
402
-
403
- worklistMap[op] = worklist.size ();
404
- worklist.push_back (op);
405
- }
406
- }
407
-
408
- Operation *GreedyPatternRewriteDriver::popFromWorklist () {
409
- auto *op = worklist.back ();
410
- worklist.pop_back ();
411
-
412
- // This operation is no longer in the worklist, keep worklistMap up to date.
413
- if (op)
414
- worklistMap.erase (op);
415
- return op;
416
- }
417
-
418
- void GreedyPatternRewriteDriver::removeFromWorklist (Operation *op) {
419
- auto it = worklistMap.find (op);
420
- if (it != worklistMap.end ()) {
421
- assert (worklist[it->second ] == op && " malformed worklist data structure" );
422
- worklist[it->second ] = nullptr ;
423
- worklistMap.erase (it);
424
- }
477
+ strictModeFilteredOps.contains (op))
478
+ worklist.push (op);
425
479
}
426
480
427
481
void GreedyPatternRewriteDriver::notifyBlockCreated (Block *block) {
@@ -475,7 +529,7 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
475
529
476
530
addOperandsToWorklist (op->getOperands ());
477
531
op->walk ([this ](Operation *operation) {
478
- removeFromWorklist (operation);
532
+ worklist. remove (operation);
479
533
folder.notifyRemoval (operation);
480
534
});
481
535
@@ -580,7 +634,6 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
580
634
break ;
581
635
582
636
worklist.clear ();
583
- worklistMap.clear ();
584
637
585
638
if (!config.useTopDownTraversal ) {
586
639
// Add operations to the worklist in postorder.
@@ -599,10 +652,7 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
599
652
});
600
653
601
654
// Reverse the list so our pop-back loop processes them in-order.
602
- std::reverse (worklist.begin (), worklist.end ());
603
- // Remember the reverse index.
604
- for (size_t i = 0 , e = worklist.size (); i != e; ++i)
605
- worklistMap[worklist[i]] = i;
655
+ worklist.reverse ();
606
656
}
607
657
608
658
ctx->executeAction <GreedyPatternRewriteIteration>(
0 commit comments