Skip to content

Commit 18624ae

Browse files
maerharttru
authored andcommitted
[mlir][SliceAnalysis] Fix stack overflow in graph regions (#139694)
This analysis currently just crashes when applied to a graph region that has a use-def cycle. This PR fixes that by keeping track of the operations the DFS has already visited when following use-def edges and stopping once we visit an operation again.
1 parent 6296ebd commit 18624ae

File tree

3 files changed

+79
-19
lines changed

3 files changed

+79
-19
lines changed

mlir/include/mlir/Analysis/SliceAnalysis.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ using ForwardSliceOptions = SliceOptions;
6565
///
6666
/// The implementation traverses the use chains in postorder traversal for
6767
/// efficiency reasons: if an operation is already in `forwardSlice`, no
68-
/// need to traverse its uses again. Since use-def chains form a DAG, this
69-
/// terminates.
68+
/// need to traverse its uses again. In the presence of use-def cycles in a
69+
/// graph region, the traversal stops at the first operation that was already
70+
/// visited (which is not added to the slice anymore).
7071
///
7172
/// Upon return to the root call, `forwardSlice` is filled with a
7273
/// postorder list of uses (i.e. a reverse topological order). To get a proper
@@ -114,8 +115,9 @@ void getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
114115
///
115116
/// The implementation traverses the def chains in postorder traversal for
116117
/// efficiency reasons: if an operation is already in `backwardSlice`, no
117-
/// need to traverse its definitions again. Since useuse-def chains form a DAG,
118-
/// this terminates.
118+
/// need to traverse its definitions again. In the presence of use-def cycles
119+
/// in a graph region, the traversal stops at the first operation that was
120+
/// already visited (which is not added to the slice anymore).
119121
///
120122
/// Upon return to the root call, `backwardSlice` is filled with a
121123
/// postorder list of defs. This happens to be a topological order, from the

mlir/lib/Analysis/SliceAnalysis.cpp

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
using namespace mlir;
2727

2828
static void
29-
getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
29+
getForwardSliceImpl(Operation *op, DenseSet<Operation *> &visited,
30+
SetVector<Operation *> *forwardSlice,
3031
const SliceOptions::TransitiveFilter &filter = nullptr) {
3132
if (!op)
3233
return;
@@ -40,20 +41,41 @@ getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
4041
for (Region &region : op->getRegions())
4142
for (Block &block : region)
4243
for (Operation &blockOp : block)
43-
if (forwardSlice->count(&blockOp) == 0)
44-
getForwardSliceImpl(&blockOp, forwardSlice, filter);
45-
for (Value result : op->getResults()) {
46-
for (Operation *userOp : result.getUsers())
47-
if (forwardSlice->count(userOp) == 0)
48-
getForwardSliceImpl(userOp, forwardSlice, filter);
49-
}
44+
if (forwardSlice->count(&blockOp) == 0) {
45+
// We don't have to check if the 'blockOp' is already visited because
46+
// there cannot be a traversal path from this nested op to the parent
47+
// and thus a cycle cannot be closed here. We still have to mark it
48+
// as visited to stop before visiting this operation again if it is
49+
// part of a cycle.
50+
visited.insert(&blockOp);
51+
getForwardSliceImpl(&blockOp, visited, forwardSlice, filter);
52+
visited.erase(&blockOp);
53+
}
54+
55+
for (Value result : op->getResults())
56+
for (Operation *userOp : result.getUsers()) {
57+
// A cycle can only occur within a basic block (not across regions or
58+
// basic blocks) because the parent region must be a graph region, graph
59+
// regions are restricted to always have 0 or 1 blocks, and there cannot
60+
// be a def-use edge from a nested operation to an operation in an
61+
// ancestor region. Therefore, we don't have to but may use the same
62+
// 'visited' set across regions/blocks as long as we remove operations
63+
// from the set again when the DFS traverses back from the leaf to the
64+
// root.
65+
if (forwardSlice->count(userOp) == 0 && visited.insert(userOp).second)
66+
getForwardSliceImpl(userOp, visited, forwardSlice, filter);
67+
68+
visited.erase(userOp);
69+
}
5070

5171
forwardSlice->insert(op);
5272
}
5373

5474
void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
5575
const ForwardSliceOptions &options) {
56-
getForwardSliceImpl(op, forwardSlice, options.filter);
76+
DenseSet<Operation *> visited;
77+
visited.insert(op);
78+
getForwardSliceImpl(op, visited, forwardSlice, options.filter);
5779
if (!options.inclusive) {
5880
// Don't insert the top level operation, we just queried on it and don't
5981
// want it in the results.
@@ -69,8 +91,12 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
6991

7092
void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
7193
const SliceOptions &options) {
72-
for (Operation *user : root.getUsers())
73-
getForwardSliceImpl(user, forwardSlice, options.filter);
94+
DenseSet<Operation *> visited;
95+
for (Operation *user : root.getUsers()) {
96+
visited.insert(user);
97+
getForwardSliceImpl(user, visited, forwardSlice, options.filter);
98+
visited.erase(user);
99+
}
74100

75101
// Reverse to get back the actual topological order.
76102
// std::reverse does not work out of the box on SetVector and I want an
@@ -80,6 +106,7 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
80106
}
81107

82108
static LogicalResult getBackwardSliceImpl(Operation *op,
109+
DenseSet<Operation *> &visited,
83110
SetVector<Operation *> *backwardSlice,
84111
const BackwardSliceOptions &options) {
85112
if (!op || op->hasTrait<OpTrait::IsIsolatedFromAbove>())
@@ -93,8 +120,12 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
93120

94121
auto processValue = [&](Value value) {
95122
if (auto *definingOp = value.getDefiningOp()) {
96-
if (backwardSlice->count(definingOp) == 0)
97-
return getBackwardSliceImpl(definingOp, backwardSlice, options);
123+
if (backwardSlice->count(definingOp) == 0 &&
124+
visited.insert(definingOp).second)
125+
return getBackwardSliceImpl(definingOp, visited, backwardSlice,
126+
options);
127+
128+
visited.erase(definingOp);
98129
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
99130
if (options.omitBlockArguments)
100131
return success();
@@ -107,7 +138,8 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
107138
if (parentOp && backwardSlice->count(parentOp) == 0) {
108139
if (parentOp->getNumRegions() == 1 &&
109140
llvm::hasSingleElement(parentOp->getRegion(0).getBlocks())) {
110-
return getBackwardSliceImpl(parentOp, backwardSlice, options);
141+
return getBackwardSliceImpl(parentOp, visited, backwardSlice,
142+
options);
111143
}
112144
}
113145
} else {
@@ -145,7 +177,10 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
145177
LogicalResult mlir::getBackwardSlice(Operation *op,
146178
SetVector<Operation *> *backwardSlice,
147179
const BackwardSliceOptions &options) {
148-
LogicalResult result = getBackwardSliceImpl(op, backwardSlice, options);
180+
DenseSet<Operation *> visited;
181+
visited.insert(op);
182+
LogicalResult result =
183+
getBackwardSliceImpl(op, visited, backwardSlice, options);
149184

150185
if (!options.inclusive) {
151186
// Don't insert the top level operation, we just queried on it and don't

mlir/test/Dialect/Affine/slicing-utils.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,26 @@ func.func @slicing_test_multiple_return(%arg0: index) -> (index, index) {
292292
%0:2 = "slicing-test-op"(%arg0, %arg0): (index, index) -> (index, index)
293293
return %0#0, %0#1 : index, index
294294
}
295+
296+
// -----
297+
298+
// FWD-LABEL: graph_region_with_cycle
299+
// BWD-LABEL: graph_region_with_cycle
300+
// FWDBWD-LABEL: graph_region_with_cycle
301+
func.func @graph_region_with_cycle() {
302+
test.isolated_graph_region {
303+
// FWD: matched: [[V0:%.+]] = "slicing-test-op"([[V1:%.+]]) : (i1) -> i1 forward static slice:
304+
// FWD: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1
305+
// FWD: matched: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 forward static slice:
306+
// FWD: [[V0]] = "slicing-test-op"([[V1]]) : (i1) -> i1
307+
308+
// BWD: matched: [[V0:%.+]] = "slicing-test-op"([[V1:%.+]]) : (i1) -> i1 backward static slice:
309+
// BWD: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1
310+
// BWD: matched: [[V1]] = "slicing-test-op"([[V0]]) : (i1) -> i1 backward static slice:
311+
// BWD: [[V0]] = "slicing-test-op"([[V1]]) : (i1) -> i1
312+
%0 = "slicing-test-op"(%1) : (i1) -> i1
313+
%1 = "slicing-test-op"(%0) : (i1) -> i1
314+
}
315+
316+
return
317+
}

0 commit comments

Comments
 (0)