From d5d2790750f5b2d3d11c8c36c392f0b92faa5766 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Tue, 15 Jul 2025 11:31:16 +0100 Subject: [PATCH 1/2] [mlir][SliceAnalysis] Fix stack overflow in graph regions --- mlir/include/mlir/Analysis/SliceAnalysis.h | 10 +++--- mlir/lib/Analysis/SliceAnalysis.cpp | 40 ++++++++++++++------- mlir/test/Dialect/Affine/slicing-utils.mlir | 23 ++++++++++++ 3 files changed, 56 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Analysis/SliceAnalysis.h b/mlir/include/mlir/Analysis/SliceAnalysis.h index d082d2d9f758b..18349d071bb2e 100644 --- a/mlir/include/mlir/Analysis/SliceAnalysis.h +++ b/mlir/include/mlir/Analysis/SliceAnalysis.h @@ -65,8 +65,9 @@ using ForwardSliceOptions = SliceOptions; /// /// The implementation traverses the use chains in postorder traversal for /// efficiency reasons: if an operation is already in `forwardSlice`, no -/// need to traverse its uses again. Since use-def chains form a DAG, this -/// terminates. +/// need to traverse its uses again. In the presence of use-def cycles in a +/// graph region, the traversal stops at the first operation that was already +/// visited (which is not added to the slice anymore). /// /// Upon return to the root call, `forwardSlice` is filled with a /// postorder list of uses (i.e. a reverse topological order). To get a proper @@ -114,8 +115,9 @@ void getForwardSlice(Value root, SetVector *forwardSlice, /// /// The implementation traverses the def chains in postorder traversal for /// efficiency reasons: if an operation is already in `backwardSlice`, no -/// need to traverse its definitions again. Since useuse-def chains form a DAG, -/// this terminates. +/// need to traverse its definitions again. In the presence of use-def cycles +/// in a graph region, the traversal stops at the first operation that was +/// already visited (which is not added to the slice anymore). /// /// Upon return to the root call, `backwardSlice` is filled with a /// postorder list of defs. This happens to be a topological order, from the diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index 36a9812bd7972..c50b652aaf283 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -26,7 +26,8 @@ using namespace mlir; static void -getForwardSliceImpl(Operation *op, SetVector *forwardSlice, +getForwardSliceImpl(Operation *op, DenseSet &visited, + SetVector *forwardSlice, const SliceOptions::TransitiveFilter &filter = nullptr) { if (!op) return; @@ -41,19 +42,23 @@ getForwardSliceImpl(Operation *op, SetVector *forwardSlice, for (Block &block : region) for (Operation &blockOp : block) if (forwardSlice->count(&blockOp) == 0) - getForwardSliceImpl(&blockOp, forwardSlice, filter); - for (Value result : op->getResults()) { - for (Operation *userOp : result.getUsers()) - if (forwardSlice->count(userOp) == 0) - getForwardSliceImpl(userOp, forwardSlice, filter); - } + getForwardSliceImpl(&blockOp, visited, forwardSlice, filter); + + for (Value result : op->getResults()) + for (Operation *userOp : result.getUsers()) { + if (forwardSlice->count(userOp) == 0 && visited.insert(userOp).second) + getForwardSliceImpl(userOp, visited, forwardSlice, filter); + + visited.erase(userOp); + } forwardSlice->insert(op); } void mlir::getForwardSlice(Operation *op, SetVector *forwardSlice, const ForwardSliceOptions &options) { - getForwardSliceImpl(op, forwardSlice, options.filter); + DenseSet visited; + getForwardSliceImpl(op, visited, forwardSlice, options.filter); if (!options.inclusive) { // Don't insert the top level operation, we just queried on it and don't // want it in the results. @@ -69,8 +74,9 @@ void mlir::getForwardSlice(Operation *op, SetVector *forwardSlice, void mlir::getForwardSlice(Value root, SetVector *forwardSlice, const SliceOptions &options) { + DenseSet visited; for (Operation *user : root.getUsers()) - getForwardSliceImpl(user, forwardSlice, options.filter); + getForwardSliceImpl(user, visited, forwardSlice, options.filter); // Reverse to get back the actual topological order. // std::reverse does not work out of the box on SetVector and I want an @@ -80,6 +86,7 @@ void mlir::getForwardSlice(Value root, SetVector *forwardSlice, } static LogicalResult getBackwardSliceImpl(Operation *op, + DenseSet &visited, SetVector *backwardSlice, const BackwardSliceOptions &options) { if (!op || op->hasTrait()) @@ -93,8 +100,12 @@ static LogicalResult getBackwardSliceImpl(Operation *op, auto processValue = [&](Value value) { if (auto *definingOp = value.getDefiningOp()) { - if (backwardSlice->count(definingOp) == 0) - return getBackwardSliceImpl(definingOp, backwardSlice, options); + if (backwardSlice->count(definingOp) == 0 && + visited.insert(definingOp).second) + return getBackwardSliceImpl(definingOp, visited, backwardSlice, + options); + + visited.erase(definingOp); } else if (auto blockArg = dyn_cast(value)) { if (options.omitBlockArguments) return success(); @@ -107,7 +118,8 @@ static LogicalResult getBackwardSliceImpl(Operation *op, if (parentOp && backwardSlice->count(parentOp) == 0) { if (parentOp->getNumRegions() == 1 && llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) { - return getBackwardSliceImpl(parentOp, backwardSlice, options); + return getBackwardSliceImpl(parentOp, visited, backwardSlice, + options); } } } else { @@ -145,7 +157,9 @@ static LogicalResult getBackwardSliceImpl(Operation *op, LogicalResult mlir::getBackwardSlice(Operation *op, SetVector *backwardSlice, const BackwardSliceOptions &options) { - LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options); + DenseSet visited; + LogicalResult result = + getBackwardSliceImpl(op, visited, backwardSlice, options); if (!options.inclusive) { // Don't insert the top level operation, we just queried on it and don't diff --git a/mlir/test/Dialect/Affine/slicing-utils.mlir b/mlir/test/Dialect/Affine/slicing-utils.mlir index 0848a924b9d96..c53667a98cfbe 100644 --- a/mlir/test/Dialect/Affine/slicing-utils.mlir +++ b/mlir/test/Dialect/Affine/slicing-utils.mlir @@ -292,3 +292,26 @@ func.func @slicing_test_multiple_return(%arg0: index) -> (index, index) { %0:2 = "slicing-test-op"(%arg0, %arg0): (index, index) -> (index, index) return %0#0, %0#1 : index, index } + +// ----- + +// FWD-LABEL: graph_region_with_cycle +// BWD-LABEL: graph_region_with_cycle +// FWDBWD-LABEL: graph_region_with_cycle +func.func @graph_region_with_cycle() { + test.isolated_graph_region { + // FWD: matched: [[V0:%.+]] = "slicing-test-op"([[V1:%.+]]) : (i1) -> i1 forward static slice: + // FWD: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 + // FWD: matched: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 forward static slice: + // FWD: [[V0]] = "slicing-test-op"([[V1]]) : (i1) -> i1 + + // BWD: matched: [[V0:%.+]] = "slicing-test-op"([[V1:%.+]]) : (i1) -> i1 backward static slice: + // BWD: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 + // BWD: matched: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 backward static slice: + // BWD: [[V0]] = "slicing-test-op"([[V1]]) : (i1) -> i1 + %0 = "slicing-test-op"(%1) : (i1) -> i1 + %1 = "slicing-test-op"(%0) : (i1) -> i1 + } + + return +} From 363a0930ba25831b29f7ba50c5b276211fe23e5e Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Tue, 13 May 2025 10:03:44 +0100 Subject: [PATCH 2/2] Add comments and missing inserts --- mlir/lib/Analysis/SliceAnalysis.cpp | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp index c50b652aaf283..991c71e3f689a 100644 --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -41,11 +41,27 @@ getForwardSliceImpl(Operation *op, DenseSet &visited, for (Region ®ion : op->getRegions()) for (Block &block : region) for (Operation &blockOp : block) - if (forwardSlice->count(&blockOp) == 0) + if (forwardSlice->count(&blockOp) == 0) { + // We don't have to check if the 'blockOp' is already visited because + // there cannot be a traversal path from this nested op to the parent + // and thus a cycle cannot be closed here. We still have to mark it + // as visited to stop before visiting this operation again if it is + // part of a cycle. + visited.insert(&blockOp); getForwardSliceImpl(&blockOp, visited, forwardSlice, filter); + visited.erase(&blockOp); + } for (Value result : op->getResults()) for (Operation *userOp : result.getUsers()) { + // A cycle can only occur within a basic block (not across regions or + // basic blocks) because the parent region must be a graph region, graph + // regions are restricted to always have 0 or 1 blocks, and there cannot + // be a def-use edge from a nested operation to an operation in an + // ancestor region. Therefore, we don't have to but may use the same + // 'visited' set across regions/blocks as long as we remove operations + // from the set again when the DFS traverses back from the leaf to the + // root. if (forwardSlice->count(userOp) == 0 && visited.insert(userOp).second) getForwardSliceImpl(userOp, visited, forwardSlice, filter); @@ -58,6 +74,7 @@ getForwardSliceImpl(Operation *op, DenseSet &visited, void mlir::getForwardSlice(Operation *op, SetVector *forwardSlice, const ForwardSliceOptions &options) { DenseSet visited; + visited.insert(op); getForwardSliceImpl(op, visited, forwardSlice, options.filter); if (!options.inclusive) { // Don't insert the top level operation, we just queried on it and don't @@ -75,8 +92,11 @@ void mlir::getForwardSlice(Operation *op, SetVector *forwardSlice, void mlir::getForwardSlice(Value root, SetVector *forwardSlice, const SliceOptions &options) { DenseSet visited; - for (Operation *user : root.getUsers()) + for (Operation *user : root.getUsers()) { + visited.insert(user); getForwardSliceImpl(user, visited, forwardSlice, options.filter); + visited.erase(user); + } // Reverse to get back the actual topological order. // std::reverse does not work out of the box on SetVector and I want an @@ -158,6 +178,7 @@ LogicalResult mlir::getBackwardSlice(Operation *op, SetVector *backwardSlice, const BackwardSliceOptions &options) { DenseSet visited; + visited.insert(op); LogicalResult result = getBackwardSliceImpl(op, visited, backwardSlice, options);