Skip to content

Commit 9ccf613

Browse files
[mlir][Transforms][NFC] Store per-pattern IR modifications in separate state (#145319)
This commit adds extra state to `ConversionPatternRewriterImpl` to store all modified / newly-created operations and moved / newly-created blocks in separate lists on a per-pattern basis. This is in preparation of the One-Shot Dialect Conversion refactoring: the new driver will no longer maintain a list of all IR rewrites, so information about newly-created operations (which is needed to trigger recursive legalization) must be retained in a different data structure. This commit is also expected to improve the performance of the existing driver. The previous implementation iterated over all new IR modifications and then filtered them by type. It also required an additional pointer indirection (through `std::unique_ptr<IRRewrite>`) to retrieve the operation/block pointer.
1 parent fff720d commit 9ccf613

File tree

1 file changed

+75
-64
lines changed

1 file changed

+75
-64
lines changed

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 75 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,6 +1080,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
10801080
/// to modify/access them is invalid rewriter API usage.
10811081
SetVector<Operation *> replacedOps;
10821082

1083+
/// A set of operations that were created by the current pattern.
1084+
SetVector<Operation *> patternNewOps;
1085+
1086+
/// A set of operations that were modified by the current pattern.
1087+
SetVector<Operation *> patternModifiedOps;
1088+
1089+
/// A set of blocks that were inserted (newly-created blocks or moved blocks)
1090+
/// by the current pattern.
1091+
SetVector<Block *> patternInsertedBlocks;
1092+
10831093
/// A mapping of all unresolved materializations (UnrealizedConversionCastOp)
10841094
/// to the corresponding rewrite objects.
10851095
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
@@ -1571,6 +1581,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
15711581
if (!previous.isSet()) {
15721582
// This is a newly created op.
15731583
appendRewrite<CreateOperationRewrite>(op);
1584+
patternNewOps.insert(op);
15741585
return;
15751586
}
15761587
Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
@@ -1655,6 +1666,8 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
16551666
}
16561667
});
16571668

1669+
patternInsertedBlocks.insert(block);
1670+
16581671
if (!previous) {
16591672
// This is a newly created block.
16601673
appendRewrite<CreateBlockRewrite>(block);
@@ -1852,6 +1865,8 @@ void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
18521865
assert(!impl->wasOpReplaced(op) &&
18531866
"attempting to modify a replaced/erased op");
18541867
PatternRewriter::finalizeOpModification(op);
1868+
impl->patternModifiedOps.insert(op);
1869+
18551870
// There is nothing to do here, we only need to track the operation at the
18561871
// start of the update.
18571872
#ifndef NDEBUG
@@ -1964,21 +1979,25 @@ class OperationLegalizer {
19641979
/// Legalize the resultant IR after successfully applying the given pattern.
19651980
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
19661981
ConversionPatternRewriter &rewriter,
1967-
RewriterState &curState);
1982+
const SetVector<Operation *> &newOps,
1983+
const SetVector<Operation *> &modifiedOps,
1984+
const SetVector<Block *> &insertedBlocks);
19681985

19691986
/// Legalizes the actions registered during the execution of a pattern.
19701987
LogicalResult
19711988
legalizePatternBlockRewrites(Operation *op,
19721989
ConversionPatternRewriter &rewriter,
19731990
ConversionPatternRewriterImpl &impl,
1974-
RewriterState &state, RewriterState &newState);
1975-
LogicalResult legalizePatternCreatedOperations(
1976-
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1977-
RewriterState &state, RewriterState &newState);
1978-
LogicalResult legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1979-
ConversionPatternRewriterImpl &impl,
1980-
RewriterState &state,
1981-
RewriterState &newState);
1991+
const SetVector<Block *> &insertedBlocks,
1992+
const SetVector<Operation *> &newOps);
1993+
LogicalResult
1994+
legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
1995+
ConversionPatternRewriterImpl &impl,
1996+
const SetVector<Operation *> &newOps);
1997+
LogicalResult
1998+
legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
1999+
ConversionPatternRewriterImpl &impl,
2000+
const SetVector<Operation *> &modifiedOps);
19822001

19832002
//===--------------------------------------------------------------------===//
19842003
// Cost Model
@@ -2131,6 +2150,15 @@ OperationLegalizer::legalize(Operation *op,
21312150
return failure();
21322151
}
21332152

