Skip to content

Commit 81e8c39

Browse files
[mlir][linalg][bufferize][NFC] Add bufferizesToAliasOnly
The list of operations that do neither read nor write, but create an alias when bufferizing inplace, is getting longer. This commit adds a helper function so that we do not have to spell out the entire list each time. Differential Revision: https://reviews.llvm.org/D112515
1 parent 4e14bac commit 81e8c39

File tree

1 file changed

+32
-19
lines changed

1 file changed

+32
-19
lines changed

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,18 @@ static OpResult getAliasingOpResult(OpOperand &opOperand) {
626626
[&](Operation *op) { return getInplaceableOpResult(opOperand); });
627627
}
628628

629+
/// Return `true` if the given OpOperand does not bufferize to a memory read or
630+
/// write, but creates an alias when bufferized inplace.
631+
static bool bufferizesToAliasOnly(OpOperand &opOperand) {
632+
Operation *owner = opOperand.getOwner();
633+
// TODO: In the future this may need to evolve into a TypeSwitch. For all
634+
// currently supported ops, the aliasing-only OpOperand is always the first
635+
// one.
636+
return isa<ExtractSliceOp, TensorCollapseShapeOp, TensorExpandShapeOp,
637+
tensor::CastOp>(owner) &&
638+
&opOperand == &owner->getOpOperand(0);
639+
}
640+
629641
// Predeclaration of function.
630642
static bool bufferizesToMemoryRead(OpOperand &opOperand);
631643

@@ -640,8 +652,8 @@ static bool isValueRead(Value value) {
640652
while (!workingSet.empty()) {
641653
OpOperand *uMaybeReading = workingSet.pop_back_val();
642654
// Skip over all ops that create an alias but do not read.
643-
if (isa<ExtractSliceOp, tensor::CastOp>(uMaybeReading->getOwner()))
644-
for (OpOperand &use : uMaybeReading->getOwner()->getResult(0).getUses())
655+
if (bufferizesToAliasOnly(*uMaybeReading))
656+
for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses())
645657
workingSet.push_back(&use);
646658
if (bufferizesToMemoryRead(*uMaybeReading))
647659
return true;
@@ -658,7 +670,7 @@ static bool bufferizesToMemoryRead(OpOperand &opOperand) {
658670
return true;
659671
// Some ops alone do not bufferize to a memory read, but one of their uses
660672
// may.
661-
if (isa<ExtractSliceOp, tensor::CastOp>(opOperand.getOwner()))
673+
if (bufferizesToAliasOnly(opOperand))
662674
return false;
663675
// scf::ForOp alone doesn't bufferize to a memory read, one of the uses of its
664676
// matching bbArg may.
@@ -690,7 +702,7 @@ static bool bufferizesToMemoryWrite(OpOperand &opOperand) {
690702
return false;
691703
// Some ops alone do not bufferize to a memory write, but one of their uses
692704
// may.
693-
if (isa<ExtractSliceOp, tensor::CastOp>(opOperand.getOwner()))
705+
if (bufferizesToAliasOnly(opOperand))
694706
return false;
695707
// CallOpInterface alone doesn't bufferize to a memory write, one of the uses
696708
// of the matching bbArg may. It is the responsibility of the caller to
@@ -2318,9 +2330,8 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
23182330
return success();
23192331
}
23202332

2321-
/// This analysis function is used for ops where the first OpOperand aliases
2322-
/// with the first OpResult, without creating a read or write. There are a few
2323-
/// ops besides ExtractSliceOp that have such semantics.
2333+
/// This analysis function is used for OpOperands that alias with an OpResult
2334+
/// but are not inplaceable on it. E.g., ExtractSliceOp.
23242335
///
23252336
/// Rationale for bufferizing `%1 = tensor.extract_slice %0[...]` inplace:
23262337
///
@@ -2335,11 +2346,12 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OpResult result,
23352346
/// An analysis is required to ensure inplace bufferization would not result in
23362347
/// RaW dependence violations.
23372348
static LogicalResult
2338-
bufferizableInPlaceAnalysisAliasOnlyOp(Operation *op,
2349+
bufferizableInPlaceAnalysisAliasOnlyOp(OpOperand &operand,
23392350
BufferizationAliasInfo &aliasInfo,
23402351
const DominanceInfo &domInfo) {
2341-
return bufferizableInPlaceAnalysisImpl(
2342-
op->getOpOperand(0), op->getOpResult(0), aliasInfo, domInfo);
2352+
OpResult result = getAliasingOpResult(operand);
2353+
assert(result && "expected that the OpOperand has an aliasing OpResult");
2354+
return bufferizableInPlaceAnalysisImpl(operand, result, aliasInfo, domInfo);
23432355
}
23442356

23452357
/// Determine if `operand` can be bufferized in-place with one of the op's
@@ -2372,16 +2384,17 @@ LogicalResult mlir::linalg::inPlaceAnalysis(SmallVector<Operation *> &ops,
23722384

23732385
// Walk ops in reverse for better interference analysis.
23742386
for (Operation *op : reverse(ops)) {
2375-
for (OpOperand &opOperand : op->getOpOperands())
2387+
for (OpOperand &opOperand : op->getOpOperands()) {
23762388
if (failed(bufferizableInPlaceAnalysis(opOperand, aliasInfo, domInfo)))
23772389
return failure();
23782390

2379-
// Special logic to analyze ops who's OpResults are not inplaceable on an
2380-
// OpOperand but may create an alias.
2381-
if (isa<ExtractSliceOp, tensor::CastOp>(op))
2382-
if (failed(
2383-
bufferizableInPlaceAnalysisAliasOnlyOp(op, aliasInfo, domInfo)))
2384-
return failure();
2391+
// Special logic to analyze OpOperands that are not inplaceable on an
2392+
// OpResult but may create an alias.
2393+
if (bufferizesToAliasOnly(opOperand))
2394+
if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(opOperand, aliasInfo,
2395+
domInfo)))
2396+
return failure();
2397+
}
23852398
}
23862399

23872400
return success();
@@ -3049,8 +3062,8 @@ static LogicalResult runInitTensorElimination(FuncOp funcOp,
30493062
aliasInfo.createAliasInfoEntry(extractOp.result());
30503063

30513064
// Run analysis on the ExtractSliceOp.
3052-
if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(extractOp, aliasInfo,
3053-
domInfo)))
3065+
if (failed(bufferizableInPlaceAnalysisAliasOnlyOp(
3066+
extractOp->getOpOperand(0), aliasInfo, domInfo)))
30543067
return WalkResult::interrupt();
30553068

30563069
// Advance to the next operation.

0 commit comments

Comments
 (0)