Skip to content

Commit 585a8a3

Browse files
[mlir][bufferize] OpOperands can have multiple aliasing OpResults
This makes getAliasingOpResult symmetric to getAliasingOpOperand. The previous implementation was confusing for users and implemented in such a way only because there are currently no bufferizable ops that have multiple aliasing OpResults. Differential Revision: https://reviews.llvm.org/D119259
1 parent 22a1973 commit 585a8a3

File tree

11 files changed

+122
-97
lines changed

11 files changed

+122
-97
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,8 @@ class BufferizationState {
180180
SmallVector<OpOperand *> getAliasingOpOperand(OpResult result) const;
181181

182182
/// Determine which OpResult will alias with `opOperand` if the op is
183-
/// bufferized in place. Return an empty OpResult if the op is not
184-
/// bufferizable.
185-
OpResult getAliasingOpResult(OpOperand &opOperand) const;
183+
/// bufferized in place. Return an empty vector if the op is not bufferizable.
184+
SmallVector<OpResult> getAliasingOpResult(OpOperand &opOperand) const;
186185

187186
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if
188187
/// the op is not bufferizable.
@@ -396,9 +395,10 @@ struct AllocationHoistingBarrierOnly
396395
return {};
397396
}
398397

399-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
400-
const BufferizationState &state) const {
401-
return OpResult();
398+
SmallVector<OpResult>
399+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
400+
const BufferizationState &state) const {
401+
return {};
402402
}
403403

404404
BufferRelation bufferRelation(Operation *op, OpResult opResult,

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
124124
bufferized in-place. This method will never be called on OpOperands
125125
that do not have a tensor type.
126126
}],
127-
/*retType=*/"OpResult",
127+
/*retType=*/"SmallVector<OpResult>",
128128
/*methodName=*/"getAliasingOpResult",
129129
/*args=*/(ins "OpOperand &":$opOperand,
130130
"const BufferizationState &":$state),
@@ -162,8 +162,10 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
162162
for (OpOperand &opOperand : $_op.getOperation()->getOpOperands()) {
163163
if (!opOperand.get().getType().isa<TensorType>())
164164
continue;
165-
if (bufferizableOp.getAliasingOpResult(opOperand, state) ==
166-
opResult)
165+
SmallVector<OpResult> aliasingOpResults =
166+
bufferizableOp.getAliasingOpResult(opOperand, state);
167+
if (llvm::find(aliasingOpResults, opResult)
168+
!= aliasingOpResults.end())
167169
result.push_back(&opOperand);
168170
}
169171
return result;
@@ -304,8 +306,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
304306
cast<BufferizableOpInterface>(getOperation());
305307
return !bufferizableOp.bufferizesToMemoryRead(opOperand, state)
306308
&& !bufferizableOp.bufferizesToMemoryWrite(opOperand, state)
307-
&& static_cast<bool>(
308-
bufferizableOp.getAliasingOpResult(opOperand, state));
309+
&& !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
309310
}
310311

311312
// TODO: The following two attributes should belong to the tensor dialect.

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,9 +211,9 @@ def Bufferization_ToMemrefOp : Bufferization_Op<"to_memref", [
211211
return true;
212212
}
213213

214-
OpResult getAliasingOpResult(OpOperand &opOperand,
215-
const BufferizationState &state) const {
216-
return OpResult();
214+
SmallVector<OpResult> getAliasingOpResult(
215+
OpOperand &opOperand, const BufferizationState &state) const {
216+
return {};
217217
}
218218

219219
LogicalResult bufferize(RewriterBase &rewriter,

mlir/lib/Dialect/Arithmetic/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,10 @@ struct IndexCastOpInterface
6969
return false;
7070
}
7171

72-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
73-
const BufferizationState &state) const {
74-
return op->getResult(0);
72+
SmallVector<OpResult>
73+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
74+
const BufferizationState &state) const {
75+
return {op->getResult(0)};
7576
}
7677

7778
BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -114,9 +115,10 @@ struct SelectOpInterface
114115
return false;
115116
}
116117

117-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
118-
const BufferizationState &state) const {
119-
return op->getOpResult(0) /*result*/;
118+
SmallVector<OpResult>
119+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
120+
const BufferizationState &state) const {
121+
return {op->getOpResult(0) /*result*/};
120122
}
121123

122124
SmallVector<OpOperand *>

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,13 @@ BufferizationState::getAliasingOpOperand(OpResult result) const {
8787
}
8888

8989
/// Determine which OpResult will alias with `opOperand` if the op is bufferized
90-
/// in place. Return an empty OpResult if the op is not bufferizable.
91-
OpResult BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
90+
/// in place. Return an empty vector if the op is not bufferizable.
91+
SmallVector<OpResult>
92+
BufferizationState::getAliasingOpResult(OpOperand &opOperand) const {
9293
if (auto bufferizableOp =
9394
dyn_cast<BufferizableOpInterface>(opOperand.getOwner()))
9495
return bufferizableOp.getAliasingOpResult(opOperand, *this);
95-
return OpResult();
96+
return {};
9697
}
9798

