10
10
#include " mlir/Interfaces/SideEffectInterfaces.h"
11
11
#include " llvm/ADT/SmallPtrSet.h"
12
12
#include " llvm/Support/Debug.h"
13
+ #include < queue>
13
14
14
15
using namespace mlir ;
15
16
@@ -26,75 +27,88 @@ using namespace mlir;
26
27
// LoopLike Utilities
27
28
// ===----------------------------------------------------------------------===//
28
29
29
- // Checks whether the given op can be hoisted by checking that
30
- // - the op and any of its contained operations do not depend on SSA values
31
- // defined inside of the loop (by means of calling definedOutside).
32
- // - the op has no side-effects. If sideEffecting is Never, sideeffects of this
33
- // op and its nested ops are ignored.
34
- static bool canBeHoisted (Operation *op,
35
- function_ref<bool (Value)> definedOutside) {
36
- // Check that dependencies are defined outside of loop.
37
- if (!llvm::all_of (op->getOperands (), definedOutside))
38
- return false ;
39
- // Check whether this op is side-effect free. If we already know that there
40
- // can be no side-effects because the surrounding op has claimed so, we can
41
- // (and have to) skip this step.
30
+ // / Returns true if the given operation is side-effect free as are all of its
31
+ // / nested operations.
32
+ // /
33
+ // / TODO: There is a duplicate function in ControlFlowSink. Move
34
+ // / `moveLoopInvariantCode` to TransformUtils and then factor out this function.
35
+ static bool isSideEffectFree (Operation *op) {
42
36
if (auto memInterface = dyn_cast<MemoryEffectOpInterface>(op)) {
37
+ // If the op has side-effects, it cannot be moved.
43
38
if (!memInterface.hasNoEffect ())
44
39
return false ;
45
- // If the operation doesn't have side effects and it doesn't recursively
46
- // have side effects, it can always be hoisted.
40
+ // If the op does not have recursive side effects, then it can be moved.
47
41
if (!op->hasTrait <OpTrait::HasRecursiveSideEffects>())
48
42
return true ;
49
-
50
- // Otherwise, if the operation doesn't provide the memory effect interface
51
- // and it doesn't have recursive side effects we treat it conservatively as
52
- // side-effecting.
53
43
} else if (!op->hasTrait <OpTrait::HasRecursiveSideEffects>()) {
44
+ // Otherwise, if the op does not implement the memory effect interface and
45
+ // it does not have recursive side effects, then it cannot be known that the
46
+ // op is moveable.
54
47
return false ;
55
48
}
56
49
57
- // Recurse into the regions for this op and check whether the contained ops
58
- // can be hoisted.
59
- for (auto ®ion : op->getRegions ()) {
60
- for (auto &block : region) {
61
- for (auto &innerOp : block)
62
- if (!canBeHoisted (&innerOp, definedOutside))
63
- return false ;
64
- }
65
- }
50
+ // Recurse into the regions and ensure that all nested ops can also be moved.
51
+ for (Region ®ion : op->getRegions ())
52
+ for (Operation &op : region.getOps ())
53
+ if (!isSideEffectFree (&op))
54
+ return false ;
66
55
return true ;
67
56
}
68
57
58
+ // / Checks whether the given op can be hoisted by checking that
59
+ // / - the op and none of its contained operations depend on values inside of the
60
+ // / loop (by means of calling definedOutside).
61
+ // / - the op has no side-effects.
62
+ static bool canBeHoisted (Operation *op,
63
+ function_ref<bool (Value)> definedOutside) {
64
+ if (!isSideEffectFree (op))
65
+ return false ;
66
+
67
+ // Do not move terminators.
68
+ if (op->hasTrait <OpTrait::IsTerminator>())
69
+ return false ;
70
+
71
+ // Walk the nested operations and check that all used values are either
72
+ // defined outside of the loop or in a nested region, but not at the level of
73
+ // the loop body.
74
+ auto walkFn = [&](Operation *child) {
75
+ for (Value operand : child->getOperands ()) {
76
+ // Ignore values defined in a nested region.
77
+ if (op->isAncestor (operand.getParentRegion ()->getParentOp ()))
78
+ continue ;
79
+ if (!definedOutside (operand))
80
+ return WalkResult::interrupt ();
81
+ }
82
+ return WalkResult::advance ();
83
+ };
84
+ return !op->walk (walkFn).wasInterrupted ();
85
+ }
86
+
69
87
void mlir::moveLoopInvariantCode (LoopLikeOpInterface looplike) {
70
- auto &loopBody = looplike.getLoopBody ();
71
-
72
- // We use two collections here as we need to preserve the order for insertion
73
- // and this is easiest.
74
- SmallPtrSet<Operation *, 8 > willBeMovedSet;
75
- SmallVector<Operation *, 8 > opsToMove;
76
-
77
- // Helper to check whether an operation is loop invariant wrt. SSA properties.
78
- auto isDefinedOutsideOfBody = [&](Value value) {
79
- auto *definingOp = value.getDefiningOp ();
80
- return (definingOp && !!willBeMovedSet.count (definingOp)) ||
81
- looplike.isDefinedOutsideOfLoop (value);
88
+ Region *loopBody = &looplike.getLoopBody ();
89
+
90
+ std::queue<Operation *> worklist;
91
+ // Add top-level operations in the loop body to the worklist.
92
+ for (Operation &op : loopBody->getOps ())
93
+ worklist.push (&op);
94
+
95
+ auto definedOutside = [&](Value value) {
96
+ return looplike.isDefinedOutsideOfLoop (value);
82
97
};
83
98
84
- // Do not use walk here, as we do not want to go into nested regions and hoist
85
- // operations from there. These regions might have semantics unknown to this
86
- // rewriting. If the nested regions are loops, they will have been processed.
87
- for (auto &block : loopBody) {
88
- for (auto &op : block.without_terminator ()) {
89
- if (canBeHoisted (&op, isDefinedOutsideOfBody)) {
90
- opsToMove.push_back (&op);
91
- willBeMovedSet.insert (&op);
92
- }
93
- }
94
- }
99
+ while (!worklist.empty ()) {
100
+ Operation *op = worklist.front ();
101
+ worklist.pop ();
102
+ // Skip ops that have already been moved. Check if the op can be hoisted.
103
+ if (op->getParentRegion () != loopBody || !canBeHoisted (op, definedOutside))
104
+ continue ;
95
105
96
- // For all instructions that we found to be invariant, move outside of the
97
- // loop.
98
- for (Operation *op : opsToMove)
99
106
looplike.moveOutOfLoop (op);
107
+
108
+ // Since the op has been moved, we need to check its users within the
109
+ // top-level of the loop body.
110
+ for (Operation *user : op->getUsers ())
111
+ if (user->getParentRegion () == loopBody)
112
+ worklist.push (user);
113
+ }
100
114
}
0 commit comments