Skip to content

Commit ca7167d

Browse files
[mlir][Transforms][NFC] GreedyPatternRewriteDriver: Add worklist class
Encapsulate all worklist-related functionality in a separate `Worklist` class. This makes the remaining code more readable and allows for custom worklist implementations (e.g., a randomized worklist for fuzzing pattern application: D142447). Differential Revision: https://reviews.llvm.org/D151345
1 parent 811cbfc commit ca7167d

File tree

1 file changed

+103
-53
lines changed

1 file changed

+103
-53
lines changed

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Lines changed: 103 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ using namespace mlir;
3131

3232
#define DEBUG_TYPE "greedy-rewriter"
3333

34+
namespace {
35+
3436
//===----------------------------------------------------------------------===//
3537
// Debugging Infrastructure
3638
//===----------------------------------------------------------------------===//
3739

38-
namespace {
3940
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
4041
/// A helper struct that stores finger prints of ops in order to detect broken
4142
/// RewritePatterns. A rewrite pattern is broken if it modifies IR without
@@ -130,6 +131,100 @@ struct DebugFingerPrints : public RewriterBase::ForwardingListener {
130131
};
131132
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
132133

134+
//===----------------------------------------------------------------------===//
135+
// Worklist
136+
//===----------------------------------------------------------------------===//
137+
138+
/// A LIFO worklist of operations with efficient removal and set semantics.
139+
///
140+
/// This class maintains a vector of operations and a mapping of operations to
141+
/// positions in the vector, so that operations can be removed efficiently at
142+
/// random. When an operation is removed, it is replaced with nullptr. Such
143+
/// nullptr are skipped when pop'ing elements.
144+
class Worklist {
145+
public:
146+
Worklist();
147+
148+
/// Clear the worklist.
149+
void clear();
150+
151+
/// Return whether the worklist is empty.
152+
bool empty() const;
153+
154+
/// Push an operation to the end of the worklist, unless the operation is
155+
/// already on the worklist.
156+
void push(Operation *op);
157+
158+
/// Pop the an operation from the end of the worklist. Only allowed on
159+
/// non-empty worklists.
160+
Operation *pop();
161+
162+
/// Remove an operation from the worklist.
163+
void remove(Operation *op);
164+
165+
/// Reverse the worklist.
166+
void reverse();
167+
168+
private:
169+
/// The worklist of operations.
170+
std::vector<Operation *> list;
171+
172+
/// A mapping of operations to positions in `list`.
173+
DenseMap<Operation *, unsigned> map;
174+
};
175+
176+
Worklist::Worklist() { list.reserve(64); }
177+
178+
void Worklist::clear() {
179+
list.clear();
180+
map.clear();
181+
}
182+
183+
bool Worklist::empty() const {
184+
// Skip all nullptr.
185+
return !llvm::any_of(list,
186+
[](Operation *op) { return static_cast<bool>(op); });
187+
}
188+
189+
void Worklist::push(Operation *op) {
190+
assert(op && "cannot push nullptr to worklist");
191+
// Check to see if the worklist already contains this op.
192+
if (map.count(op))
193+
return;
194+
map[op] = list.size();
195+
list.push_back(op);
196+
}
197+
198+
Operation *Worklist::pop() {
199+
assert(!empty() && "cannot pop from empty worklist");
200+
// Skip and remove all trailing nullptr.
201+
while (!list.back())
202+
list.pop_back();
203+
Operation *op = list.back();
204+
list.pop_back();
205+
map.erase(op);
206+
// Cleanup: Remove all trailing nullptr.
207+
while (!list.empty() && !list.back())
208+
list.pop_back();
209+
return op;
210+
}
211+
212+
void Worklist::remove(Operation *op) {
213+
assert(op && "cannot remove nullptr from worklist");
214+
auto it = map.find(op);
215+
if (it != map.end()) {
216+
assert(list[it->second] == op && "malformed worklist data structure");
217+
list[it->second] = nullptr;
218+
map.erase(it);
219+
}
220+
}
221+
222+
void Worklist::reverse() {
223+
std::reverse(list.begin(), list.end());
224+
for (size_t i = 0, e = list.size(); i != e; ++i)
225+
map[list[i]] = i;
226+
}
227+
133228
//===----------------------------------------------------------------------===//
134229
// GreedyPatternRewriteDriver
135230
//===----------------------------------------------------------------------===//
@@ -176,11 +271,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
176271
bool processWorklist();
177272

178273
/// The worklist for this transformation keeps track of the operations that
179-
/// need to be revisited, plus their index in the worklist. This allows us to
180-
/// efficiently remove operations from the worklist when they are erased, even
181-
/// if they aren't the root of a pattern.
182-
std::vector<Operation *> worklist;
183-
DenseMap<Operation *, unsigned> worklistMap;
274+
/// need to be (re)visited.
275+
Worklist worklist;
184276

185277
/// Non-pattern based folder for operations.
186278
OperationFolder folder;
@@ -201,9 +293,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
201293
/// simplifications.
202294
void addOperandsToWorklist(ValueRange operands);
203295

204-
/// Pop the next operation from the worklist.
205-
Operation *popFromWorklist();
206-
207296
/// Notify the driver that the given block was created.
208297
void notifyBlockCreated(Block *block) override;
209298

@@ -212,9 +301,6 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
212301
notifyMatchFailure(Location loc,
213302
function_ref<void(Diagnostic &)> reasonCallback) override;
214303

215-
/// If the specified operation is in the worklist, remove it.
216-
void removeFromWorklist(Operation *op);
217-
218304
#ifndef NDEBUG
219305
/// A logger used to emit information during the application process.
220306
llvm::ScopedPrinter logger{llvm::dbgs()};
@@ -239,8 +325,6 @@ GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
239325
// clang-format on
240326
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
241327
{
242-
worklist.reserve(64);
243-
244328
// Apply a simple cost model based solely on pattern benefit.
245329
matcher.applyDefaultCostModel();
246330

@@ -278,12 +362,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
278362
while (!worklist.empty() &&
279363
(numRewrites < config.maxNumRewrites ||
280364
config.maxNumRewrites == GreedyRewriteConfig::kNoLimit)) {
281-
auto *op = popFromWorklist();
282-
283-
// Nulls get added to the worklist when operations are removed, ignore
284-
// them.
285-
if (op == nullptr)
286-
continue;
365+
auto *op = worklist.pop();
287366

288367
LLVM_DEBUG({
289368
logger.getOStream() << "\n";
@@ -395,33 +474,8 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) {
395474

396475
void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) {
397476
if (config.strictMode == GreedyRewriteStrictness::AnyOp ||
398-
strictModeFilteredOps.contains(op)) {
399-
// Check to see if the worklist already contains this op.
400-
if (worklistMap.count(op))
401-
return;
402-
403-
worklistMap[op] = worklist.size();
404-
worklist.push_back(op);
405-
}
406-
}
407-
408-
Operation *GreedyPatternRewriteDriver::popFromWorklist() {
409-
auto *op = worklist.back();
410-
worklist.pop_back();
411-
412-
// This operation is no longer in the worklist, keep worklistMap up to date.
413-
if (op)
414-
worklistMap.erase(op);
415-
return op;
416-
}
417-
418-
void GreedyPatternRewriteDriver::removeFromWorklist(Operation *op) {
419-
auto it = worklistMap.find(op);
420-
if (it != worklistMap.end()) {
421-
assert(worklist[it->second] == op && "malformed worklist data structure");
422-
worklist[it->second] = nullptr;
423-
worklistMap.erase(it);
424-
}
477+
strictModeFilteredOps.contains(op))
478+
worklist.push(op);
425479
}
426480

427481
void GreedyPatternRewriteDriver::notifyBlockCreated(Block *block) {
@@ -475,7 +529,7 @@ void GreedyPatternRewriteDriver::notifyOperationRemoved(Operation *op) {
475529

476530
addOperandsToWorklist(op->getOperands());
477531
op->walk([this](Operation *operation) {
478-
removeFromWorklist(operation);
532+
worklist.remove(operation);
479533
folder.notifyRemoval(operation);
480534
});
481535

@@ -580,7 +634,6 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
580634
break;
581635

582636
worklist.clear();
583-
worklistMap.clear();
584637

585638
if (!config.useTopDownTraversal) {
586639
// Add operations to the worklist in postorder.
@@ -599,10 +652,7 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
599652
});
600653

601654
// Reverse the list so our pop-back loop processes them in-order.
602-
std::reverse(worklist.begin(), worklist.end());
603-
// Remember the reverse index.
604-
for (size_t i = 0, e = worklist.size(); i != e; ++i)
605-
worklistMap[worklist[i]] = i;
655+
worklist.reverse();
606656
}
607657

608658
ctx->executeAction<GreedyPatternRewriteIteration>(

0 commit comments

Comments
 (0)