Skip to content

Commit aceb407

Browse files
[mlir][TOSA] Do not access erased op in MaxPool2dOp lowering
1 parent f2e244f commit aceb407

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -808,6 +808,7 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
808808
dilationAttr);
809809

810810
rewriter.setInsertionPointAfter(op);
811+
auto nanMode = op.getNanMode();
811812
rewriter.replaceOp(op, resultOp);
812813

813814
// NaN propagation has no meaning for non floating point types.
@@ -821,11 +822,10 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
821822
// we've already produced a named op we will just take its body and modify
822823
// it to include the appropriate checks. If the current value is NaN the
823824
// old value of pool will be taken otherwise we use the result.
824-
if (const auto nanMode = op.getNanMode(); nanMode == "IGNORE") {
825+
if (nanMode == "IGNORE") {
825826
auto genericOp = rewriter.create<linalg::GenericOp>(
826-
op->getLoc(), resultOp.getType(0), resultOp.getInputs(),
827-
resultOp.getOutputs(), resultOp.getIndexingMapsArray(),
828-
resultOp.getIteratorTypesArray(),
827+
loc, resultOp.getType(0), resultOp.getInputs(), resultOp.getOutputs(),
828+
resultOp.getIndexingMapsArray(), resultOp.getIteratorTypesArray(),
829829
[&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) {
830830
IRMapping map;
831831
auto oldBlock = resultOp.getRegion().begin();
@@ -834,10 +834,10 @@ class MaxPool2dConverter : public OpConversionPattern<tosa::MaxPool2dOp> {
834834
map.map(oldArgs, blockArgs);
835835
auto *newOp = opBuilder.clone(oldMaxOp, map);
836836
Value isNaN = opBuilder.create<arith::CmpFOp>(
837-
op->getLoc(), arith::CmpFPredicate::UNO, blockArgs.front(),
837+
loc, arith::CmpFPredicate::UNO, blockArgs.front(),
838838
blockArgs.front());
839839
auto selectOp = opBuilder.create<arith::SelectOp>(
840-
op->getLoc(), isNaN, blockArgs.back(), newOp->getResult(0));
840+
loc, isNaN, blockArgs.back(), newOp->getResult(0));
841841
opBuilder.create<linalg::YieldOp>(loc, selectOp.getResult());
842842
});
843843
rewriter.replaceOp(resultOp, genericOp);

mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,8 +320,6 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
320320
rewriter.getDenseI64ArrayAttr(sizes),
321321
rewriter.getDenseI64ArrayAttr(strides));
322322

323-
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
324-
325323
// Remove const_shape ops when it no longer has use point.
326324
Operation *startConstShape = sliceOp.getStart().getDefiningOp();
327325
if (startConstShape->getResult(0).hasOneUse())
@@ -331,6 +329,7 @@ class SliceConverter : public OpConversionPattern<tosa::SliceOp> {
331329
if (sizeConstShape->getResult(0).hasOneUse())
332330
rewriter.eraseOp(sizeConstShape);
333331

332+
rewriter.replaceOp(sliceOp, newSliceOp.getResult());
334333
return success();
335334
}
336335
};

0 commit comments

Comments
 (0)