9899
/// Return true if `opOperand` bufferizes to a memory read. Return `true` if the
@@ -144,8 +145,9 @@ bool BufferizationState::isValueRead(Value value) const {
144145
OpOperand *uMaybeReading = workingSet.pop_back_val();
145146
// Skip over all ops that neither read nor write (but create an alias).
146147
if (bufferizesToAliasOnly(*uMaybeReading))
147-
for (OpOperand &use : getAliasingOpResult(*uMaybeReading).getUses())
148-
workingSet.push_back(&use);
148+
for (OpResult opResult : getAliasingOpResult(*uMaybeReading))
149+
for (OpOperand &use : opResult.getUses())
150+
workingSet.push_back(&use);
149151
if (bufferizesToMemoryRead(*uMaybeReading))
150152
return true;
151153
}
@@ -266,9 +268,10 @@ FailureOr<Value> BufferizationState::getBuffer(
266268
}))
267269
return resultBuffer;
268270
// Do not copy if the copied data is never read.
269-
OpResult aliasingOpResult = getAliasingOpResult(opOperand);
270-
if (aliasingOpResult && !bufferizesToMemoryRead(opOperand) &&
271-
!isValueRead(aliasingOpResult))
271+
SmallVector<OpResult> aliasingOpResults = getAliasingOpResult(opOperand);
272+
if (!aliasingOpResults.empty() && !bufferizesToMemoryRead(opOperand) &&
273+
llvm::none_of(aliasingOpResults,
274+
[&](OpResult opResult) { return isValueRead(opResult); }))
272275
return resultBuffer;
273276
// Do not copy if this op does not read the data, but writes it.
274277
if (bufferizesToMemoryWrite(opOperand) && !bufferizesToMemoryRead(opOperand))

mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ bool BufferizationAliasInfo::isInPlace(OpOperand &operand) const {
140140
void BufferizationAliasInfo::bufferizeInPlace(OpOperand &operand,
141141
BufferizationState &state) {
142142
markInPlace(operand);
143-
if (OpResult result = state.getAliasingOpResult(operand))
143+
for (OpResult result : state.getAliasingOpResult(operand))
144144
aliasInfo.unionSets(result, operand.get());
145145
}
146146

@@ -196,8 +196,8 @@ AnalysisBufferizationState::AnalysisBufferizationState(
196196
for (OpOperand &opOperand : bufferizableOp->getOpOperands()) {
197197
if (opOperand.get().getType().isa<TensorType>())
198198
if (bufferizableOp.mustBufferizeInPlace(opOperand, *this)) {
199-
if (OpResult opResult =
200-
bufferizableOp.getAliasingOpResult(opOperand, *this))
199+
for (OpResult opResult :
200+
bufferizableOp.getAliasingOpResult(opOperand, *this))
201201
aliasInfo.unionAliasSets(opOperand.get(), opResult);
202202
aliasInfo.markInPlace(opOperand);
203203
}
@@ -404,7 +404,9 @@ static bool hasReadAfterWriteInterference(
404404

405405
// No conflict if the conflicting write and the last write are the same
406406
// use.
407-
if (state.getAliasingOpResult(*uConflictingWrite) == lastWrite)
407+
SmallVector<OpResult> aliasingOpResult =
408+
state.getAliasingOpResult(*uConflictingWrite);
409+
if (aliasingOpResult.size() == 1 && aliasingOpResult[0] == lastWrite)
408410
continue;
409411

410412
// All requirements are met. Conflict found!
@@ -477,7 +479,7 @@ static bool wouldCreateReadAfterWriteInterference(
477479
DenseSet<OpOperand *> usesRead, usesWrite;
478480
getAliasingReads(usesRead, operand.get());
479481
getAliasingInplaceWrites(usesWrite, operand.get());
480-
if (OpResult result = state.getAliasingOpResult(operand)) {
482+
for (OpResult result : state.getAliasingOpResult(operand)) {
481483
getAliasingReads(usesRead, result);
482484
getAliasingInplaceWrites(usesWrite, result);
483485
}
@@ -506,7 +508,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand,
506508
bool hasWrite = aliasesInPlaceWrite(opOperand.get(), aliasInfo, state) ||
507509
state.bufferizesToMemoryWrite(opOperand);
508510

509-
if (OpResult opResult = state.getAliasingOpResult(opOperand))
511+
for (OpResult opResult : state.getAliasingOpResult(opOperand))
510512
hasWrite |= aliasesInPlaceWrite(opResult, aliasInfo, state);
511513

512514
return hasWrite;

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,7 @@ struct LinalgOpInterface
168168
// Operand is written to if it has an aliasing OpResult. For more details,
169169
// see `computeAliasingPairs`.
170170
auto bufferizableOp = cast<BufferizableOpInterface>(op);
171-
return static_cast<bool>(
172-
bufferizableOp.getAliasingOpResult(opOperand, state));
171+
return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
173172
}
174173

175174
SmallVector<OpOperand *>
@@ -185,13 +184,16 @@ struct LinalgOpInterface
185184
return {};
186185
}
187186

188-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
189-
const BufferizationState &state) const {
187+
SmallVector<OpResult>
188+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
189+
const BufferizationState &state) const {
190190
auto genericOp = cast<linalg::LinalgOp>(op);
191191

192192
// Aliasing OpOperand/OpResult pairs are computed by `computeAliasingPairs`.
193193
DenseMap<OpOperand *, OpResult> pairs = computeAliasingPairs(genericOp);
194-
return pairs[&opOperand];
194+
if (!pairs.count(&opOperand))
195+
return {};
196+
return {pairs[&opOperand]};
195197
}
196198

197199
BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -252,16 +254,19 @@ struct TiledLoopOpInterface
252254

253255
// Only operands with an aliasing OpResult (i.e., output operands) bufferize
254256
// to a memory write.
255-
return static_cast<bool>(
256-
bufferizableOp.getAliasingOpResult(opOperand, state));
257+
return !bufferizableOp.getAliasingOpResult(opOperand, state).empty();
257258
}
258259

259-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
260-
const BufferizationState &state) const {
260+
SmallVector<OpResult>
261+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
262+
const BufferizationState &state) const {
261263
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
262264

263265
// Output operands are tied to their corresponding OpResults.
264-
return tiledLoopOp.getTiedOpResult(opOperand);
266+
OpResult opResult = tiledLoopOp.getTiedOpResult(opOperand);
267+
if (!opResult)
268+
return {};
269+
return {opResult};
265270
}
266271

267272
BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -397,9 +402,10 @@ struct YieldOpInterface
397402
return false;
398403
}
399404

