Skip to content

Commit 57e87dd

Browse files
committed
[ARM][LowOverheadLoops] Fix branch target codegen
While lowering test.set.loop.iterations, it wasn't checked how the brcond was using the result and so the wls could branch to the loop preheader instead of not entering it. The same was true for loop.decrement.reg. So brcond and br_cc and now lowered manually when using the hwloop intrinsics. During this we now check whether the result has been negated and whether we're using SETEQ or SETNE and 0 or 1. We can then figure out which basic block the WLS and LE should be targeting. Differential Revision: https://reviews.llvm.org/D64616 llvm-svn: 366809
1 parent c60c12f commit 57e87dd

File tree

5 files changed

+696
-37
lines changed

5 files changed

+696
-37
lines changed

llvm/lib/Target/ARM/ARMISelDAGToDAG.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2998,13 +2998,26 @@ void ARMDAGToDAGISel::Select(SDNode *N) {
29982998
// Other cases are autogenerated.
29992999
break;
30003000
}
3001-
case ARMISD::WLS: {
3002-
SDValue Ops[] = { N->getOperand(1), // Loop count
3003-
N->getOperand(2), // Exit target
3001+
case ARMISD::WLS:
3002+
case ARMISD::LE: {
3003+
SDValue Ops[] = { N->getOperand(1),
3004+
N->getOperand(2),
30043005
N->getOperand(0) };
3005-
SDNode *LoopStart =
3006-
CurDAG->getMachineNode(ARM::t2WhileLoopStart, dl, MVT::Other, Ops);
3007-
ReplaceUses(N, LoopStart);
3006+
unsigned Opc = N->getOpcode() == ARMISD::WLS ?
3007+
ARM::t2WhileLoopStart : ARM::t2LoopEnd;
3008+
SDNode *New = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops);
3009+
ReplaceUses(N, New);
3010+
CurDAG->RemoveDeadNode(N);
3011+
return;
3012+
}
3013+
case ARMISD::LOOP_DEC: {
3014+
SDValue Ops[] = { N->getOperand(1),
3015+
N->getOperand(2),
3016+
N->getOperand(0) };
3017+
SDNode *Dec =
3018+
CurDAG->getMachineNode(ARM::t2LoopDec, dl,
3019+
CurDAG->getVTList(MVT::i32, MVT::Other), Ops);
3020+
ReplaceUses(N, Dec);
30083021
CurDAG->RemoveDeadNode(N);
30093022
return;
30103023
}

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 157 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -669,8 +669,10 @@ ARMTargetLowering::ARMTargetLowering(const TargetMachine &TM,
669669
addMVEVectorTypes(Subtarget->hasMVEFloatOps());
670670

671671
// Combine low-overhead loop intrinsics so that we can lower i1 types.
672-
if (Subtarget->hasLOB())
672+
if (Subtarget->hasLOB()) {
673673
setTargetDAGCombine(ISD::BRCOND);
674+
setTargetDAGCombine(ISD::BR_CC);
675+
}
674676

675677
if (Subtarget->hasNEON()) {
676678
addDRTypeForNEON(MVT::v2f32);
@@ -1589,6 +1591,8 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
15891591
case ARMISD::VST3LN_UPD: return "ARMISD::VST3LN_UPD";
15901592
case ARMISD::VST4LN_UPD: return "ARMISD::VST4LN_UPD";
15911593
case ARMISD::WLS: return "ARMISD::WLS";
1594+
case ARMISD::LE: return "ARMISD::LE";
1595+
case ARMISD::LOOP_DEC: return "ARMISD::LOOP_DEC";
15921596
}
15931597
return nullptr;
15941598
}
@@ -13034,43 +13038,169 @@ SDValue ARMTargetLowering::PerformCMOVToBFICombine(SDNode *CMOV, SelectionDAG &D
1303413038
return V;
1303513039
}
1303613040

13041+
// Given N, the value controlling the conditional branch, search for the loop
13042+
// intrinsic, returning it, along with how the value is used. We need to handle
13043+
// patterns such as the following:
13044+
// (brcond (xor (setcc (loop.decrement), 0, ne), 1), exit)
13045+
// (brcond (setcc (loop.decrement), 0, eq), exit)
13046+
// (brcond (setcc (loop.decrement), 0, ne), header)
13047+
static SDValue SearchLoopIntrinsic(SDValue N, ISD::CondCode &CC, int &Imm,
13048+
bool &Negate) {
13049+
switch (N->getOpcode()) {
13050+
default:
13051+
break;
13052+
case ISD::XOR: {
13053+
if (!isa<ConstantSDNode>(N.getOperand(1)))
13054+
return SDValue();
13055+
if (!cast<ConstantSDNode>(N.getOperand(1))->isOne())
13056+
return SDValue();
13057+
Negate = !Negate;
13058+
return SearchLoopIntrinsic(N.getOperand(0), CC, Imm, Negate);
13059+
}
13060+
case ISD::SETCC: {
13061+
auto *Const = dyn_cast<ConstantSDNode>(N.getOperand(1));
13062+
if (!Const)
13063+
return SDValue();
13064+
if (Const->isNullValue())
13065+
Imm = 0;
13066+
else if (Const->isOne())
13067+
Imm = 1;
13068+
else
13069+
return SDValue();
13070+
CC = cast<CondCodeSDNode>(N.getOperand(2))->get();
13071+
return SearchLoopIntrinsic(N->getOperand(0), CC, Imm, Negate);
13072+
}
13073+
case ISD::INTRINSIC_W_CHAIN: {
13074+
unsigned IntOp = cast<ConstantSDNode>(N.getOperand(1))->getZExtValue();
13075+
if (IntOp != Intrinsic::test_set_loop_iterations &&
13076+
IntOp != Intrinsic::loop_decrement_reg)
13077+
return SDValue();
13078+
return N;
13079+
}
13080+
}
13081+
return SDValue();
13082+
}
13083+
1303713084
static SDValue PerformHWLoopCombine(SDNode *N,
1303813085
TargetLowering::DAGCombinerInfo &DCI,
1303913086
const ARMSubtarget *ST) {
13040-
// Look for (brcond (xor test.set.loop.iterations, -1)
13041-
SDValue CC = N->getOperand(1);
13042-
unsigned Opc = CC->getOpcode();
13043-
SDValue Int;
1304413087

13045-
if ((Opc == ISD::XOR || Opc == ISD::SETCC) &&
13046-
(CC->getOperand(0)->getOpcode() == ISD::INTRINSIC_W_CHAIN)) {
13088+
// The hwloop intrinsics that we're interested are used for control-flow,
13089+
// either for entering or exiting the loop:
13090+
// - test.set.loop.iterations will test whether its operand is zero. If it
13091+
// is zero, the proceeding branch should not enter the loop.
13092+
// - loop.decrement.reg also tests whether its operand is zero. If it is
13093+
// zero, the proceeding branch should not branch back to the beginning of
13094+
// the loop.
13095+
// So here, we need to check that how the brcond is using the result of each
13096+
// of the intrinsics to ensure that we're branching to the right place at the
13097+
// right time.
13098+
13099+
ISD::CondCode CC;
13100+
SDValue Cond;
13101+
int Imm = 1;
13102+
bool Negate = false;
13103+
SDValue Chain = N->getOperand(0);
13104+
SDValue Dest;
1304713105

13048-
assert((isa<ConstantSDNode>(CC->getOperand(1)) &&
13049-
cast<ConstantSDNode>(CC->getOperand(1))->isOne()) &&
13050-
"Expected to compare against 1");
13106+
if (N->getOpcode() == ISD::BRCOND) {
13107+
CC = ISD::SETEQ;
13108+
Cond = N->getOperand(1);
13109+
Dest = N->getOperand(2);
13110+
} else {
13111+
assert(N->getOpcode() == ISD::BR_CC && "Expected BRCOND or BR_CC!");
13112+
CC = cast<CondCodeSDNode>(N->getOperand(1))->get();
13113+
Cond = N->getOperand(2);
13114+
Dest = N->getOperand(4);
13115+
if (auto *Const = dyn_cast<ConstantSDNode>(N->getOperand(3))) {
13116+
if (!Const->isOne() && !Const->isNullValue())
13117+
return SDValue();
13118+
Imm = Const->getZExtValue();
13119+
} else
13120+
return SDValue();
13121+
}
1305113122

13052-
Int = CC->getOperand(0);
13053-
} else if (CC->getOpcode() == ISD::INTRINSIC_W_CHAIN)
13054-
Int = CC;
13055-
else
13123+
SDValue Int = SearchLoopIntrinsic(Cond, CC, Imm, Negate);
13124+
if (!Int)
1305613125
return SDValue();
1305713126

13058-
unsigned IntOp = cast<ConstantSDNode>(Int.getOperand(1))->getZExtValue();
13059-
if (IntOp != Intrinsic::test_set_loop_iterations)
13060-
return SDValue();
13127+
if (Negate)
13128+
CC = ISD::getSetCCInverse(CC, true);
13129+
13130+
auto IsTrueIfZero = [](ISD::CondCode CC, int Imm) {
13131+
return (CC == ISD::SETEQ && Imm == 0) ||
13132+
(CC == ISD::SETNE && Imm == 1) ||
13133+
(CC == ISD::SETLT && Imm == 1) ||
13134+
(CC == ISD::SETULT && Imm == 1);
13135+
};
13136+
13137+
auto IsFalseIfZero = [](ISD::CondCode CC, int Imm) {
13138+
return (CC == ISD::SETEQ && Imm == 1) ||
13139+
(CC == ISD::SETNE && Imm == 0) ||
13140+
(CC == ISD::SETGT && Imm == 0) ||
13141+
(CC == ISD::SETUGT && Imm == 0) ||
13142+
(CC == ISD::SETGE && Imm == 1) ||
13143+
(CC == ISD::SETUGE && Imm == 1);
13144+
};
13145+
13146+
assert((IsTrueIfZero(CC, Imm) || IsFalseIfZero(CC, Imm)) &&
13147+
"unsupported condition");
1306113148

1306213149
SDLoc dl(Int);
13063-
SDValue Chain = N->getOperand(0);
13150+
SelectionDAG &DAG = DCI.DAG;
1306413151
SDValue Elements = Int.getOperand(2);
13065-
SDValue ExitBlock = N->getOperand(2);
13152+
unsigned IntOp = cast<ConstantSDNode>(Int->getOperand(1))->getZExtValue();
13153+
assert((N->hasOneUse() && N->use_begin()->getOpcode() == ISD::BR)
13154+
&& "expected single br user");
13155+
SDNode *Br = *N->use_begin();
13156+
SDValue OtherTarget = Br->getOperand(1);
13157+
13158+
// Update the unconditional branch to branch to the given Dest.
13159+
auto UpdateUncondBr = [](SDNode *Br, SDValue Dest, SelectionDAG &DAG) {
13160+
SDValue NewBrOps[] = { Br->getOperand(0), Dest };
13161+
SDValue NewBr = DAG.getNode(ISD::BR, SDLoc(Br), MVT::Other, NewBrOps);
13162+
DAG.ReplaceAllUsesOfValueWith(SDValue(Br, 0), NewBr);
13163+
};
1306613164

13067-
// TODO: Once we start supporting tail predication, we can add another
13068-
// operand to WLS for the number of elements processed in a vector loop.
13165+
if (IntOp == Intrinsic::test_set_loop_iterations) {
13166+
SDValue Res;
13167+
// We expect this 'instruction' to branch when the counter is zero.
13168+
if (IsTrueIfZero(CC, Imm)) {
13169+
SDValue Ops[] = { Chain, Elements, Dest };
13170+
Res = DAG.getNode(ARMISD::WLS, dl, MVT::Other, Ops);
13171+
} else {
13172+
// The logic is the reverse of what we need for WLS, so find the other
13173+
// basic block target: the target of the proceeding br.
13174+
UpdateUncondBr(Br, Dest, DAG);
1306913175

13070-
SDValue Ops[] = { Chain, Elements, ExitBlock };
13071-
SDValue Res = DCI.DAG.getNode(ARMISD::WLS, dl, MVT::Other, Ops);
13072-
DCI.DAG.ReplaceAllUsesOfValueWith(Int.getValue(1), Int.getOperand(0));
13073-
return Res;
13176+
SDValue Ops[] = { Chain, Elements, OtherTarget };
13177+
Res = DAG.getNode(ARMISD::WLS, dl, MVT::Other, Ops);
13178+
}
13179+
DAG.ReplaceAllUsesOfValueWith(Int.getValue(1), Int.getOperand(0));
13180+
return Res;
13181+
} else {
13182+
SDValue Size = DAG.getTargetConstant(
13183+
cast<ConstantSDNode>(Int.getOperand(3))->getZExtValue(), dl, MVT::i32);
13184+
SDValue Args[] = { Int.getOperand(0), Elements, Size, };
13185+
SDValue LoopDec = DAG.getNode(ARMISD::LOOP_DEC, dl,
13186+
DAG.getVTList(MVT::i32, MVT::Other), Args);
13187+
DAG.ReplaceAllUsesWith(Int.getNode(), LoopDec.getNode());
13188+
13189+
// We expect this instruction to branch when the count is not zero.
13190+
SDValue Target = IsFalseIfZero(CC, Imm) ? Dest : OtherTarget;
13191+
13192+
// Update the unconditional branch to target the loop preheader if we've
13193+
// found the condition has been reversed.
13194+
if (Target == OtherTarget)
13195+
UpdateUncondBr(Br, Dest, DAG);
13196+
13197+
Chain = DAG.getNode(ISD::TokenFactor, dl, MVT::Other,
13198+
SDValue(LoopDec.getNode(), 1), Chain);
13199+
13200+
SDValue EndArgs[] = { Chain, SDValue(LoopDec.getNode(), 0), Target };
13201+
return DAG.getNode(ARMISD::LE, dl, MVT::Other, EndArgs);
13202+
}
13203+
return SDValue();
1307413204
}
1307513205

1307613206
/// PerformBRCONDCombine - Target-specific DAG combining for ARMISD::BRCOND.
@@ -13304,7 +13434,8 @@ SDValue ARMTargetLowering::PerformDAGCombine(SDNode *N,
1330413434
case ISD::OR: return PerformORCombine(N, DCI, Subtarget);
1330513435
case ISD::XOR: return PerformXORCombine(N, DCI, Subtarget);
1330613436
case ISD::AND: return PerformANDCombine(N, DCI, Subtarget);
13307-
case ISD::BRCOND: return PerformHWLoopCombine(N, DCI, Subtarget);
13437+
case ISD::BRCOND:
13438+
case ISD::BR_CC: return PerformHWLoopCombine(N, DCI, Subtarget);
1330813439
case ARMISD::ADDC:
1330913440
case ARMISD::SUBC: return PerformAddcSubcCombine(N, DCI, Subtarget);
1331013441
case ARMISD::SUBE: return PerformAddeSubeCombine(N, DCI, Subtarget);

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ class VectorType;
126126
WIN__DBZCHK, // Windows' divide by zero check
127127

128128
WLS, // Low-overhead loops, While Loop Start
129+
LOOP_DEC, // Really a part of LE, performs the sub
130+
LE, // Low-overhead loops, Loop End
129131

130132
VCEQ, // Vector compare equal.
131133
VCEQZ, // Vector compare equal to zero.

llvm/lib/Target/ARM/ARMInstrInfo.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def SDT_ARMIntShiftParts : SDTypeProfile<2, 3, [SDTCisSameAs<0, 1>,
108108

109109
// TODO Add another operand for 'Size' so that we can re-use this node when we
110110
// start supporting *TP versions.
111-
def SDT_ARMWhileLoop : SDTypeProfile<0, 2, [SDTCisVT<0, i32>,
112-
SDTCisVT<1, OtherVT>]>;
111+
def SDT_ARMLoLoop : SDTypeProfile<0, 2, [SDTCisVT<0, i32>,
112+
SDTCisVT<1, OtherVT>]>;
113113

114114
def ARMSmlald : SDNode<"ARMISD::SMLALD", SDT_LongMac>;
115115
def ARMSmlaldx : SDNode<"ARMISD::SMLALDX", SDT_LongMac>;
@@ -265,9 +265,9 @@ def ARMvshruImm : SDNode<"ARMISD::VSHRuIMM", SDTARMVSHIMM>;
265265
def ARMvshls : SDNode<"ARMISD::VSHLs", SDTARMVSH>;
266266
def ARMvshlu : SDNode<"ARMISD::VSHLu", SDTARMVSH>;
267267

268-
def ARMWLS : SDNode<"ARMISD::WLS", SDT_ARMWhileLoop,
269-
[SDNPHasChain]>;
270-
268+
def ARMWLS : SDNode<"ARMISD::WLS", SDT_ARMLoLoop, [SDNPHasChain]>;
269+
def ARMLE : SDNode<"ARMISD::LE", SDT_ARMLoLoop, [SDNPHasChain]>;
270+
def ARMLoopDec : SDNode<"ARMISD::LOOP_DEC", SDTIntBinOp, [SDNPHasChain]>;
271271
//===----------------------------------------------------------------------===//
272272
// ARM Flag Definitions.
273273

0 commit comments

Comments
 (0)