26
26
using namespace mlir ;
27
27
28
28
static void
29
- getForwardSliceImpl (Operation *op, SetVector<Operation *> *forwardSlice,
29
+ getForwardSliceImpl (Operation *op, DenseSet<Operation *> &visited,
30
+ SetVector<Operation *> *forwardSlice,
30
31
const SliceOptions::TransitiveFilter &filter = nullptr ) {
31
32
if (!op)
32
33
return ;
@@ -40,20 +41,41 @@ getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
40
41
for (Region ®ion : op->getRegions ())
41
42
for (Block &block : region)
42
43
for (Operation &blockOp : block)
43
- if (forwardSlice->count (&blockOp) == 0 )
44
- getForwardSliceImpl (&blockOp, forwardSlice, filter);
45
- for (Value result : op->getResults ()) {
46
- for (Operation *userOp : result.getUsers ())
47
- if (forwardSlice->count (userOp) == 0 )
48
- getForwardSliceImpl (userOp, forwardSlice, filter);
49
- }
44
+ if (forwardSlice->count (&blockOp) == 0 ) {
45
+ // We don't have to check if the 'blockOp' is already visited because
46
+ // there cannot be a traversal path from this nested op to the parent
47
+ // and thus a cycle cannot be closed here. We still have to mark it
48
+ // as visited to stop before visiting this operation again if it is
49
+ // part of a cycle.
50
+ visited.insert (&blockOp);
51
+ getForwardSliceImpl (&blockOp, visited, forwardSlice, filter);
52
+ visited.erase (&blockOp);
53
+ }
54
+
55
+ for (Value result : op->getResults ())
56
+ for (Operation *userOp : result.getUsers ()) {
57
+ // A cycle can only occur within a basic block (not across regions or
58
+ // basic blocks) because the parent region must be a graph region, graph
59
+ // regions are restricted to always have 0 or 1 blocks, and there cannot
60
+ // be a def-use edge from a nested operation to an operation in an
61
+ // ancestor region. Therefore, we don't have to but may use the same
62
+ // 'visited' set across regions/blocks as long as we remove operations
63
+ // from the set again when the DFS traverses back from the leaf to the
64
+ // root.
65
+ if (forwardSlice->count (userOp) == 0 && visited.insert (userOp).second )
66
+ getForwardSliceImpl (userOp, visited, forwardSlice, filter);
67
+
68
+ visited.erase (userOp);
69
+ }
50
70
51
71
forwardSlice->insert (op);
52
72
}
53
73
54
74
void mlir::getForwardSlice (Operation *op, SetVector<Operation *> *forwardSlice,
55
75
const ForwardSliceOptions &options) {
56
- getForwardSliceImpl (op, forwardSlice, options.filter );
76
+ DenseSet<Operation *> visited;
77
+ visited.insert (op);
78
+ getForwardSliceImpl (op, visited, forwardSlice, options.filter );
57
79
if (!options.inclusive ) {
58
80
// Don't insert the top level operation, we just queried on it and don't
59
81
// want it in the results.
@@ -69,8 +91,12 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
69
91
70
92
void mlir::getForwardSlice (Value root, SetVector<Operation *> *forwardSlice,
71
93
const SliceOptions &options) {
72
- for (Operation *user : root.getUsers ())
73
- getForwardSliceImpl (user, forwardSlice, options.filter );
94
+ DenseSet<Operation *> visited;
95
+ for (Operation *user : root.getUsers ()) {
96
+ visited.insert (user);
97
+ getForwardSliceImpl (user, visited, forwardSlice, options.filter );
98
+ visited.erase (user);
99
+ }
74
100
75
101
// Reverse to get back the actual topological order.
76
102
// std::reverse does not work out of the box on SetVector and I want an
@@ -80,6 +106,7 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
80
106
}
81
107
82
108
static LogicalResult getBackwardSliceImpl (Operation *op,
109
+ DenseSet<Operation *> &visited,
83
110
SetVector<Operation *> *backwardSlice,
84
111
const BackwardSliceOptions &options) {
85
112
if (!op || op->hasTrait <OpTrait::IsIsolatedFromAbove>())
@@ -93,8 +120,12 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
93
120
94
121
auto processValue = [&](Value value) {
95
122
if (auto *definingOp = value.getDefiningOp ()) {
96
- if (backwardSlice->count (definingOp) == 0 )
97
- return getBackwardSliceImpl (definingOp, backwardSlice, options);
123
+ if (backwardSlice->count (definingOp) == 0 &&
124
+ visited.insert (definingOp).second )
125
+ return getBackwardSliceImpl (definingOp, visited, backwardSlice,
126
+ options);
127
+
128
+ visited.erase (definingOp);
98
129
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
99
130
if (options.omitBlockArguments )
100
131
return success ();
@@ -107,7 +138,8 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
107
138
if (parentOp && backwardSlice->count (parentOp) == 0 ) {
108
139
if (parentOp->getNumRegions () == 1 &&
109
140
llvm::hasSingleElement (parentOp->getRegion (0 ).getBlocks ())) {
110
- return getBackwardSliceImpl (parentOp, backwardSlice, options);
141
+ return getBackwardSliceImpl (parentOp, visited, backwardSlice,
142
+ options);
111
143
}
112
144
}
113
145
} else {
@@ -145,7 +177,10 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
145
177
LogicalResult mlir::getBackwardSlice (Operation *op,
146
178
SetVector<Operation *> *backwardSlice,
147
179
const BackwardSliceOptions &options) {
148
- LogicalResult result = getBackwardSliceImpl (op, backwardSlice, options);
180
+ DenseSet<Operation *> visited;
181
+ visited.insert (op);
182
+ LogicalResult result =
183
+ getBackwardSliceImpl (op, visited, backwardSlice, options);
149
184
150
185
if (!options.inclusive ) {
151
186
// Don't insert the top level operation, we just queried on it and don't
0 commit comments