Skip to content

Commit 7e749d4

Browse files
authored
[mlir][bufferization]-Add ControlBuildSubsetExtractionFn to TensorEmptyElimination (#120851)
This PR Adds a `ControlBuildSubsetExtractionFn` to the tensor empty elimination util, This will control the building of the subsets extraction of the `SubsetInsertionOpInterface`. This control function returns the subsets extraction value that will replace the `emptyTensorOp` use which is being consumed by a specefic user (which the util expects to eliminate it). The default control function will stay like today's behavior without any additional changes.
1 parent 8e965d8 commit 7e749d4

File tree

2 files changed

+53
-20
lines changed

2 files changed

+53
-20
lines changed

mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_TRANSFORMS_H
1111

1212
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1314
#include "mlir/IR/Operation.h"
15+
#include "mlir/Interfaces/SubsetOpInterface.h"
1416

1517
namespace mlir {
1618
namespace bufferization {
@@ -34,13 +36,35 @@ struct OneShotBufferizationOptions;
3436
/// "tensor.empty" op.
3537
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op);
3638

39+
/// A function type that defines a callback to control the construction
40+
/// of the subset extraction of the `SubsetInsertionOpInterface`.
41+
/// The subset extraction value can be used as a replacement for the
42+
/// `emptyTensorOp` value which is being consumed by `user`, failing
43+
/// of building such a value should be indicated with an empty value.
44+
/// This function should guarantee the legality of the replacement,
45+
/// i.e. the replacement should dominate the user of the `emptyTensorOp`
46+
/// being eliminated.
47+
using ControlBuildSubsetExtractionFn =
48+
std::function<Value(RewriterBase &, SubsetInsertionOpInterface,
49+
tensor::EmptyOp emptyTensorOp, Operation *user)>;
50+
51+
/// This method builds and returns a subset extraction value for the
52+
/// destination tensor that the given `op` inserts into.
53+
/// It returns a value which should replace the `emptyTensorOp` use
54+
/// that is being consumed by `user`.
55+
/// If no such a value found it will return an empty Value.
56+
Value buildSubsetExtraction(RewriterBase &rewriter,
57+
SubsetInsertionOpInterface op,
58+
tensor::EmptyOp emptyTensorOp, Operation *user);
59+
3760
/// Try to eliminate "tensor.empty" ops inside `op`.
3861
///
3962
/// This function overload accepts an existing `OneShotAnalysisState`, which
4063
/// contains in-place bufferization decisions. This overload is useful if an
4164
/// existing analysis should be reused for empty tensor elimination.
42-
LogicalResult eliminateEmptyTensors(RewriterBase &rewriter, Operation *op,
43-
OneShotAnalysisState &state);
65+
LogicalResult eliminateEmptyTensors(
66+
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
67+
ControlBuildSubsetExtractionFn subsetsExtractionFn = buildSubsetExtraction);
4468

4569
/// Within the given operation, hoist buffers from loops where possible. See
4670
/// "BufferLoopHoistingPass" for more information.

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
9393
return nullptr;
9494
}
9595

96+
Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
97+
SubsetInsertionOpInterface op,
98+
tensor::EmptyOp emptyTensorOp,
99+
Operation *user) {
100+
101+
mlir::OpBuilder::InsertionGuard guard(rewriter);
102+
// All values that are needed to create the replacement op.
103+
SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
104+
// Find a suitable insertion point. If no suitable insertion point
105+
// for the replacement can be found, return an empty value to skip
106+
// this replacement.
107+
Operation *insertionPoint =
108+
findValidInsertionPoint(emptyTensorOp, user, neededValues);
109+
if (!insertionPoint)
110+
return {};
111+
112+
rewriter.setInsertionPoint(insertionPoint);
113+
Value replacement =
114+
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
115+
return replacement;
116+
}
117+
96118
LogicalResult mlir::bufferization::eliminateEmptyTensors(
97-
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
119+
RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
120+
ControlBuildSubsetExtractionFn subsetsExtractionFn) {
98121
OpBuilder::InsertionGuard g(rewriter);
99122
llvm::DenseSet<OpOperand *> visitedOpOperands;
100123
op->walk([&](SubsetInsertionOpInterface op) {
@@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
105128
if (!state.isInPlace(source))
106129
return WalkResult::skip();
107130

108-
// All values that are needed to create the replacement op.
109-
SmallVector<Value> neededValues =
110-
op.getValuesNeededToBuildSubsetExtraction();
111-
112131
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
113132
// equivalent tensors. I.e., stop when there are ops such as extract_slice
114133
// on the path.
@@ -129,8 +148,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
129148
&visitedOpOperands);
130149

131150
for (Value v : emptyTensors) {
132-
Operation *emptyTensorOp = v.getDefiningOp();
133-
151+
auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
152+
assert(emptyTensorOp && "expected tensor.empty op");
134153
// Find the use to be replaced from the use-def chain.
135154
auto iter = llvm::find_if(
136155
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
@@ -142,17 +161,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
142161
continue;
143162
OpOperand *useToBeReplaced = *iter;
144163
Operation *user = useToBeReplaced->getOwner();
145-
146-
// Find a suitable insertion point. If no suitable insertion point for
147-
// the replacement can be found, skip this replacement.
148-
Operation *insertionPoint =
149-
findValidInsertionPoint(emptyTensorOp, user, neededValues);
150-
if (!insertionPoint)
151-
continue;
152-
153-
rewriter.setInsertionPoint(insertionPoint);
154-
Value replacement =
155-
op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
164+
auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
156165
if (!replacement)
157166
continue;
158167
if (emptyTensorOp == replacement.getDefiningOp())

0 commit comments

Comments
 (0)