400-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
401-
const BufferizationState &state) const {
402-
return OpResult();
405+
SmallVector<OpResult>
406+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
407+
const BufferizationState &state) const {
408+
return {};
403409
}
404410

405411
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -723,25 +723,24 @@ struct CallOpInterface
723723
funcOp.getArgument(opOperand.getOperandNumber()));
724724
}
725725

726-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
727-
const BufferizationState &state) const {
726+
SmallVector<OpResult>
727+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
728+
const BufferizationState &state) const {
728729
CallOp callOp = cast<CallOp>(op);
729730
FuncOp funcOp = getCalledFunction(callOp);
730731
assert(funcOp && "expected CallOp to a FuncOp");
731732
const ModuleBufferizationState &moduleState =
732733
getModuleBufferizationState(state);
733734

735+
SmallVector<OpResult> result;
734736
for (int64_t resultIdx = 0; resultIdx < callOp->getNumResults();
735737
++resultIdx)
736738
if (Optional<int64_t> maybeArgNumber =
737739
getEquivalentFuncArgIdx(funcOp, moduleState, resultIdx))
738740
if (*maybeArgNumber == opOperand.getOperandNumber())
739-
return callOp->getOpResult(resultIdx);
741+
result.push_back(callOp->getOpResult(resultIdx));
740742

741-
// Note: Returning a non-equivalent tensor from a FuncOp is currently not
742-
// supported an will fail bufferization. (Even if allow-return-memref, it
743-
// will fail when the function is called.)
744-
return OpResult();
743+
return result;
745744
}
746745

747746
SmallVector<OpOperand *>
@@ -916,9 +915,10 @@ struct ReturnOpInterface
916915
return false;
917916
}
918917

919-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
920-
const BufferizationState &state) const {
921-
return OpResult();
918+
SmallVector<OpResult>
919+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
920+
const BufferizationState &state) const {
921+
return {};
922922
}
923923

924924
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,13 @@ struct ForOpInterface
278278
return true;
279279
}
280280

281-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
282-
const BufferizationState &state) const {
281+
SmallVector<OpResult>
282+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
283+
const BufferizationState &state) const {
283284
auto forOp = cast<scf::ForOp>(op);
284285
if (!opOperand.get().getType().isa<RankedTensorType>())
285-
return OpResult();
286-
return forOp.getResultForOpOperand(opOperand);
286+
return {};
287+
return {forOp.getResultForOpOperand(opOperand)};
287288
}
288289

289290
BufferRelation bufferRelation(Operation *op, OpResult opResult,
@@ -401,13 +402,14 @@ struct YieldOpInterface
401402
return false;
402403
}
403404

404-
OpResult getAliasingOpResult(Operation *op, OpOperand &opOperand,
405-
const BufferizationState &state) const {
405+
SmallVector<OpResult>
406+
getAliasingOpResult(Operation *op, OpOperand &opOperand,
407+
const BufferizationState &state) const {
406408
if (isa<scf::IfOp>(op->getParentOp()))
407-
return op->getParentOp()->getResult(opOperand.getOperandNumber());
409+
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
408410
if (isa<scf::ExecuteRegionOp>(op->getParentOp()))
409-
return op->getParentOp()->getResult(opOperand.getOperandNumber());
410-
return OpResult();
411+
return {op->getParentOp()->getResult(opOperand.getOperandNumber())};
412+
return {};
411413
}
412414

413415
bool mustBufferizeInPlace(Operation *op, OpOperand &opOperand,

0 commit comments

Comments
 (0)