@@ -626,6 +626,18 @@ static OpResult getAliasingOpResult(OpOperand &opOperand) {
626
626
[&](Operation *op) { return getInplaceableOpResult (opOperand); });
627
627
}
628
628
629
+ // / Return `true` if the given OpOperand does not bufferize to a memory read or
630
+ // / write, but creates an alias when bufferized inplace.
631
+ static bool bufferizesToAliasOnly (OpOperand &opOperand) {
632
+ Operation *owner = opOperand.getOwner ();
633
+ // TODO: In the future this may need to evolve into a TypeSwitch. For all
634
+ // currently supported ops, the aliasing-only OpOperand is always the first
635
+ // one.
636
+ return isa<ExtractSliceOp, TensorCollapseShapeOp, TensorExpandShapeOp,
637
+ tensor::CastOp>(owner) &&
638
+ &opOperand == &owner->getOpOperand (0 );
639
+ }
640
+
629
641
// Predeclaration of function.
630
642
static bool bufferizesToMemoryRead (OpOperand &opOperand);
631
643
@@ -640,8 +652,8 @@ static bool isValueRead(Value value) {
640
652
while (!workingSet.empty ()) {
641
653
OpOperand *uMaybeReading = workingSet.pop_back_val ();
642
654
// Skip over all ops that create an alias but do not read.
643
- if (isa<ExtractSliceOp, tensor::CastOp>( uMaybeReading-> getOwner () ))
644
- for (OpOperand &use : uMaybeReading-> getOwner ()-> getResult ( 0 ).getUses ())
655
+ if (bufferizesToAliasOnly (* uMaybeReading))
656
+ for (OpOperand &use : getAliasingOpResult (*uMaybeReading ).getUses ())
645
657
workingSet.push_back (&use);
646
658
if (bufferizesToMemoryRead (*uMaybeReading))
647
659
return true ;
@@ -658,7 +670,7 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
658
670
return true ;
659
671
// Some ops alone do not bufferize to a memory read, but one of their uses
660
672
// may.
661
- if (isa<ExtractSliceOp, tensor::CastOp> (opOperand. getOwner () ))
673
+ if (bufferizesToAliasOnly (opOperand))
662
674
return false ;
663
675
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
664
676
// matching bbArg may.
@@ -690,7 +702,7 @@ static bool bufferizesToMemoryWrite(OpOperand &opOperand) {
690
702
return false ;
691
703
// Some ops alone do not bufferize to a memory write, but one of their uses
692
704
// may.
693
- if (isa<ExtractSliceOp, tensor::CastOp> (opOperand. getOwner () ))
705
+ if (bufferizesToAliasOnly (opOperand))
694
706
return false ;
695
707
// CallOpInterface alone doesn't bufferize to a memory write, one of the uses
696
708
// of the matching bbArg may. It is the responsibility of the caller to
@@ -2318,9 +2330,8 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
2318
2330
return success ();
2319
2331
}
2320
2332
2321
- // / This analysis function is used for ops where the first OpOperand aliases
2322
- // / with the first OpResult, without creating a read or write. There are a few
2323
- // / ops besides ExtractSliceOp that have such semantics.
2333
+ // / This analysis function is used for OpOperands that alias with an OpResult
2334
+ // / but are not inplaceable on it. E.g., ExtractSliceOp.
2324
2335
// /
2325
2336
// / Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
2326
2337
// /
@@ -2335,11 +2346,12 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
2335
2346
// / An analysis is required to ensure inplace bufferization would not result in
2336
2347
// / RaW dependence violations.
2337
2348
static LogicalResult
2338
- bufferizableInPlaceAnalysisAliasOnlyOp (Operation *op ,
2349
+ bufferizableInPlaceAnalysisAliasOnlyOp (OpOperand &operand ,
2339
2350
BufferizationAliasInfo &aliasInfo,
2340
2351
const DominanceInfo &domInfo) {
2341
- return bufferizableInPlaceAnalysisImpl (
2342
- op->getOpOperand (0 ), op->getOpResult (0 ), aliasInfo, domInfo);
2352
+ OpResult result = getAliasingOpResult (operand);
2353
+ assert (result && " expected that the OpOperand has an aliasing OpResult" );
2354
+ return bufferizableInPlaceAnalysisImpl (operand, result, aliasInfo, domInfo);
2343
2355
}
2344
2356
2345
2357
// / Determine if `operand` can be bufferized in-place with one of the op's
@@ -2372,16 +2384,17 @@ LogicalResult mlir::linalg::inPlaceAnalysis(SmallVector<Operation *> &ops,
2372
2384
2373
2385
// Walk ops in reverse for better interference analysis.
2374
2386
for (Operation *op : reverse (ops)) {
2375
- for (OpOperand &opOperand : op->getOpOperands ())
2387
+ for (OpOperand &opOperand : op->getOpOperands ()) {
2376
2388
if (failed (bufferizableInPlaceAnalysis (opOperand, aliasInfo, domInfo)))
2377
2389
return failure ();
2378
2390
2379
- // Special logic to analyze ops who's OpResults are not inplaceable on an
2380
- // OpOperand but may create an alias.
2381
- if (isa<ExtractSliceOp, tensor::CastOp>(op))
2382
- if (failed (
2383
- bufferizableInPlaceAnalysisAliasOnlyOp (op, aliasInfo, domInfo)))
2384
- return failure ();
2391
+ // Special logic to analyze OpOperands that are not inplaceable on an
2392
+ // OpResult but may create an alias.
2393
+ if (bufferizesToAliasOnly (opOperand))
2394
+ if (failed (bufferizableInPlaceAnalysisAliasOnlyOp (opOperand, aliasInfo,
2395
+ domInfo)))
2396
+ return failure ();
2397
+ }
2385
2398
}
2386
2399
2387
2400
return success ();
@@ -3049,8 +3062,8 @@ static LogicalResult runInitTensorElimination(FuncOp funcOp,
3049
3062
aliasInfo.createAliasInfoEntry (extractOp.result ());
3050
3063
3051
3064
// Run analysis on the ExtractSliceOp.
3052
- if (failed (bufferizableInPlaceAnalysisAliasOnlyOp (extractOp, aliasInfo,
3053
- domInfo)))
3065
+ if (failed (bufferizableInPlaceAnalysisAliasOnlyOp (
3066
+ extractOp-> getOpOperand ( 0 ), aliasInfo, domInfo)))
3054
3067
return WalkResult::interrupt ();
3055
3068
3056
3069
// Advance to the next operation.
0 commit comments