2153+
/// Helper function that moves and returns the given object. Also resets the
2154+
/// original object, so that it is in a valid, empty state again.
2155+
template <typename T>
2156+
static T moveAndReset(T &obj) {
2157+
T result = std::move(obj);
2158+
obj = T();
2159+
return result;
2160+
}
2161+
21342162
LogicalResult
21352163
OperationLegalizer::legalizeWithFold(Operation *op,
21362164
ConversionPatternRewriter &rewriter) {
@@ -2192,6 +2220,9 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
21922220
RewriterState curState = rewriterImpl.getCurrentState();
21932221
auto onFailure = [&](const Pattern &pattern) {
21942222
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2223+
rewriterImpl.patternNewOps.clear();
2224+
rewriterImpl.patternModifiedOps.clear();
2225+
rewriterImpl.patternInsertedBlocks.clear();
21952226
LLVM_DEBUG({
21962227
logFailure(rewriterImpl.logger, "pattern failed to match");
21972228
if (rewriterImpl.config.notifyCallback) {
@@ -2212,7 +2243,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
22122243
// successfully applied.
22132244
auto onSuccess = [&](const Pattern &pattern) {
22142245
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
2215-
auto result = legalizePatternResult(op, pattern, rewriter, curState);
2246+
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
2247+
SetVector<Operation *> modifiedOps =
2248+
moveAndReset(rewriterImpl.patternModifiedOps);
2249+
SetVector<Block *> insertedBlocks =
2250+
moveAndReset(rewriterImpl.patternInsertedBlocks);
2251+
auto result = legalizePatternResult(op, pattern, rewriter, newOps,
2252+
modifiedOps, insertedBlocks);
22162253
appliedPatterns.erase(&pattern);
22172254
if (failed(result)) {
22182255
if (!rewriterImpl.config.allowPatternRollback)
@@ -2253,10 +2290,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
22532290
return true;
22542291
}
22552292

2256-
LogicalResult
2257-
OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2258-
ConversionPatternRewriter &rewriter,
2259-
RewriterState &curState) {
2293+
LogicalResult OperationLegalizer::legalizePatternResult(
2294+
Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
2295+
const SetVector<Operation *> &newOps,
2296+
const SetVector<Operation *> &modifiedOps,
2297+
const SetVector<Block *> &insertedBlocks) {
22602298
auto &impl = rewriter.getImpl();
22612299
assert(impl.pendingRootUpdates.empty() && "dangling root updates");
22622300

@@ -2274,12 +2312,10 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
22742312
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
22752313

22762314
// Legalize each of the actions registered during application.
2277-
RewriterState newState = impl.getCurrentState();
2278-
if (failed(legalizePatternBlockRewrites(op, rewriter, impl, curState,
2279-
newState)) ||
2280-
failed(legalizePatternRootUpdates(rewriter, impl, curState, newState)) ||
2281-
failed(legalizePatternCreatedOperations(rewriter, impl, curState,
2282-
newState))) {
2315+
if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
2316+
newOps)) ||
2317+
failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
2318+
failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
22832319
return failure();
22842320
}
22852321

@@ -2289,20 +2325,14 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
22892325

22902326
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
22912327
Operation *op, ConversionPatternRewriter &rewriter,
2292-
ConversionPatternRewriterImpl &impl, RewriterState &state,
2293-
RewriterState &newState) {
2294-
SmallPtrSet<Operation *, 16> operationsToIgnore;
2328+
ConversionPatternRewriterImpl &impl,
2329+
const SetVector<Block *> &insertedBlocks,
2330+
const SetVector<Operation *> &newOps) {
2331+
SmallPtrSet<Operation *, 16> alreadyLegalized;
22952332

22962333
// If the pattern moved or created any blocks, make sure the types of block
22972334
// arguments get legalized.
2298-
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2299-
BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites[i].get());
2300-
if (!rewrite)
2301-
continue;
2302-
Block *block = rewrite->getBlock();
2303-
if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2304-
ReplaceBlockArgRewrite, InlineBlockRewrite>(rewrite))
2305-
continue;
2335+
for (Block *block : insertedBlocks) {
23062336
// Only check blocks outside of the current operation.
23072337
Operation *parentOp = block->getParentOp();
23082338
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
@@ -2322,41 +2352,26 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
23222352
continue;
23232353
}
23242354

2325-
// Otherwise, check that this operation isn't one generated by this pattern.
2326-
// This is because we will attempt to legalize the parent operation, and
2327-
// blocks in regions created by this pattern will already be legalized later
2328-
// on. If we haven't built the set yet, build it now.
2329-
if (operationsToIgnore.empty()) {
2330-
for (unsigned i = state.numRewrites, e = impl.rewrites.size(); i != e;
2331-
++i) {
2332-
auto *createOp =
2333-
dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2334-
if (!createOp)
2335-
continue;
2336-
operationsToIgnore.insert(createOp->getOperation());
2355+
// Otherwise, try to legalize the parent operation if it was not generated
2356+
// by this pattern. This is because we will attempt to legalize the parent
2357+
// operation, and blocks in regions created by this pattern will already be
2358+
// legalized later on.
2359+
if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) {
2360+
if (failed(legalize(parentOp, rewriter))) {
2361+
LLVM_DEBUG(logFailure(
2362+
impl.logger, "operation '{0}'({1}) became illegal after rewrite",
2363+
parentOp->getName(), parentOp));
2364+
return failure();
23372365
}
23382366
}
2339-
2340-
// If this operation should be considered for re-legalization, try it.
2341-
if (operationsToIgnore.insert(parentOp).second &&
2342-
failed(legalize(parentOp, rewriter))) {
2343-
LLVM_DEBUG(logFailure(impl.logger,
2344-
"operation '{0}'({1}) became illegal after rewrite",
2345-
parentOp->getName(), parentOp));
2346-
return failure();
2347-
}
23482367
}
23492368
return success();
23502369
}
23512370

23522371
LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
23532372
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2354-
RewriterState &state, RewriterState &newState) {
2355-
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2356-
auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites[i].get());
2357-
if (!createOp)
2358-
continue;
2359-
Operation *op = createOp->getOperation();
2373+
const SetVector<Operation *> &newOps) {
2374+
for (Operation *op : newOps) {
23602375
if (failed(legalize(op, rewriter))) {
23612376
LLVM_DEBUG(logFailure(impl.logger,
23622377
"failed to legalize generated operation '{0}'({1})",
@@ -2369,12 +2384,8 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
23692384

23702385
LogicalResult OperationLegalizer::legalizePatternRootUpdates(
23712386
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2372-
RewriterState &state, RewriterState &newState) {
2373-
for (int i = state.numRewrites, e = newState.numRewrites; i != e; ++i) {
2374-
auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites[i].get());
2375-
if (!rewrite)
2376-
continue;
2377-
Operation *op = rewrite->getOperation();
2387+
const SetVector<Operation *> &modifiedOps) {
2388+
for (Operation *op : modifiedOps) {
23782389
if (failed(legalize(op, rewriter))) {
23792390
LLVM_DEBUG(logFailure(
23802391
impl.logger, "failed to legalize operation updated in-place '{0}'",

0 commit comments

Comments
 (0)