Skip to content

Commit 5dc339d

Browse files
author
Jessica Paquette
committed
[AArch64][GlobalISel] Fold 64-bit cmps with 64-bit adds
G_ICMP is selected to an arithmetic overflow op (ADDS/SUBS/etc) with a dead destination + a CSINC instruction. We have a fold which allows us to combine 32-bit adds with G_ICMP. The problem with G_ICMP is that we model it as always having a 32-bit destination even though it can be a 64-bit operation. So, we were missing some opportunities for 64-bit folds. This patch teaches the fold to recognize 64-bit G_ICMPs + refactors some of the code surrounding CSINC accordingly. (Later down the line, I think we should probably change the way we handle G_ICMP in general.) Differential Revision: https://reviews.llvm.org/D111088
1 parent 2ba572a commit 5dc339d

File tree

2 files changed

+285
-139
lines changed

2 files changed

+285
-139
lines changed

llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -276,13 +276,9 @@ class AArch64InstructionSelector : public InstructionSelector {
276276
const RegisterBank &DstRB, LLT ScalarTy,
277277
Register VecReg, unsigned LaneIdx,
278278
MachineIRBuilder &MIRBuilder) const;
279-
280-
/// Emit a CSet for an integer compare.
281-
///
282-
/// \p DefReg and \p SrcReg are expected to be 32-bit scalar registers.
283-
MachineInstr *emitCSetForICMP(Register DefReg, unsigned Pred,
284-
MachineIRBuilder &MIRBuilder,
285-
Register SrcReg = AArch64::WZR) const;
279+
MachineInstr *emitCSINC(Register Dst, Register Src1, Register Src2,
280+
AArch64CC::CondCode Pred,
281+
MachineIRBuilder &MIRBuilder) const;
286282
/// Emit a CSet for a FP compare.
287283
///
288284
/// \p Dst is expected to be a 32-bit scalar register.
@@ -2213,27 +2209,55 @@ bool AArch64InstructionSelector::earlySelect(MachineInstr &I) {
22132209
// fold the add into the cset for the cmp by using cinc.
22142210
//
22152211
// FIXME: This would probably be a lot nicer in PostLegalizerLowering.
2216-
Register X = I.getOperand(1).getReg();
2217-
2218-
// Only handle scalars. Scalar G_ICMP is only legal for s32, so bail out
2219-
// early if we see it.
2220-
LLT Ty = MRI.getType(X);
2221-
if (Ty.isVector() || Ty.getSizeInBits() != 32)
2212+
Register AddDst = I.getOperand(0).getReg();
2213+
Register AddLHS = I.getOperand(1).getReg();
2214+
Register AddRHS = I.getOperand(2).getReg();
2215+
// Only handle scalars.
2216+
LLT Ty = MRI.getType(AddLHS);
2217+
if (Ty.isVector())
22222218
return false;
2223-
2224-
Register CmpReg = I.getOperand(2).getReg();
2225-
MachineInstr *Cmp = getOpcodeDef(TargetOpcode::G_ICMP, CmpReg, MRI);
2219+
// Since G_ICMP is modeled as ADDS/SUBS/ANDS, we can handle 32 bits or 64
2220+
// bits.
2221+
unsigned Size = Ty.getSizeInBits();
2222+
if (Size != 32 && Size != 64)
2223+
return false;
2224+
auto MatchCmp = [&](Register Reg) -> MachineInstr * {
2225+
if (!MRI.hasOneNonDBGUse(Reg))
2226+
return nullptr;
2227+
// If the LHS of the add is 32 bits, then we want to fold a 32-bit
2228+
// compare.
2229+
if (Size == 32)
2230+
return getOpcodeDef(TargetOpcode::G_ICMP, Reg, MRI);
2231+
// We model scalar compares using 32-bit destinations right now.
2232+
// If it's a 64-bit compare, it'll have 64-bit sources.
2233+
Register ZExt;
2234+
if (!mi_match(Reg, MRI,
2235+
m_OneNonDBGUse(m_GZExt(m_OneNonDBGUse(m_Reg(ZExt))))))
2236+
return nullptr;
2237+
auto *Cmp = getOpcodeDef(TargetOpcode::G_ICMP, ZExt, MRI);
2238+
if (!Cmp ||
2239+
MRI.getType(Cmp->getOperand(2).getReg()).getSizeInBits() != 64)
2240+
return nullptr;
2241+
return Cmp;
2242+
};
2243+
// Try to match
2244+
// z + (cmp pred, x, y)
2245+
MachineInstr *Cmp = MatchCmp(AddRHS);
22262246
if (!Cmp) {
2227-
std::swap(X, CmpReg);
2228-
Cmp = getOpcodeDef(TargetOpcode::G_ICMP, CmpReg, MRI);
2247+
// (cmp pred, x, y) + z
2248+
std::swap(AddLHS, AddRHS);
2249+
Cmp = MatchCmp(AddRHS);
22292250
if (!Cmp)
22302251
return false;
22312252
}
2232-
auto Pred =
2233-
static_cast<CmpInst::Predicate>(Cmp->getOperand(1).getPredicate());
2234-
emitIntegerCompare(Cmp->getOperand(2), Cmp->getOperand(3),
2235-
Cmp->getOperand(1), MIB);
2236-
emitCSetForICMP(I.getOperand(0).getReg(), Pred, MIB, X);
2253+
auto &PredOp = Cmp->getOperand(1);
2254+
auto Pred = static_cast<CmpInst::Predicate>(PredOp.getPredicate());
2255+
const AArch64CC::CondCode InvCC =
2256+
changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred));
2257+
MIB.setInstrAndDebugLoc(I);
2258+
emitIntegerCompare(/*LHS=*/Cmp->getOperand(2),
2259+
/*RHS=*/Cmp->getOperand(3), PredOp, MIB);
2260+
emitCSINC(/*Dst=*/AddDst, /*Src =*/AddLHS, /*Src2=*/AddLHS, InvCC, MIB);
22372261
I.eraseFromParent();
22382262
return true;
22392263
}
@@ -2963,10 +2987,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
29632987
// false, so to get the increment when it's true, we need to use the
29642988
// inverse. In this case, we want to increment when carry is set.
29652989
Register ZReg = AArch64::WZR;
2966-
auto CsetMI = MIB.buildInstr(AArch64::CSINCWr, {I.getOperand(1).getReg()},
2967-
{ZReg, ZReg})
2968-
.addImm(getInvertedCondCode(OpAndCC.second));
2969-
constrainSelectedInstRegOperands(*CsetMI, TII, TRI, RBI);
2990+
emitCSINC(/*Dst=*/I.getOperand(1).getReg(), /*Src1=*/ZReg, /*Src2=*/ZReg,
2991+
getInvertedCondCode(OpAndCC.second), MIB);
29702992
I.eraseFromParent();
29712993
return true;
29722994
}
@@ -3303,9 +3325,11 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
33033325
}
33043326

