Skip to content

Commit f178c38

Browse files
[mlir][scf][bufferize][NFC] Simplify verifyAnalysis implementation
Differential Revision: https://reviews.llvm.org/D124928
1 parent 47c559d commit f178c38

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -444,9 +444,12 @@ struct ForOpInterface
444444
}
445445

446446
/// Assert that yielded values of an scf.for op are equivalent to their
447-
/// corresponding bbArgs. Otherwise, an alloc+copy are inserted and yielded
448-
/// from the loop. This could be a performance problem, so it must be
449-
/// explicitly activated with `alloc-return-allocs`.
447+
/// corresponding bbArgs. In that case, the buffer relations of the
448+
/// corresponding OpResults are "Equivalent".
449+
///
450+
/// If this is not the case, an allocs+copies are inserted and yielded from
451+
/// the loop. This could be a performance problem, so it must be explicitly
452+
/// activated with `alloc-return-allocs`.
450453
LogicalResult verifyAnalysis(Operation *op,
451454
const AnalysisState &state) const {
452455
const auto &options =
@@ -457,22 +460,19 @@ struct ForOpInterface
457460
auto forOp = cast<scf::ForOp>(op);
458461
auto yieldOp =
459462
cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
460-
for (OpOperand &operand : yieldOp->getOpOperands()) {
461-
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
462-
if (!tensorType)
463+
for (OpResult opResult : op->getOpResults()) {
464+
if (!opResult.getType().isa<TensorType>())
463465
continue;
464466

465-
OpOperand &forOperand = forOp.getOpOperandForResult(
466-
forOp->getResult(operand.getOperandNumber()));
467-
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
468467
// Note: This is overly strict. We should check for aliasing bufferized
469468
// values. But we don't have a "must-alias" analysis yet.
470-
if (!state.areEquivalentBufferizedValues(operand.get(), bbArg))
469+
if (bufferRelation(op, opResult, state) != BufferRelation::Equivalent)
471470
return yieldOp->emitError()
472-
<< "Yield operand #" << operand.getOperandNumber()
471+
<< "Yield operand #" << opResult.getResultNumber()
473472
<< " does not bufferize to a buffer that is aliasing the "
474473
"matching enclosing scf::for operand";
475474
}
475+
476476
return success();
477477
}
478478
};

0 commit comments

Comments
 (0)