26
26
using namespace mlir ;
27
27
28
28
static void
29
- getForwardSliceImpl (Operation *op, SetVector<Operation *> *forwardSlice,
29
+ getForwardSliceImpl (Operation *op, DenseSet<Operation *> &visited,
30
+ SetVector<Operation *> *forwardSlice,
30
31
const SliceOptions::TransitiveFilter &filter = nullptr ) {
31
32
if (!op)
32
33
return ;
@@ -41,19 +42,23 @@ getForwardSliceImpl(Operation *op, SetVector<Operation *> *forwardSlice,
41
42
for (Block &block : region)
42
43
for (Operation &blockOp : block)
43
44
if (forwardSlice->count (&blockOp) == 0 )
44
- getForwardSliceImpl (&blockOp, forwardSlice, filter);
45
- for (Value result : op->getResults ()) {
46
- for (Operation *userOp : result.getUsers ())
47
- if (forwardSlice->count (userOp) == 0 )
48
- getForwardSliceImpl (userOp, forwardSlice, filter);
49
- }
45
+ getForwardSliceImpl (&blockOp, visited, forwardSlice, filter);
46
+
47
+ for (Value result : op->getResults ())
48
+ for (Operation *userOp : result.getUsers ()) {
49
+ if (forwardSlice->count (userOp) == 0 && visited.insert (userOp).second )
50
+ getForwardSliceImpl (userOp, visited, forwardSlice, filter);
51
+
52
+ visited.erase (userOp);
53
+ }
50
54
51
55
forwardSlice->insert (op);
52
56
}
53
57
54
58
void mlir::getForwardSlice (Operation *op, SetVector<Operation *> *forwardSlice,
55
59
const ForwardSliceOptions &options) {
56
- getForwardSliceImpl (op, forwardSlice, options.filter );
60
+ DenseSet<Operation *> visited;
61
+ getForwardSliceImpl (op, visited, forwardSlice, options.filter );
57
62
if (!options.inclusive ) {
58
63
// Don't insert the top level operation, we just queried on it and don't
59
64
// want it in the results.
@@ -69,8 +74,9 @@ void mlir::getForwardSlice(Operation *op, SetVector<Operation *> *forwardSlice,
69
74
70
75
void mlir::getForwardSlice (Value root, SetVector<Operation *> *forwardSlice,
71
76
const SliceOptions &options) {
77
+ DenseSet<Operation *> visited;
72
78
for (Operation *user : root.getUsers ())
73
- getForwardSliceImpl (user, forwardSlice, options.filter );
79
+ getForwardSliceImpl (user, visited, forwardSlice, options.filter );
74
80
75
81
// Reverse to get back the actual topological order.
76
82
// std::reverse does not work out of the box on SetVector and I want an
@@ -80,6 +86,7 @@ void mlir::getForwardSlice(Value root, SetVector<Operation *> *forwardSlice,
80
86
}
81
87
82
88
static LogicalResult getBackwardSliceImpl (Operation *op,
89
+ DenseSet<Operation *> &visited,
83
90
SetVector<Operation *> *backwardSlice,
84
91
const BackwardSliceOptions &options) {
85
92
if (!op || op->hasTrait <OpTrait::IsIsolatedFromAbove>())
@@ -93,8 +100,12 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
93
100
94
101
auto processValue = [&](Value value) {
95
102
if (auto *definingOp = value.getDefiningOp ()) {
96
- if (backwardSlice->count (definingOp) == 0 )
97
- return getBackwardSliceImpl (definingOp, backwardSlice, options);
103
+ if (backwardSlice->count (definingOp) == 0 &&
104
+ visited.insert (definingOp).second )
105
+ return getBackwardSliceImpl (definingOp, visited, backwardSlice,
106
+ options);
107
+
108
+ visited.erase (definingOp);
98
109
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
99
110
if (options.omitBlockArguments )
100
111
return success ();
@@ -107,7 +118,8 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
107
118
if (parentOp && backwardSlice->count (parentOp) == 0 ) {
108
119
if (parentOp->getNumRegions () == 1 &&
109
120
llvm::hasSingleElement (parentOp->getRegion (0 ).getBlocks ())) {
110
- return getBackwardSliceImpl (parentOp, backwardSlice, options);
121
+ return getBackwardSliceImpl (parentOp, visited, backwardSlice,
122
+ options);
111
123
}
112
124
}
113
125
} else {
@@ -145,7 +157,9 @@ static LogicalResult getBackwardSliceImpl(Operation *op,
145
157
LogicalResult mlir::getBackwardSlice (Operation *op,
146
158
SetVector<Operation *> *backwardSlice,
147
159
const BackwardSliceOptions &options) {
148
- LogicalResult result = getBackwardSliceImpl (op, backwardSlice, options);
160
+ DenseSet<Operation *> visited;
161
+ LogicalResult result =
162
+ getBackwardSliceImpl (op, visited, backwardSlice, options);
149
163
150
164
if (!options.inclusive ) {
151
165
// Don't insert the top level operation, we just queried on it and don't
0 commit comments