33053327
auto Pred = static_cast<CmpInst::Predicate>(I.getOperand(1).getPredicate());
3306-
emitIntegerCompare(I.getOperand(2), I.getOperand(3), I.getOperand(1),
3307-
MIB);
3308-
emitCSetForICMP(I.getOperand(0).getReg(), Pred, MIB);
3328+
const AArch64CC::CondCode InvCC =
3329+
changeICMPPredToAArch64CC(CmpInst::getInversePredicate(Pred));
3330+
emitIntegerCompare(I.getOperand(2), I.getOperand(3), I.getOperand(1), MIB);
3331+
emitCSINC(/*Dst=*/I.getOperand(0).getReg(), /*Src1=*/AArch64::WZR,
3332+
/*Src2=*/AArch64::WZR, InvCC, MIB);
33093333
I.eraseFromParent();
33103334
return true;
33113335
}
@@ -4451,25 +4475,19 @@ MachineInstr *AArch64InstructionSelector::emitCSetForFCmp(
44514475
assert(!Ty.isVector() && Ty.getSizeInBits() == 32 &&
44524476
"Expected a 32-bit scalar register?");
44534477
#endif
4454-
const Register ZeroReg = AArch64::WZR;
4455-
auto EmitCSet = [&](Register CsetDst, AArch64CC::CondCode CC) {
4456-
auto CSet =
4457-
MIRBuilder.buildInstr(AArch64::CSINCWr, {CsetDst}, {ZeroReg, ZeroReg})
4458-
.addImm(getInvertedCondCode(CC));
4459-
constrainSelectedInstRegOperands(*CSet, TII, TRI, RBI);
4460-
return &*CSet;
4461-
};
4462-
4478+
const Register ZReg = AArch64::WZR;
44634479
AArch64CC::CondCode CC1, CC2;
44644480
changeFCMPPredToAArch64CC(Pred, CC1, CC2);
4481+
auto InvCC1 = AArch64CC::getInvertedCondCode(CC1);
44654482
if (CC2 == AArch64CC::AL)
4466-
return EmitCSet(Dst, CC1);
4467-
4483+
return emitCSINC(/*Dst=*/Dst, /*Src1=*/ZReg, /*Src2=*/ZReg, InvCC1,
4484+
MIRBuilder);
44684485
const TargetRegisterClass *RC = &AArch64::GPR32RegClass;
44694486
Register Def1Reg = MRI.createVirtualRegister(RC);
44704487
Register Def2Reg = MRI.createVirtualRegister(RC);
4471-
EmitCSet(Def1Reg, CC1);
4472-
EmitCSet(Def2Reg, CC2);
4488+
auto InvCC2 = AArch64CC::getInvertedCondCode(CC2);
4489+
emitCSINC(/*Dst=*/Def1Reg, /*Src1=*/ZReg, /*Src2=*/ZReg, InvCC1, MIRBuilder);
4490+
emitCSINC(/*Dst=*/Def2Reg, /*Src1=*/ZReg, /*Src2=*/ZReg, InvCC2, MIRBuilder);
44734491
auto OrMI = MIRBuilder.buildInstr(AArch64::ORRWrr, {Dst}, {Def1Reg, Def2Reg});
44744492
constrainSelectedInstRegOperands(*OrMI, TII, TRI, RBI);
44754493
return &*OrMI;
@@ -4578,16 +4596,25 @@ MachineInstr *AArch64InstructionSelector::emitVectorConcat(
45784596
}
45794597

45804598
MachineInstr *
4581-
AArch64InstructionSelector::emitCSetForICMP(Register DefReg, unsigned Pred,
4582-
MachineIRBuilder &MIRBuilder,
4583-
Register SrcReg) const {
4584-
// CSINC increments the result when the predicate is false. Invert it.
4585-
const AArch64CC::CondCode InvCC = changeICMPPredToAArch64CC(
4586-
CmpInst::getInversePredicate((CmpInst::Predicate)Pred));
4587-
auto I = MIRBuilder.buildInstr(AArch64::CSINCWr, {DefReg}, {SrcReg, SrcReg})
4588-
.addImm(InvCC);
4589-
constrainSelectedInstRegOperands(*I, TII, TRI, RBI);
4590-
return &*I;
4599+
AArch64InstructionSelector::emitCSINC(Register Dst, Register Src1,
4600+
Register Src2, AArch64CC::CondCode Pred,
4601+
MachineIRBuilder &MIRBuilder) const {
4602+
auto &MRI = *MIRBuilder.getMRI();
4603+
const RegClassOrRegBank &RegClassOrBank = MRI.getRegClassOrRegBank(Dst);
4604+
// If we used a register class, then this won't necessarily have an LLT.
4605+
// Compute the size based off whether or not we have a class or bank.
4606+
unsigned Size;
4607+
if (const auto *RC = RegClassOrBank.dyn_cast<const TargetRegisterClass *>())
4608+
Size = TRI.getRegSizeInBits(*RC);
4609+
else
4610+
Size = MRI.getType(Dst).getSizeInBits();
4611+
// Some opcodes use s1.
4612+
assert(Size <= 64 && "Expected 64 bits or less only!");
4613+
static const unsigned OpcTable[2] = {AArch64::CSINCWr, AArch64::CSINCXr};
4614+
unsigned Opc = OpcTable[Size == 64];
4615+
auto CSINC = MIRBuilder.buildInstr(Opc, {Dst}, {Src1, Src2}).addImm(Pred);
4616+
constrainSelectedInstRegOperands(*CSINC, TII, TRI, RBI);
4617+
return &*CSINC;
45914618
}
45924619

45934620
std::pair<MachineInstr *, AArch64CC::CondCode>

0 commit comments

Comments
 (0)