@@ -1080,6 +1080,16 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
1080
1080
// / to modify/access them is invalid rewriter API usage.
1081
1081
SetVector<Operation *> replacedOps;
1082
1082
1083
+ // / A set of operations that were created by the current pattern.
1084
+ SetVector<Operation *> patternNewOps;
1085
+
1086
+ // / A set of operations that were modified by the current pattern.
1087
+ SetVector<Operation *> patternModifiedOps;
1088
+
1089
+ // / A set of blocks that were inserted (newly-created blocks or moved blocks)
1090
+ // / by the current pattern.
1091
+ SetVector<Block *> patternInsertedBlocks;
1092
+
1083
1093
// / A mapping of all unresolved materializations (UnrealizedConversionCastOp)
1084
1094
// / to the corresponding rewrite objects.
1085
1095
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
@@ -1571,6 +1581,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
1571
1581
if (!previous.isSet ()) {
1572
1582
// This is a newly created op.
1573
1583
appendRewrite<CreateOperationRewrite>(op);
1584
+ patternNewOps.insert (op);
1574
1585
return ;
1575
1586
}
1576
1587
Operation *prevOp = previous.getPoint () == previous.getBlock ()->end ()
@@ -1655,6 +1666,8 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
1655
1666
}
1656
1667
});
1657
1668
1669
+ patternInsertedBlocks.insert (block);
1670
+
1658
1671
if (!previous) {
1659
1672
// This is a newly created block.
1660
1673
appendRewrite<CreateBlockRewrite>(block);
@@ -1852,6 +1865,8 @@ void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
1852
1865
assert (!impl->wasOpReplaced (op) &&
1853
1866
" attempting to modify a replaced/erased op" );
1854
1867
PatternRewriter::finalizeOpModification (op);
1868
+ impl->patternModifiedOps .insert (op);
1869
+
1855
1870
// There is nothing to do here, we only need to track the operation at the
1856
1871
// start of the update.
1857
1872
#ifndef NDEBUG
@@ -1964,21 +1979,25 @@ class OperationLegalizer {
1964
1979
// / Legalize the resultant IR after successfully applying the given pattern.
1965
1980
LogicalResult legalizePatternResult (Operation *op, const Pattern &pattern,
1966
1981
ConversionPatternRewriter &rewriter,
1967
- RewriterState &curState);
1982
+ const SetVector<Operation *> &newOps,
1983
+ const SetVector<Operation *> &modifiedOps,
1984
+ const SetVector<Block *> &insertedBlocks);
1968
1985
1969
1986
// / Legalizes the actions registered during the execution of a pattern.
1970
1987
LogicalResult
1971
1988
legalizePatternBlockRewrites (Operation *op,
1972
1989
ConversionPatternRewriter &rewriter,
1973
1990
ConversionPatternRewriterImpl &impl,
1974
- RewriterState &state, RewriterState &newState);
1975
- LogicalResult legalizePatternCreatedOperations (
1976
- ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
1977
- RewriterState &state, RewriterState &newState);
1978
- LogicalResult legalizePatternRootUpdates (ConversionPatternRewriter &rewriter,
1979
- ConversionPatternRewriterImpl &impl,
1980
- RewriterState &state,
1981
- RewriterState &newState);
1991
+ const SetVector<Block *> &insertedBlocks,
1992
+ const SetVector<Operation *> &newOps);
1993
+ LogicalResult
1994
+ legalizePatternCreatedOperations (ConversionPatternRewriter &rewriter,
1995
+ ConversionPatternRewriterImpl &impl,
1996
+ const SetVector<Operation *> &newOps);
1997
+ LogicalResult
1998
+ legalizePatternRootUpdates (ConversionPatternRewriter &rewriter,
1999
+ ConversionPatternRewriterImpl &impl,
2000
+ const SetVector<Operation *> &modifiedOps);
1982
2001
1983
2002
// ===--------------------------------------------------------------------===//
1984
2003
// Cost Model
@@ -2131,6 +2150,15 @@ OperationLegalizer::legalize(Operation *op,
2131
2150
return failure ();
2132
2151
}
2133
2152
2153
+ // / Helper function that moves and returns the given object. Also resets the
2154
+ // / original object, so that it is in a valid, empty state again.
2155
+ template <typename T>
2156
+ static T moveAndReset (T &obj) {
2157
+ T result = std::move (obj);
2158
+ obj = T ();
2159
+ return result;
2160
+ }
2161
+
2134
2162
LogicalResult
2135
2163
OperationLegalizer::legalizeWithFold (Operation *op,
2136
2164
ConversionPatternRewriter &rewriter) {
@@ -2192,6 +2220,9 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2192
2220
RewriterState curState = rewriterImpl.getCurrentState ();
2193
2221
auto onFailure = [&](const Pattern &pattern) {
2194
2222
assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2223
+ rewriterImpl.patternNewOps .clear ();
2224
+ rewriterImpl.patternModifiedOps .clear ();
2225
+ rewriterImpl.patternInsertedBlocks .clear ();
2195
2226
LLVM_DEBUG ({
2196
2227
logFailure (rewriterImpl.logger , " pattern failed to match" );
2197
2228
if (rewriterImpl.config .notifyCallback ) {
@@ -2212,7 +2243,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
2212
2243
// successfully applied.
2213
2244
auto onSuccess = [&](const Pattern &pattern) {
2214
2245
assert (rewriterImpl.pendingRootUpdates .empty () && " dangling root updates" );
2215
- auto result = legalizePatternResult (op, pattern, rewriter, curState);
2246
+ SetVector<Operation *> newOps = moveAndReset (rewriterImpl.patternNewOps );
2247
+ SetVector<Operation *> modifiedOps =
2248
+ moveAndReset (rewriterImpl.patternModifiedOps );
2249
+ SetVector<Block *> insertedBlocks =
2250
+ moveAndReset (rewriterImpl.patternInsertedBlocks );
2251
+ auto result = legalizePatternResult (op, pattern, rewriter, newOps,
2252
+ modifiedOps, insertedBlocks);
2216
2253
appliedPatterns.erase (&pattern);
2217
2254
if (failed (result)) {
2218
2255
if (!rewriterImpl.config .allowPatternRollback )
@@ -2253,10 +2290,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
2253
2290
return true ;
2254
2291
}
2255
2292
2256
- LogicalResult
2257
- OperationLegalizer::legalizePatternResult (Operation *op, const Pattern &pattern,
2258
- ConversionPatternRewriter &rewriter,
2259
- RewriterState &curState) {
2293
+ LogicalResult OperationLegalizer::legalizePatternResult (
2294
+ Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
2295
+ const SetVector<Operation *> &newOps,
2296
+ const SetVector<Operation *> &modifiedOps,
2297
+ const SetVector<Block *> &insertedBlocks) {
2260
2298
auto &impl = rewriter.getImpl ();
2261
2299
assert (impl.pendingRootUpdates .empty () && " dangling root updates" );
2262
2300
@@ -2274,12 +2312,10 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2274
2312
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
2275
2313
2276
2314
// Legalize each of the actions registered during application.
2277
- RewriterState newState = impl.getCurrentState ();
2278
- if (failed (legalizePatternBlockRewrites (op, rewriter, impl, curState,
2279
- newState)) ||
2280
- failed (legalizePatternRootUpdates (rewriter, impl, curState, newState)) ||
2281
- failed (legalizePatternCreatedOperations (rewriter, impl, curState,
2282
- newState))) {
2315
+ if (failed (legalizePatternBlockRewrites (op, rewriter, impl, insertedBlocks,
2316
+ newOps)) ||
2317
+ failed (legalizePatternRootUpdates (rewriter, impl, modifiedOps)) ||
2318
+ failed (legalizePatternCreatedOperations (rewriter, impl, newOps))) {
2283
2319
return failure ();
2284
2320
}
2285
2321
@@ -2289,20 +2325,14 @@ OperationLegalizer::legalizePatternResult(Operation *op, const Pattern &pattern,
2289
2325
2290
2326
LogicalResult OperationLegalizer::legalizePatternBlockRewrites (
2291
2327
Operation *op, ConversionPatternRewriter &rewriter,
2292
- ConversionPatternRewriterImpl &impl, RewriterState &state,
2293
- RewriterState &newState) {
2294
- SmallPtrSet<Operation *, 16 > operationsToIgnore;
2328
+ ConversionPatternRewriterImpl &impl,
2329
+ const SetVector<Block *> &insertedBlocks,
2330
+ const SetVector<Operation *> &newOps) {
2331
+ SmallPtrSet<Operation *, 16 > alreadyLegalized;
2295
2332
2296
2333
// If the pattern moved or created any blocks, make sure the types of block
2297
2334
// arguments get legalized.
2298
- for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2299
- BlockRewrite *rewrite = dyn_cast<BlockRewrite>(impl.rewrites [i].get ());
2300
- if (!rewrite)
2301
- continue ;
2302
- Block *block = rewrite->getBlock ();
2303
- if (isa<BlockTypeConversionRewrite, EraseBlockRewrite,
2304
- ReplaceBlockArgRewrite, InlineBlockRewrite>(rewrite))
2305
- continue ;
2335
+ for (Block *block : insertedBlocks) {
2306
2336
// Only check blocks outside of the current operation.
2307
2337
Operation *parentOp = block->getParentOp ();
2308
2338
if (!parentOp || parentOp == op || block->getNumArguments () == 0 )
@@ -2322,41 +2352,26 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
2322
2352
continue ;
2323
2353
}
2324
2354
2325
- // Otherwise, check that this operation isn't one generated by this pattern.
2326
- // This is because we will attempt to legalize the parent operation, and
2327
- // blocks in regions created by this pattern will already be legalized later
2328
- // on. If we haven't built the set yet, build it now.
2329
- if (operationsToIgnore.empty ()) {
2330
- for (unsigned i = state.numRewrites , e = impl.rewrites .size (); i != e;
2331
- ++i) {
2332
- auto *createOp =
2333
- dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2334
- if (!createOp)
2335
- continue ;
2336
- operationsToIgnore.insert (createOp->getOperation ());
2355
+ // Otherwise, try to legalize the parent operation if it was not generated
2356
+ // by this pattern. This is because we will attempt to legalize the parent
2357
+ // operation, and blocks in regions created by this pattern will already be
2358
+ // legalized later on.
2359
+ if (!newOps.count (parentOp) && alreadyLegalized.insert (parentOp).second ) {
2360
+ if (failed (legalize (parentOp, rewriter))) {
2361
+ LLVM_DEBUG (logFailure (
2362
+ impl.logger , " operation '{0}'({1}) became illegal after rewrite" ,
2363
+ parentOp->getName (), parentOp));
2364
+ return failure ();
2337
2365
}
2338
2366
}
2339
-
2340
- // If this operation should be considered for re-legalization, try it.
2341
- if (operationsToIgnore.insert (parentOp).second &&
2342
- failed (legalize (parentOp, rewriter))) {
2343
- LLVM_DEBUG (logFailure (impl.logger ,
2344
- " operation '{0}'({1}) became illegal after rewrite" ,
2345
- parentOp->getName (), parentOp));
2346
- return failure ();
2347
- }
2348
2367
}
2349
2368
return success ();
2350
2369
}
2351
2370
2352
2371
LogicalResult OperationLegalizer::legalizePatternCreatedOperations (
2353
2372
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2354
- RewriterState &state, RewriterState &newState) {
2355
- for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2356
- auto *createOp = dyn_cast<CreateOperationRewrite>(impl.rewrites [i].get ());
2357
- if (!createOp)
2358
- continue ;
2359
- Operation *op = createOp->getOperation ();
2373
+ const SetVector<Operation *> &newOps) {
2374
+ for (Operation *op : newOps) {
2360
2375
if (failed (legalize (op, rewriter))) {
2361
2376
LLVM_DEBUG (logFailure (impl.logger ,
2362
2377
" failed to legalize generated operation '{0}'({1})" ,
@@ -2369,12 +2384,8 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations(
2369
2384
2370
2385
LogicalResult OperationLegalizer::legalizePatternRootUpdates (
2371
2386
ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl,
2372
- RewriterState &state, RewriterState &newState) {
2373
- for (int i = state.numRewrites , e = newState.numRewrites ; i != e; ++i) {
2374
- auto *rewrite = dyn_cast<ModifyOperationRewrite>(impl.rewrites [i].get ());
2375
- if (!rewrite)
2376
- continue ;
2377
- Operation *op = rewrite->getOperation ();
2387
+ const SetVector<Operation *> &modifiedOps) {
2388
+ for (Operation *op : modifiedOps) {
2378
2389
if (failed (legalize (op, rewriter))) {
2379
2390
LLVM_DEBUG (logFailure (
2380
2391
impl.logger , " failed to legalize operation updated in-place '{0}'" ,
0 commit comments