@@ -516,14 +516,6 @@ static OpResult getInplaceableOpResult(InsertSliceOp op, OpOperand &opOperand) {
516
516
return op->getResult (0 );
517
517
}
518
518
519
- // / Return the OpResult that may bufferize into the same buffer as `opOperand`
520
- // / when the op is bufferized inplace.
521
- // / Return null if no such result exists.
522
- static OpResult getInplaceableOpResult (tensor::CastOp op,
523
- OpOperand &opOperand) {
524
- return op->getResult (0 );
525
- }
526
-
527
519
// / Return the OpResult that may bufferize into the same buffer as `opOperand`
528
520
// / when the op is bufferized inplace.
529
521
// / The inplace analysis uses this information along with interfering read
@@ -534,16 +526,16 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
534
526
// clang-format off
535
527
// Ops that perform destructive updates on operand(s) to produce
536
528
// result(s).
537
- .Case <tensor::CastOp,
538
- scf::ForOp,
529
+ .Case <scf::ForOp,
539
530
InsertSliceOp,
540
531
LinalgOp,
541
532
TiledLoopOp,
542
533
VectorTransferOpInterface>(
543
534
[&](auto op) { return getInplaceableOpResult (op, opOperand); })
544
- // ExtractSliceOp is special, when bufferized inplace it just returns an
545
- // alias to its operand. Its result is never inplaceable on its operand.
546
- .Case ([&](ExtractSliceOp op) { return OpResult (); })
535
+ // Some ops just return an alias to an operand when bufferized inplace.
536
+ // Such OpResults are never inplaceable on an OpOperand.
537
+ .Case <ExtractSliceOp, tensor::CastOp>(
538
+ [] (auto op) { return OpResult (); })
547
539
// CallOpInterface is special, it needs to wait for the callee to be
548
540
// bufferized and needs to inspect the BufferAliasInfo object. It can't
549
541
// make a proper determination by itself and needs to be conservative.
@@ -572,9 +564,9 @@ static SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) {
572
564
if (!hasKnownBufferizationAliasingBehavior (result.getDefiningOp ()))
573
565
return SmallVector<OpOperand *>();
574
566
TypeSwitch<Operation *>(result.getDefiningOp ())
575
- .Case ([&](tensor::CastOp op) { r.push_back (&op->getOpOperand (0 )); })
576
- .Case ([&](ExtractSliceOp op) { r.push_back (&op->getOpOperand (0 )); })
577
567
.Case ([&](scf::IfOp op) { populateAliasingOpOperands (op, result, r); })
568
+ .Case <ExtractSliceOp, tensor::CastOp>(
569
+ [&](auto op) { r.push_back (&op->getOpOperand (0 )); })
578
570
// In the case of scf::ForOp, this currently assumes the iter_args / yield
579
571
// are 1-1. This may fail and is verified at the end.
580
572
// TODO: update this.
@@ -606,7 +598,15 @@ static SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) {
606
598
// / If the an ExtractSliceOp is bufferized in-place, the source operand will
607
599
// / alias with the result.
608
600
static OpResult getAliasingOpResult (ExtractSliceOp op, OpOperand &opOperand) {
609
- if (op.source () == opOperand.get ())
601
+ if (&op->getOpOperand (0 ) == &opOperand)
602
+ return op->getResult (0 );
603
+ return OpResult ();
604
+ }
605
+
606
+ // / If the a tensor::CastOp is bufferized in-place, the source operand will
607
+ // / alias with the result.
608
+ static OpResult getAliasingOpResult (tensor::CastOp op, OpOperand &opOperand) {
609
+ if (&op->getOpOperand (0 ) == &opOperand)
610
610
return op->getResult (0 );
611
611
return OpResult ();
612
612
}
@@ -616,11 +616,11 @@ static OpResult getAliasingOpResult(ExtractSliceOp op, OpOperand &opOperand) {
616
616
// / TODO: in the future this may need to evolve towards a list of OpResult.
617
617
static OpResult getAliasingOpResult (OpOperand &opOperand) {
618
618
return TypeSwitch<Operation *, OpResult>(opOperand.getOwner ())
619
- // ExtractSliceOp is different: its result is not inplaceable on op.source
620
- // but when bufferized inplace, the result is an aliasing subregion of
621
- // op.source .
622
- .Case (
623
- [&](ExtractSliceOp op) { return getAliasingOpResult (op, opOperand); })
619
+ // Some ops are different: Their result is not inplaceable on an OpOperand
620
+ // but when bufferized inplace, their result is aliasing (a subregion of)
621
+ // an OpOperand .
622
+ .Case <ExtractSliceOp, tensor::CastOp> (
623
+ [&](auto op) { return getAliasingOpResult (op, opOperand); })
624
624
// All other ops, return the result of `getInplaceableOpResult`.
625
625
.Default (
626
626
[&](Operation *op) { return getInplaceableOpResult (opOperand); });
@@ -639,11 +639,9 @@ static bool isValueRead(Value value) {
639
639
640
640
while (!workingSet.empty ()) {
641
641
OpOperand *uMaybeReading = workingSet.pop_back_val ();
642
- // Skip over all ExtractSliceOps. These do not read by themselves but just
643
- // add a new alias.
644
- if (auto extractSliceOp =
645
- dyn_cast<ExtractSliceOp>(uMaybeReading->getOwner ()))
646
- for (OpOperand &use : extractSliceOp.result ().getUses ())
642
+ // Skip over all ops that create an alias but do not read.
643
+ if (isa<ExtractSliceOp, tensor::CastOp>(uMaybeReading->getOwner ()))
644
+ for (OpOperand &use : uMaybeReading->getOwner ()->getResult (0 ).getUses ())
647
645
workingSet.push_back (&use);
648
646
if (bufferizesToMemoryRead (*uMaybeReading))
649
647
return true ;
@@ -658,9 +656,9 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
658
656
// it. Conservatively return true.
659
657
if (!hasKnownBufferizationAliasingBehavior (opOperand.getOwner ()))
660
658
return true ;
661
- // ExtractSliceOp alone doesn't bufferize to a memory read, one of its uses
659
+ // Some ops alone do not bufferize to a memory read, but one of their uses
662
660
// may.
663
- if (isa<ExtractSliceOp>(opOperand.getOwner ()))
661
+ if (isa<ExtractSliceOp, tensor::CastOp >(opOperand.getOwner ()))
664
662
return false ;
665
663
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
666
664
// matching bbArg may.
@@ -690,9 +688,9 @@ static bool bufferizesToMemoryWrite(OpOperand &opOperand) {
690
688
// These terminators are not writes.
691
689
if (isa<ReturnOp, linalg::YieldOp, scf::YieldOp>(opOperand.getOwner ()))
692
690
return false ;
693
- // ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses
691
+ // Some ops alone do not bufferize to a memory write, but one of their uses
694
692
// may.
695
- if (isa<ExtractSliceOp>(opOperand.getOwner ()))
693
+ if (isa<ExtractSliceOp, tensor::CastOp >(opOperand.getOwner ()))
696
694
return false ;
697
695
// CallOpInterface alone doesn't bufferize to a memory write, one of the uses
698
696
// of the matching bbArg may. It is the responsibility of the caller to
@@ -2320,27 +2318,28 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
2320
2318
return success ();
2321
2319
}
2322
2320
2321
+ // / This analysis function is used for ops where the first OpOperand aliases
2322
+ // / with the first OpResult, without creating a read or write. There are a few
2323
+ // / ops besides ExtractSliceOp that have such semantics.
2323
2324
// /
2324
- // / Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace.
2325
- // / ===========================================================
2325
+ // / Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
2326
2326
// /
2327
- // / When bufferized out of place, a ExtractSlice lowers to alloc + copy. This
2327
+ // / When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
2328
2328
// / cannot change the flow of information for either the source or the
2329
2329
// / result buffers.
2330
2330
// /
2331
- // / When bufferized inplace, a ExtractSliceOp does not by itself create any read
2332
- // / or write from memory. Instead, it has the effect of merging the alias sets
2333
- // / of the source and the result buffers.
2331
+ // / When bufferized inplace, an ExtractSliceOp does not by itself create any
2332
+ // / read or write from memory. Instead, it has the effect of merging the alias
2333
+ // / sets of the source and the result buffers.
2334
2334
// /
2335
2335
// / An analysis is required to ensure inplace bufferization would not result in
2336
2336
// / RaW dependence violations.
2337
2337
static LogicalResult
2338
- bufferizableInPlaceAnalysis (ExtractSliceOp extractSliceOp,
2339
- BufferizationAliasInfo &aliasInfo,
2340
- const DominanceInfo &domInfo) {
2341
- return bufferizableInPlaceAnalysisImpl (extractSliceOp->getOpOperand (0 ),
2342
- extractSliceOp->getOpResult (0 ),
2343
- aliasInfo, domInfo);
2338
+ bufferizableInPlaceAnalysisAliasOnlyOp (Operation *op,
2339
+ BufferizationAliasInfo &aliasInfo,
2340
+ const DominanceInfo &domInfo) {
2341
+ return bufferizableInPlaceAnalysisImpl (
2342
+ op->getOpOperand (0 ), op->getOpResult (0 ), aliasInfo, domInfo);
2344
2343
}
2345
2344
2346
2345
// / Determine if `operand` can be bufferized in-place with one of the op's
@@ -2377,14 +2376,11 @@ LogicalResult mlir::linalg::inPlaceAnalysis(SmallVector<Operation *> &ops,
2377
2376
if (failed (bufferizableInPlaceAnalysis (opOperand, aliasInfo, domInfo)))
2378
2377
return failure ();
2379
2378
2380
- // Special logic to analyze ExtractSliceOp.
2381
- // Note that ExtractSliceOp analysis needs to be interleaved with other ops
2382
- // to properly capture aliases.
2383
- // Walk ExtractSliceOps in reverse for better clobbering analysis behavior:
2384
- // it is easier to detect clobbers of smaller slices before larger ones.
2385
- if (auto extractSliceOp = dyn_cast<ExtractSliceOp>(op))
2379
+ // Special logic to analyze ops who's OpResults are not inplaceable on an
2380
+ // OpOperand but may create an alias.
2381
+ if (isa<ExtractSliceOp, tensor::CastOp>(op))
2386
2382
if (failed (
2387
- bufferizableInPlaceAnalysis (extractSliceOp , aliasInfo, domInfo)))
2383
+ bufferizableInPlaceAnalysisAliasOnlyOp (op , aliasInfo, domInfo)))
2388
2384
return failure ();
2389
2385
}
2390
2386
@@ -3053,7 +3049,8 @@ static LogicalResult runInitTensorElimination(FuncOp funcOp,
3053
3049
aliasInfo.createAliasInfoEntry (extractOp.result ());
3054
3050
3055
3051
// Run analysis on the ExtractSliceOp.
3056
- if (failed (bufferizableInPlaceAnalysis (extractOp, aliasInfo, domInfo)))
3052
+ if (failed (bufferizableInPlaceAnalysisAliasOnlyOp (extractOp, aliasInfo,
3053
+ domInfo)))
3057
3054
return WalkResult::interrupt ();
3058
3055
3059
3056
// Advance to the next operation.
0 commit comments