Skip to content

Commit 4e14bac

Browse files
[mlir][linalg][bufferize] tensor::CastOp is an alias-only op
tensor::CastOp by itself does not bufferize to memory read/write. Differential Revision: https://reviews.llvm.org/D112514
1 parent 3fe4b54 commit 4e14bac

File tree

1 file changed

+47
-50
lines changed

1 file changed

+47
-50
lines changed

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -516,14 +516,6 @@ static OpResult getInplaceableOpResult(InsertSliceOp op, OpOperand &opOperand) {
516516
return op->getResult(0);
517517
}
518518

519-
/// Return the OpResult that may bufferize into the same buffer as `opOperand`
520-
/// when the op is bufferized inplace.
521-
/// Return null if no such result exists.
522-
static OpResult getInplaceableOpResult(tensor::CastOp op,
523-
OpOperand &opOperand) {
524-
return op->getResult(0);
525-
}
526-
527519
/// Return the OpResult that may bufferize into the same buffer as `opOperand`
528520
/// when the op is bufferized inplace.
529521
/// The inplace analysis uses this information along with interfering read
@@ -534,16 +526,16 @@ static OpResult getInplaceableOpResult(OpOperand &opOperand) {
534526
// clang-format off
535527
// Ops that perform destructive updates on operand(s) to produce
536528
// result(s).
537-
.Case<tensor::CastOp,
538-
scf::ForOp,
529+
.Case<scf::ForOp,
539530
InsertSliceOp,
540531
LinalgOp,
541532
TiledLoopOp,
542533
VectorTransferOpInterface>(
543534
[&](auto op) { return getInplaceableOpResult(op, opOperand); })
544-
// ExtractSliceOp is special, when bufferized inplace it just returns an
545-
// alias to its operand. Its result is never inplaceable on its operand.
546-
.Case([&](ExtractSliceOp op) { return OpResult(); })
535+
// Some ops just return an alias to an operand when bufferized inplace.
536+
// Such OpResults are never inplaceable on an OpOperand.
537+
.Case<ExtractSliceOp, tensor::CastOp>(
538+
[] (auto op) { return OpResult(); })
547539
// CallOpInterface is special, it needs to wait for the callee to be
548540
// bufferized and needs to inspect the BufferAliasInfo object. It can't
549541
// make a proper determination by itself and needs to be conservative.
@@ -572,9 +564,9 @@ static SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) {
572564
if (!hasKnownBufferizationAliasingBehavior(result.getDefiningOp()))
573565
return SmallVector<OpOperand *>();
574566
TypeSwitch<Operation *>(result.getDefiningOp())
575-
.Case([&](tensor::CastOp op) { r.push_back(&op->getOpOperand(0)); })
576-
.Case([&](ExtractSliceOp op) { r.push_back(&op->getOpOperand(0)); })
577567
.Case([&](scf::IfOp op) { populateAliasingOpOperands(op, result, r); })
568+
.Case<ExtractSliceOp, tensor::CastOp>(
569+
[&](auto op) { r.push_back(&op->getOpOperand(0)); })
578570
// In the case of scf::ForOp, this currently assumes the iter_args / yield
579571
// are 1-1. This may fail and is verified at the end.
580572
// TODO: update this.
@@ -606,7 +598,15 @@ static SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) {
606598
/// If the an ExtractSliceOp is bufferized in-place, the source operand will
607599
/// alias with the result.
608600
static OpResult getAliasingOpResult(ExtractSliceOp op, OpOperand &opOperand) {
609-
if (op.source() == opOperand.get())
601+
if (&op->getOpOperand(0) == &opOperand)
602+
return op->getResult(0);
603+
return OpResult();
604+
}
605+
606+
/// If the a tensor::CastOp is bufferized in-place, the source operand will
607+
/// alias with the result.
608+
static OpResult getAliasingOpResult(tensor::CastOp op, OpOperand &opOperand) {
609+
if (&op->getOpOperand(0) == &opOperand)
610610
return op->getResult(0);
611611
return OpResult();
612612
}
@@ -616,11 +616,11 @@ static OpResult getAliasingOpResult(ExtractSliceOp op, OpOperand &opOperand) {
616616
/// TODO: in the future this may need to evolve towards a list of OpResult.
617617
static OpResult getAliasingOpResult(OpOperand &opOperand) {
618618
return TypeSwitch<Operation *, OpResult>(opOperand.getOwner())
619-
// ExtractSliceOp is different: its result is not inplaceable on op.source
620-
// but when bufferized inplace, the result is an aliasing subregion of
621-
// op.source.
622-
.Case(
623-
[&](ExtractSliceOp op) { return getAliasingOpResult(op, opOperand); })
619+
// Some ops are different: Their result is not inplaceable on an OpOperand
620+
// but when bufferized inplace, their result is aliasing (a subregion of)
621+
// an OpOperand.
622+
.Case<ExtractSliceOp, tensor::CastOp>(
623+
[&](auto op) { return getAliasingOpResult(op, opOperand); })
624624
// All other ops, return the result of `getInplaceableOpResult`.
625625
.Default(
626626
[&](Operation *op) { return getInplaceableOpResult(opOperand); });
@@ -639,11 +639,9 @@ static bool isValueRead(Value value) {
639639

640640
while (!workingSet.empty()) {
641641
OpOperand *uMaybeReading = workingSet.pop_back_val();
642-
// Skip over all ExtractSliceOps. These do not read by themselves but just
643-
// add a new alias.
644-
if (auto extractSliceOp =
645-
dyn_cast<ExtractSliceOp>(uMaybeReading->getOwner()))
646-
for (OpOperand &use : extractSliceOp.result().getUses())
642+
// Skip over all ops that create an alias but do not read.
643+
if (isa<ExtractSliceOp, tensor::CastOp>(uMaybeReading->getOwner()))
644+
for (OpOperand &use : uMaybeReading->getOwner()->getResult(0).getUses())
647645
workingSet.push_back(&use);
648646
if (bufferizesToMemoryRead(*uMaybeReading))
649647
return true;
@@ -658,9 +656,9 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
658656
// it. Conservatively return true.
659657
if (!hasKnownBufferizationAliasingBehavior(opOperand.getOwner()))
660658
return true;
661-
// ExtractSliceOp alone doesn't bufferize to a memory read, one of its uses
659+
// Some ops alone do not bufferize to a memory read, but one of their uses
662660
// may.
663-
if (isa<ExtractSliceOp>(opOperand.getOwner()))
661+
if (isa<ExtractSliceOp, tensor::CastOp>(opOperand.getOwner()))
664662
return false;
665663
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
666664
// matching bbArg may.
@@ -690,9 +688,9 @@ static bool bufferizesToMemoryWrite(OpOperand &opOperand) {
690688
// These terminators are not writes.
691689
if (isa<ReturnOp, linalg::YieldOp, scf::YieldOp>(opOperand.getOwner()))
692690
return false;
693-
// ExtractSliceOp alone doesn't bufferize to a memory write, one of its uses
691+
// Some ops alone do not bufferize to a memory write, but one of their uses
694692
// may.
695-
if (isa<ExtractSliceOp>(opOperand.getOwner()))
693+
if (isa<ExtractSliceOp, tensor::CastOp>(opOperand.getOwner()))
696694
return false;
697695
// CallOpInterface alone doesn't bufferize to a memory write, one of the uses
698696
// of the matching bbArg may. It is the responsibility of the caller to
@@ -2320,27 +2318,28 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
23202318
return success();
23212319
}
23222320

2321+
/// This analysis function is used for ops where the first OpOperand aliases
2322+
/// with the first OpResult, without creating a read or write. There are a few
2323+
/// ops besides ExtractSliceOp that have such semantics.
23232324
///
2324-
/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace.
2325-
/// ===========================================================
2325+
/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
23262326
///
2327-
/// When bufferized out of place, a ExtractSlice lowers to alloc + copy. This
2327+
/// When bufferized out of place, an ExtractSliceOp lowers to alloc + copy. This
23282328
/// cannot change the flow of information for either the source or the
23292329
/// result buffers.
23302330
///
2331-
/// When bufferized inplace, a ExtractSliceOp does not by itself create any read
2332-
/// or write from memory. Instead, it has the effect of merging the alias sets
2333-
/// of the source and the result buffers.
2331+
/// When bufferized inplace, an ExtractSliceOp does not by itself create any
2332+
/// read or write from memory. Instead, it has the effect of merging the alias
2333+
/// sets of the source and the result buffers.
23342334
///
23352335
/// An analysis is required to ensure inplace bufferization would not result in
23362336
/// RaW dependence violations.
23372337
static LogicalResult
2338-
bufferizableInPlaceAnalysis(ExtractSliceOp extractSliceOp,
2339-
BufferizationAliasInfo &aliasInfo,
2340-
const DominanceInfo &domInfo) {
2341-
return bufferizableInPlaceAnalysisImpl(extractSliceOp->getOpOperand(0),
2342-
extractSliceOp->getOpResult(0),
2343-
aliasInfo, domInfo);
2338+
bufferizableInPlaceAnalysisAliasOnlyOp(Operation *op,
2339+
BufferizationAliasInfo &aliasInfo,
2340+
const DominanceInfo &domInfo) {
2341+
return bufferizableInPlaceAnalysisImpl(
2342+
op->getOpOperand(0), op->getOpResult(0), aliasInfo, domInfo);
23442343
}
23452344

23462345
/// Determine if `operand` can be bufferized in-place with one of the op's
@@ -2377,14 +2376,11 @@ LogicalResult mlir::linalg::inPlaceAnalysis(SmallVector<Operation *> &ops,
23772376
if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo)))
23782377
return failure();
23792378

2380-
// Special logic to analyze ExtractSliceOp.
2381-
// Note that ExtractSliceOp analysis needs to be interleaved with other ops
2382-
// to properly capture aliases.
2383-
// Walk ExtractSliceOps in reverse for better clobbering analysis behavior:
2384-
// it is easier to detect clobbers of smaller slices before larger ones.
2385-
if (auto extractSliceOp = dyn_cast<ExtractSliceOp>(op))
2379+
// Special logic to analyze ops who's OpResults are not inplaceable on an
2380+
// OpOperand but may create an alias.
2381+
if (isa<ExtractSliceOp, tensor::CastOp>(op))
23862382
if (failed(
2387-
bufferizableInPlaceAnalysis(extractSliceOp, aliasInfo, domInfo)))
2383+
bufferizableInPlaceAnalysisAliasOnlyOp(op, aliasInfo, domInfo)))
23882384
return failure();
23892385
}
23902386

@@ -3053,7 +3049,8 @@ static LogicalResult runInitTensorElimination(FuncOp funcOp,
30533049
aliasInfo.createAliasInfoEntry(extractOp.result());
30543050

30553051
// Run analysis on the ExtractSliceOp.
3056-
if (failed(bufferizableInPlaceAnalysis(extractOp, aliasInfo, domInfo)))
3052+
if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(extractOp, aliasInfo,
3053+
domInfo)))
30573054
return WalkResult::interrupt();
30583055

30593056
// Advance to the next operation.

0 commit comments

Comments
 (0)