@@ -276,13 +276,9 @@ class AArch64InstructionSelector : public InstructionSelector {
276
276
const RegisterBank &DstRB, LLT ScalarTy,
277
277
Register VecReg, unsigned LaneIdx,
278
278
MachineIRBuilder &MIRBuilder) const ;
279
-
280
- // / Emit a CSet for an integer compare.
281
- // /
282
- // / \p DefReg and \p SrcReg are expected to be 32-bit scalar registers.
283
- MachineInstr *emitCSetForICMP (Register DefReg, unsigned Pred,
284
- MachineIRBuilder &MIRBuilder,
285
- Register SrcReg = AArch64::WZR) const ;
279
+ MachineInstr *emitCSINC (Register Dst, Register Src1, Register Src2,
280
+ AArch64CC::CondCode Pred,
281
+ MachineIRBuilder &MIRBuilder) const ;
286
282
// / Emit a CSet for a FP compare.
287
283
// /
288
284
// / \p Dst is expected to be a 32-bit scalar register.
@@ -2213,27 +2209,55 @@ bool AArch64InstructionSelector::earlySelect(MachineInstr &I) {
2213
2209
// fold the add into the cset for the cmp by using cinc.
2214
2210
//
2215
2211
// FIXME: This would probably be a lot nicer in PostLegalizerLowering.
2216
- Register X = I.getOperand (1 ).getReg ();
2217
-
2218
- // Only handle scalars. Scalar G_ICMP is only legal for s32, so bail out
2219
- // early if we see it .
2220
- LLT Ty = MRI.getType (X );
2221
- if (Ty.isVector () || Ty. getSizeInBits () != 32 )
2212
+ Register AddDst = I.getOperand (0 ).getReg ();
2213
+ Register AddLHS = I. getOperand ( 1 ). getReg ();
2214
+ Register AddRHS = I. getOperand ( 2 ). getReg ();
2215
+ // Only handle scalars .
2216
+ LLT Ty = MRI.getType (AddLHS );
2217
+ if (Ty.isVector ())
2222
2218
return false ;
2223
-
2224
- Register CmpReg = I.getOperand (2 ).getReg ();
2225
- MachineInstr *Cmp = getOpcodeDef (TargetOpcode::G_ICMP, CmpReg, MRI);
2219
+ // Since G_ICMP is modeled as ADDS/SUBS/ANDS, we can handle 32 bits or 64
2220
+ // bits.
2221
+ unsigned Size = Ty.getSizeInBits ();
2222
+ if (Size != 32 && Size != 64 )
2223
+ return false ;
2224
+ auto MatchCmp = [&](Register Reg) -> MachineInstr * {
2225
+ if (!MRI.hasOneNonDBGUse (Reg))
2226
+ return nullptr ;
2227
+ // If the LHS of the add is 32 bits, then we want to fold a 32-bit
2228
+ // compare.
2229
+ if (Size == 32 )
2230
+ return getOpcodeDef (TargetOpcode::G_ICMP, Reg, MRI);
2231
+ // We model scalar compares using 32-bit destinations right now.
2232
+ // If it's a 64-bit compare, it'll have 64-bit sources.
2233
+ Register ZExt;
2234
+ if (!mi_match (Reg, MRI,
2235
+ m_OneNonDBGUse (m_GZExt (m_OneNonDBGUse (m_Reg (ZExt))))))
2236
+ return nullptr ;
2237
+ auto *Cmp = getOpcodeDef (TargetOpcode::G_ICMP, ZExt, MRI);
2238
+ if (!Cmp ||
2239
+ MRI.getType (Cmp->getOperand (2 ).getReg ()).getSizeInBits () != 64 )
2240
+ return nullptr ;
2241
+ return Cmp;
2242
+ };
2243
+ // Try to match
2244
+ // z + (cmp pred, x, y)
2245
+ MachineInstr *Cmp = MatchCmp (AddRHS);
2226
2246
if (!Cmp) {
2227
- std::swap (X, CmpReg);
2228
- Cmp = getOpcodeDef (TargetOpcode::G_ICMP, CmpReg, MRI);
2247
+ // (cmp pred, x, y) + z
2248
+ std::swap (AddLHS, AddRHS);
2249
+ Cmp = MatchCmp (AddRHS);
2229
2250
if (!Cmp)
2230
2251
return false ;
2231
2252
}
2232
- auto Pred =
2233
- static_cast <CmpInst::Predicate>(Cmp->getOperand (1 ).getPredicate ());
2234
- emitIntegerCompare (Cmp->getOperand (2 ), Cmp->getOperand (3 ),
2235
- Cmp->getOperand (1 ), MIB);
2236
- emitCSetForICMP (I.getOperand (0 ).getReg (), Pred, MIB, X);
2253
+ auto &PredOp = Cmp->getOperand (1 );
2254
+ auto Pred = static_cast <CmpInst::Predicate>(PredOp.getPredicate ());
2255
+ const AArch64CC::CondCode InvCC =
2256
+ changeICMPPredToAArch64CC (CmpInst::getInversePredicate (Pred));
2257
+ MIB.setInstrAndDebugLoc (I);
2258
+ emitIntegerCompare (/* LHS=*/ Cmp->getOperand (2 ),
2259
+ /* RHS=*/ Cmp->getOperand (3 ), PredOp, MIB);
2260
+ emitCSINC (/* Dst=*/ AddDst, /* Src =*/ AddLHS, /* Src2=*/ AddLHS, InvCC, MIB);
2237
2261
I.eraseFromParent ();
2238
2262
return true ;
2239
2263
}
@@ -2963,10 +2987,8 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
2963
2987
// false, so to get the increment when it's true, we need to use the
2964
2988
// inverse. In this case, we want to increment when carry is set.
2965
2989
Register ZReg = AArch64::WZR;
2966
- auto CsetMI = MIB.buildInstr (AArch64::CSINCWr, {I.getOperand (1 ).getReg ()},
2967
- {ZReg, ZReg})
2968
- .addImm (getInvertedCondCode (OpAndCC.second ));
2969
- constrainSelectedInstRegOperands (*CsetMI, TII, TRI, RBI);
2990
+ emitCSINC (/* Dst=*/ I.getOperand (1 ).getReg (), /* Src1=*/ ZReg, /* Src2=*/ ZReg,
2991
+ getInvertedCondCode (OpAndCC.second ), MIB);
2970
2992
I.eraseFromParent ();
2971
2993
return true ;
2972
2994
}
@@ -3303,9 +3325,11 @@ bool AArch64InstructionSelector::select(MachineInstr &I) {
3303
3325
}
3304
3326
3305
3327
auto Pred = static_cast <CmpInst::Predicate>(I.getOperand (1 ).getPredicate ());
3306
- emitIntegerCompare (I.getOperand (2 ), I.getOperand (3 ), I.getOperand (1 ),
3307
- MIB);
3308
- emitCSetForICMP (I.getOperand (0 ).getReg (), Pred, MIB);
3328
+ const AArch64CC::CondCode InvCC =
3329
+ changeICMPPredToAArch64CC (CmpInst::getInversePredicate (Pred));
3330
+ emitIntegerCompare (I.getOperand (2 ), I.getOperand (3 ), I.getOperand (1 ), MIB);
3331
+ emitCSINC (/* Dst=*/ I.getOperand (0 ).getReg (), /* Src1=*/ AArch64::WZR,
3332
+ /* Src2=*/ AArch64::WZR, InvCC, MIB);
3309
3333
I.eraseFromParent ();
3310
3334
return true ;
3311
3335
}
@@ -4451,25 +4475,19 @@ MachineInstr *AArch64InstructionSelector::emitCSetForFCmp(
4451
4475
assert (!Ty.isVector () && Ty.getSizeInBits () == 32 &&
4452
4476
" Expected a 32-bit scalar register?" );
4453
4477
#endif
4454
- const Register ZeroReg = AArch64::WZR;
4455
- auto EmitCSet = [&](Register CsetDst, AArch64CC::CondCode CC) {
4456
- auto CSet =
4457
- MIRBuilder.buildInstr (AArch64::CSINCWr, {CsetDst}, {ZeroReg, ZeroReg})
4458
- .addImm (getInvertedCondCode (CC));
4459
- constrainSelectedInstRegOperands (*CSet, TII, TRI, RBI);
4460
- return &*CSet;
4461
- };
4462
-
4478
+ const Register ZReg = AArch64::WZR;
4463
4479
AArch64CC::CondCode CC1, CC2;
4464
4480
changeFCMPPredToAArch64CC (Pred, CC1, CC2);
4481
+ auto InvCC1 = AArch64CC::getInvertedCondCode (CC1);
4465
4482
if (CC2 == AArch64CC::AL)
4466
- return EmitCSet ( Dst, CC1);
4467
-
4483
+ return emitCSINC ( /* Dst= */ Dst, /* Src1= */ ZReg, /* Src2= */ ZReg, InvCC1,
4484
+ MIRBuilder);
4468
4485
const TargetRegisterClass *RC = &AArch64::GPR32RegClass;
4469
4486
Register Def1Reg = MRI.createVirtualRegister (RC);
4470
4487
Register Def2Reg = MRI.createVirtualRegister (RC);
4471
- EmitCSet (Def1Reg, CC1);
4472
- EmitCSet (Def2Reg, CC2);
4488
+ auto InvCC2 = AArch64CC::getInvertedCondCode (CC2);
4489
+ emitCSINC (/* Dst=*/ Def1Reg, /* Src1=*/ ZReg, /* Src2=*/ ZReg, InvCC1, MIRBuilder);
4490
+ emitCSINC (/* Dst=*/ Def2Reg, /* Src1=*/ ZReg, /* Src2=*/ ZReg, InvCC2, MIRBuilder);
4473
4491
auto OrMI = MIRBuilder.buildInstr (AArch64::ORRWrr, {Dst}, {Def1Reg, Def2Reg});
4474
4492
constrainSelectedInstRegOperands (*OrMI, TII, TRI, RBI);
4475
4493
return &*OrMI;
@@ -4578,16 +4596,25 @@ MachineInstr *AArch64InstructionSelector::emitVectorConcat(
4578
4596
}
4579
4597
4580
4598
MachineInstr *
4581
- AArch64InstructionSelector::emitCSetForICMP (Register DefReg, unsigned Pred,
4582
- MachineIRBuilder &MIRBuilder,
4583
- Register SrcReg) const {
4584
- // CSINC increments the result when the predicate is false. Invert it.
4585
- const AArch64CC::CondCode InvCC = changeICMPPredToAArch64CC (
4586
- CmpInst::getInversePredicate ((CmpInst::Predicate)Pred));
4587
- auto I = MIRBuilder.buildInstr (AArch64::CSINCWr, {DefReg}, {SrcReg, SrcReg})
4588
- .addImm (InvCC);
4589
- constrainSelectedInstRegOperands (*I, TII, TRI, RBI);
4590
- return &*I;
4599
+ AArch64InstructionSelector::emitCSINC (Register Dst, Register Src1,
4600
+ Register Src2, AArch64CC::CondCode Pred,
4601
+ MachineIRBuilder &MIRBuilder) const {
4602
+ auto &MRI = *MIRBuilder.getMRI ();
4603
+ const RegClassOrRegBank &RegClassOrBank = MRI.getRegClassOrRegBank (Dst);
4604
+ // If we used a register class, then this won't necessarily have an LLT.
4605
+ // Compute the size based off whether or not we have a class or bank.
4606
+ unsigned Size;
4607
+ if (const auto *RC = RegClassOrBank.dyn_cast <const TargetRegisterClass *>())
4608
+ Size = TRI.getRegSizeInBits (*RC);
4609
+ else
4610
+ Size = MRI.getType (Dst).getSizeInBits ();
4611
+ // Some opcodes use s1.
4612
+ assert (Size <= 64 && " Expected 64 bits or less only!" );
4613
+ static const unsigned OpcTable[2 ] = {AArch64::CSINCWr, AArch64::CSINCXr};
4614
+ unsigned Opc = OpcTable[Size == 64 ];
4615
+ auto CSINC = MIRBuilder.buildInstr (Opc, {Dst}, {Src1, Src2}).addImm (Pred);
4616
+ constrainSelectedInstRegOperands (*CSINC, TII, TRI, RBI);
4617
+ return &*CSINC;
4591
4618
}
4592
4619
4593
4620
std::pair<MachineInstr *, AArch64CC::CondCode>
0 commit comments