Skip to content

Commit b3c293c

Browse files
authored
[LoopInterchange] Drop nuw/nsw flags from reduction ops when interchanging (llvm#148612)
Before this patch, when a reduction exists in the loop, the legality check of LoopInterchange only verified if there exists a non-reassociative floating-point instruction in the reduction calculation. However, it is insufficient, because reordering integer reductions can also lead to incorrect transformations. Consider the following example: ```c int A[2][2] = { { INT_MAX, INT_MAX }, { INT_MIN, INT_MIN }, }; int sum = 0; for (int i = 0; i < 2; i++) for (int j = 0; j < 2; j++) sum += A[j][i]; ``` To make this exchange legal, we must drop nuw/nsw flags from the instructions involved in the reduction operations. This patch extends the legality check to correctly handle such cases. In particular, for integer addition and multiplication, it verifies that the nsw and nuw flags are set on involved instructions, and drop them when the transformation actually performed. This patch also introduces explicit checks for the kind of reduction and permits only those that are known to be safe for interchange. Consequently, some "unknown" reductions (at the moment, `FindFirst*` and `FindLast*`) are rejected. Fix llvm#148228
1 parent 9eb0fc8 commit b3c293c

File tree

3 files changed

+959
-6
lines changed

3 files changed

+959
-6
lines changed

llvm/lib/Transforms/Scalar/LoopInterchange.cpp

Lines changed: 84 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,10 @@ class LoopInterchangeLegality {
379379
return InnerLoopInductions;
380380
}
381381

382+
ArrayRef<Instruction *> getHasNoWrapReductions() const {
383+
return HasNoWrapReductions;
384+
}
385+
382386
private:
383387
bool tightlyNested(Loop *Outer, Loop *Inner);
384388
bool containsUnsafeInstructions(BasicBlock *BB);
@@ -405,6 +409,11 @@ class LoopInterchangeLegality {
405409

406410
/// Set of inner loop induction PHIs
407411
SmallVector<PHINode *, 8> InnerLoopInductions;
412+
413+
/// Hold instructions that have nuw/nsw flags and involved in reductions,
414+
/// like integer addition/multiplication. Those flags must be dropped when
415+
/// interchanging the loops.
416+
SmallVector<Instruction *, 4> HasNoWrapReductions;
408417
};
409418

410419
/// Manages information utilized by the profitability check for cache. The main
@@ -473,7 +482,7 @@ class LoopInterchangeTransform {
473482
: OuterLoop(Outer), InnerLoop(Inner), SE(SE), LI(LI), DT(DT), LIL(LIL) {}
474483

475484
/// Interchange OuterLoop and InnerLoop.
476-
bool transform();
485+
bool transform(ArrayRef<Instruction *> DropNoWrapInsts);
477486
void restructureLoops(Loop *NewInner, Loop *NewOuter,
478487
BasicBlock *OrigInnerPreHeader,
479488
BasicBlock *OrigOuterPreHeader);
@@ -613,7 +622,7 @@ struct LoopInterchange {
613622
});
614623

615624
LoopInterchangeTransform LIT(OuterLoop, InnerLoop, SE, LI, DT, LIL);
616-
LIT.transform();
625+
LIT.transform(LIL.getHasNoWrapReductions());
617626
LLVM_DEBUG(dbgs() << "Loops interchanged.\n");
618627
LoopsInterchanged++;
619628

@@ -798,7 +807,9 @@ static Value *followLCSSA(Value *SV) {
798807
}
799808

800809
// Check V's users to see if it is involved in a reduction in L.
801-
static PHINode *findInnerReductionPhi(Loop *L, Value *V) {
810+
static PHINode *
811+
findInnerReductionPhi(Loop *L, Value *V,
812+
SmallVectorImpl<Instruction *> &HasNoWrapInsts) {
802813
// Reduction variables cannot be constants.
803814
if (isa<Constant>(V))
804815
return nullptr;
@@ -812,7 +823,65 @@ static PHINode *findInnerReductionPhi(Loop *L, Value *V) {
812823
// Detect floating point reduction only when it can be reordered.
813824
if (RD.getExactFPMathInst() != nullptr)
814825
return nullptr;
815-
return PHI;
826+
827+
RecurKind RK = RD.getRecurrenceKind();
828+
switch (RK) {
829+
case RecurKind::Or:
830+
case RecurKind::And:
831+
case RecurKind::Xor:
832+
case RecurKind::SMin:
833+
case RecurKind::SMax:
834+
case RecurKind::UMin:
835+
case RecurKind::UMax:
836+
case RecurKind::FAdd:
837+
case RecurKind::FMul:
838+
case RecurKind::FMin:
839+
case RecurKind::FMax:
840+
case RecurKind::FMinimum:
841+
case RecurKind::FMaximum:
842+
case RecurKind::FMinimumNum:
843+
case RecurKind::FMaximumNum:
844+
case RecurKind::FMulAdd:
845+
case RecurKind::AnyOf:
846+
return PHI;
847+
848+
// Change the order of integer addition/multiplication may change the
849+
// semantics. Consider the following case:
850+
//
851+
// int A[2][2] = {{ INT_MAX, INT_MAX }, { INT_MIN, INT_MIN }};
852+
// int sum = 0;
853+
// for (int i = 0; i < 2; i++)
854+
// for (int j = 0; j < 2; j++)
855+
// sum += A[j][i];
856+
//
857+
// If the above loops are exchanged, the addition will cause an
858+
// overflow. To prevent this, we must drop the nuw/nsw flags from the
859+
// addition/multiplication instructions when we actually exchanges the
860+
// loops.
861+
case RecurKind::Add:
862+
case RecurKind::Mul: {
863+
unsigned OpCode = RecurrenceDescriptor::getOpcode(RK);
864+
SmallVector<Instruction *, 4> Ops = RD.getReductionOpChain(PHI, L);
865+
866+
// Bail out when we fail to collect reduction instructions chain.
867+
if (Ops.empty())
868+
return nullptr;
869+
870+
for (Instruction *I : Ops) {
871+
assert(I->getOpcode() == OpCode &&
872+
"Expected the instruction to be the reduction operation");
873+
874+
// If the instruction has nuw/nsw flags, we must drop them when the
875+
// transformation is actually performed.
876+
if (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())
877+
HasNoWrapInsts.push_back(I);
878+
}
879+
return PHI;
880+
}
881+
882+
default:
883+
return nullptr;
884+
}
816885
}
817886
return nullptr;
818887
}
@@ -844,7 +913,8 @@ bool LoopInterchangeLegality::findInductionAndReductions(
844913
// Check if we have a PHI node in the outer loop that has a reduction
845914
// result from the inner loop as an incoming value.
846915
Value *V = followLCSSA(PHI.getIncomingValueForBlock(L->getLoopLatch()));
847-
PHINode *InnerRedPhi = findInnerReductionPhi(InnerLoop, V);
916+
PHINode *InnerRedPhi =
917+
findInnerReductionPhi(InnerLoop, V, HasNoWrapReductions);
848918
if (!InnerRedPhi ||
849919
!llvm::is_contained(InnerRedPhi->incoming_values(), &PHI)) {
850920
LLVM_DEBUG(
@@ -1430,7 +1500,8 @@ void LoopInterchangeTransform::restructureLoops(
14301500
SE->forgetLoop(NewOuter);
14311501
}
14321502

1433-
bool LoopInterchangeTransform::transform() {
1503+
bool LoopInterchangeTransform::transform(
1504+
ArrayRef<Instruction *> DropNoWrapInsts) {
14341505
bool Transformed = false;
14351506

14361507
if (InnerLoop->getSubLoops().empty()) {
@@ -1531,6 +1602,13 @@ bool LoopInterchangeTransform::transform() {
15311602
return false;
15321603
}
15331604

1605+
// Finally, drop the nsw/nuw flags from the instructions for reduction
1606+
// calculations.
1607+
for (Instruction *Reduction : DropNoWrapInsts) {
1608+
Reduction->setHasNoSignedWrap(false);
1609+
Reduction->setHasNoUnsignedWrap(false);
1610+
}
1611+
15341612
return true;
15351613
}
15361614

0 commit comments

Comments
 (0)