From 3759dbf479ae058ca176b2bf4a5a2e0eae650bc0 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 3 Jul 2025 19:54:19 +0000 Subject: [PATCH 1/2] [mlir][IR][WIP] Set insertion point when erasing an operation --- mlir/lib/IR/PatternMatch.cpp | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 1e6084822a99a..95b8d4cac2b29 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -152,12 +152,45 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) { eraseOp(op); } +/// Returns the given block iterator if it lies within the block `b`. +/// Otherwise, otherwise finds the ancestor of the given block iterator that +/// lies within `b`. Returns and "empty" iterator if the latter fails. +/// +/// Note: This is a variant of Block::findAncestorOpInBlock that operates on +/// block iterators instead of ops. +static std::pair +findAncestorIteratorInBlock(Block *b, Block *itBlock, Block::iterator it) { + // Case 1: The iterator lies within the block. + if (itBlock == b) + return std::make_pair(itBlock, it); + + // Otherwise: Find ancestor iterator. Bail if we run out of parent ops. + Operation *parentOp = itBlock->getParentOp(); + if (!parentOp) + return std::make_pair(static_cast(nullptr), Block::iterator()); + Operation *op = b->findAncestorOpInBlock(*parentOp); + if (!op) + return std::make_pair(static_cast(nullptr), Block::iterator()); + return std::make_pair(op->getBlock(), op->getIterator()); +} + /// This method erases an operation that is known to have no uses. The uses of /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); auto *rewriteListener = dyn_cast_if_present(listener); + // If the current insertion point is before/within the erased operation, we + // need to adjust the insertion point to be after the operation. + if (getInsertionBlock()) { + Block *insertionBlock; + Block::iterator insertionPoint; + std::tie(insertionBlock, insertionPoint) = findAncestorIteratorInBlock( + op->getBlock(), getInsertionBlock(), getInsertionPoint()); + if (insertionBlock && insertionPoint == op->getIterator()) + setInsertionPointAfter(op); + } + // Fast path: If no listener is attached, the op can be dropped in one go. if (!rewriteListener) { op->erase(); @@ -322,6 +355,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. assert(source->empty() && "expected 'source' to be empty"); eraseBlock(source); From c2c125f826bc6b0f8a8dc809e00886588414afdc Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 4 Jul 2025 09:48:51 +0000 Subject: [PATCH 2/2] address comments --- mlir/include/mlir/IR/PatternMatch.h | 5 ++++ mlir/lib/IR/PatternMatch.cpp | 37 ++++------------------------- 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index afeb784b85a12..7c330f556ac36 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -504,6 +504,11 @@ class RewriterBase : public OpBuilder { } /// This method erases an operation that is known to have no uses. + /// + /// If the current insertion point is before the erased operation, it is + /// adjusted to the following operation (or the end of the block). If the + /// current insertion point is within the erased operation, the insertion + /// point is left in an invalid state. virtual void eraseOp(Operation *op); /// This method erases all operations in a block. diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 95b8d4cac2b29..0922df4573d3e 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -152,44 +152,17 @@ void RewriterBase::replaceOp(Operation *op, Operation *newOp) { eraseOp(op); } -/// Returns the given block iterator if it lies within the block `b`. -/// Otherwise, otherwise finds the ancestor of the given block iterator that -/// lies within `b`. Returns and "empty" iterator if the latter fails. -/// -/// Note: This is a variant of Block::findAncestorOpInBlock that operates on -/// block iterators instead of ops. -static std::pair -findAncestorIteratorInBlock(Block *b, Block *itBlock, Block::iterator it) { - // Case 1: The iterator lies within the block. - if (itBlock == b) - return std::make_pair(itBlock, it); - - // Otherwise: Find ancestor iterator. Bail if we run out of parent ops. - Operation *parentOp = itBlock->getParentOp(); - if (!parentOp) - return std::make_pair(static_cast(nullptr), Block::iterator()); - Operation *op = b->findAncestorOpInBlock(*parentOp); - if (!op) - return std::make_pair(static_cast(nullptr), Block::iterator()); - return std::make_pair(op->getBlock(), op->getIterator()); -} - /// This method erases an operation that is known to have no uses. The uses of /// the given operation *must* be known to be dead. void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); auto *rewriteListener = dyn_cast_if_present(listener); - // If the current insertion point is before/within the erased operation, we - // need to adjust the insertion point to be after the operation. - if (getInsertionBlock()) { - Block *insertionBlock; - Block::iterator insertionPoint; - std::tie(insertionBlock, insertionPoint) = findAncestorIteratorInBlock( - op->getBlock(), getInsertionBlock(), getInsertionPoint()); - if (insertionBlock && insertionPoint == op->getIterator()) - setInsertionPointAfter(op); - } + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionBlock() == op->getBlock() && + getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); // Fast path: If no listener is attached, the op can be dropped in one go. if (!rewriteListener) {