@@ -162,6 +162,11 @@ bool CombinerHelper::isLegalOrBeforeLegalizer(
162
162
return isPreLegalize () || isLegal (Query);
163
163
}
164
164
165
+ bool CombinerHelper::isLegalorHasWidenScalar (const LegalityQuery &Query) const {
166
+ return isLegal (Query) ||
167
+ LI->getAction (Query).Action == LegalizeActions::WidenScalar;
168
+ }
169
+
165
170
bool CombinerHelper::isConstantLegalOrBeforeLegalizer (const LLT Ty) const {
166
171
if (!Ty.isVector ())
167
172
return isLegalOrBeforeLegalizer ({TargetOpcode::G_CONSTANT, {Ty}});
@@ -5510,6 +5515,8 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
5510
5515
Register Dst = MI.getOperand (0 ).getReg ();
5511
5516
Register RHS = MI.getOperand (2 ).getReg ();
5512
5517
LLT DstTy = MRI.getType (Dst);
5518
+ auto SizeInBits = DstTy.getScalarSizeInBits ();
5519
+ LLT WideTy = DstTy.changeElementSize (SizeInBits * 2 );
5513
5520
5514
5521
auto &MF = *MI.getMF ();
5515
5522
AttributeList Attr = MF.getFunction ().getAttributes ();
@@ -5529,8 +5536,21 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
5529
5536
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue (); });
5530
5537
}
5531
5538
5532
- // Don't support the general case for now.
5533
- return false ;
5539
+ auto *RHSDef = MRI.getVRegDef (RHS);
5540
+ if (!isConstantOrConstantVector (*RHSDef, MRI))
5541
+ return false ;
5542
+
5543
+ // Don't do this if the types are not going to be legal.
5544
+ if (LI) {
5545
+ if (!isLegalOrBeforeLegalizer ({TargetOpcode::G_MUL, {DstTy, DstTy}}))
5546
+ return false ;
5547
+ if (!isLegal ({TargetOpcode::G_SMULH, {DstTy}}) &&
5548
+ !isLegalorHasWidenScalar ({TargetOpcode::G_MUL, {WideTy, WideTy}}))
5549
+ return false ;
5550
+ }
5551
+
5552
+ return matchUnaryPredicate (
5553
+ MRI, RHS, [](const Constant *C) { return C && !C->isNullValue (); });
5534
5554
}
5535
5555
5536
5556
void CombinerHelper::applySDivByConst (MachineInstr &MI) const {
@@ -5546,21 +5566,22 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
5546
5566
Register RHS = SDiv.getReg (2 );
5547
5567
LLT Ty = MRI.getType (Dst);
5548
5568
LLT ScalarTy = Ty.getScalarType ();
5569
+ const unsigned EltBits = ScalarTy.getScalarSizeInBits ();
5549
5570
LLT ShiftAmtTy = getTargetLowering ().getPreferredShiftAmountTy (Ty);
5550
5571
LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType ();
5551
5572
auto &MIB = Builder;
5552
5573
5553
5574
bool UseSRA = false ;
5554
- SmallVector<Register, 16 > Shifts, Factors ;
5575
+ SmallVector<Register, 16 > ExactShifts, ExactFactors ;
5555
5576
5556
- auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies (RHS, MRI));
5557
- bool IsSplat = getIConstantSplatVal (*RHSDef , MRI).has_value ();
5577
+ auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies (RHS, MRI));
5578
+ bool IsSplat = getIConstantSplatVal (*RHSDefInstr , MRI).has_value ();
5558
5579
5559
- auto BuildSDIVPattern = [&](const Constant *C) {
5580
+ auto BuildExactSDIVPattern = [&](const Constant *C) {
5560
5581
// Don't recompute inverses for each splat element.
5561
- if (IsSplat && !Factors .empty ()) {
5562
- Shifts .push_back (Shifts [0 ]);
5563
- Factors .push_back (Factors [0 ]);
5582
+ if (IsSplat && !ExactFactors .empty ()) {
5583
+ ExactShifts .push_back (ExactShifts [0 ]);
5584
+ ExactFactors .push_back (ExactFactors [0 ]);
5564
5585
return true ;
5565
5586
}
5566
5587
@@ -5575,31 +5596,104 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
5575
5596
// Calculate the multiplicative inverse modulo BW.
5576
5597
// 2^W requires W + 1 bits, so we have to extend and then truncate.
5577
5598
APInt Factor = Divisor.multiplicativeInverse ();
5578
- Shifts .push_back (MIB.buildConstant (ScalarShiftAmtTy, Shift).getReg (0 ));
5579
- Factors .push_back (MIB.buildConstant (ScalarTy, Factor).getReg (0 ));
5599
+ ExactShifts .push_back (MIB.buildConstant (ScalarShiftAmtTy, Shift).getReg (0 ));
5600
+ ExactFactors .push_back (MIB.buildConstant (ScalarTy, Factor).getReg (0 ));
5580
5601
return true ;
5581
5602
};
5582
5603
5583
- // Collect all magic values from the build vector.
5604
+ if (MI.getFlag (MachineInstr::MIFlag::IsExact)) {
5605
+ // Collect all magic values from the build vector.
5606
+ bool Matched = matchUnaryPredicate (MRI, RHS, BuildExactSDIVPattern);
5607
+ (void )Matched;
5608
+ assert (Matched && " Expected unary predicate match to succeed" );
5609
+
5610
+ Register Shift, Factor;
5611
+ if (Ty.isVector ()) {
5612
+ Shift = MIB.buildBuildVector (ShiftAmtTy, ExactShifts).getReg (0 );
5613
+ Factor = MIB.buildBuildVector (Ty, ExactFactors).getReg (0 );
5614
+ } else {
5615
+ Shift = ExactShifts[0 ];
5616
+ Factor = ExactFactors[0 ];
5617
+ }
5618
+
5619
+ Register Res = LHS;
5620
+
5621
+ if (UseSRA)
5622
+ Res = MIB.buildAShr (Ty, Res, Shift, MachineInstr::IsExact).getReg (0 );
5623
+
5624
+ return MIB.buildMul (Ty, Res, Factor);
5625
+ }
5626
+
5627
+ SmallVector<Register, 16 > MagicFactors, Factors, Shifts, ShiftMasks;
5628
+
5629
+ auto BuildSDIVPattern = [&](const Constant *C) {
5630
+ auto *CI = cast<ConstantInt>(C);
5631
+ const APInt &Divisor = CI->getValue ();
5632
+
5633
+ SignedDivisionByConstantInfo magics =
5634
+ SignedDivisionByConstantInfo::get (Divisor);
5635
+ int NumeratorFactor = 0 ;
5636
+ int ShiftMask = -1 ;
5637
+
5638
+ if (Divisor.isOne () || Divisor.isAllOnes ()) {
5639
+ // If d is +1/-1, we just multiply the numerator by +1/-1.
5640
+ NumeratorFactor = Divisor.getSExtValue ();
5641
+ magics.Magic = 0 ;
5642
+ magics.ShiftAmount = 0 ;
5643
+ ShiftMask = 0 ;
5644
+ } else if (Divisor.isStrictlyPositive () && magics.Magic .isNegative ()) {
5645
+ // If d > 0 and m < 0, add the numerator.
5646
+ NumeratorFactor = 1 ;
5647
+ } else if (Divisor.isNegative () && magics.Magic .isStrictlyPositive ()) {
5648
+ // If d < 0 and m > 0, subtract the numerator.
5649
+ NumeratorFactor = -1 ;
5650
+ }
5651
+
5652
+ MagicFactors.push_back (MIB.buildConstant (ScalarTy, magics.Magic ).getReg (0 ));
5653
+ Factors.push_back (MIB.buildConstant (ScalarTy, NumeratorFactor).getReg (0 ));
5654
+ Shifts.push_back (
5655
+ MIB.buildConstant (ScalarShiftAmtTy, magics.ShiftAmount ).getReg (0 ));
5656
+ ShiftMasks.push_back (MIB.buildConstant (ScalarTy, ShiftMask).getReg (0 ));
5657
+
5658
+ return true ;
5659
+ };
5660
+
5661
+ // Collect the shifts/magic values from each element.
5584
5662
bool Matched = matchUnaryPredicate (MRI, RHS, BuildSDIVPattern);
5585
5663
(void )Matched;
5586
5664
assert (Matched && " Expected unary predicate match to succeed" );
5587
5665
5588
- Register Shift, Factor;
5589
- if (Ty.isVector ()) {
5590
- Shift = MIB.buildBuildVector (ShiftAmtTy, Shifts).getReg (0 );
5666
+ Register MagicFactor, Factor, Shift, ShiftMask;
5667
+ auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI);
5668
+ if (RHSDef) {
5669
+ MagicFactor = MIB.buildBuildVector (Ty, MagicFactors).getReg (0 );
5591
5670
Factor = MIB.buildBuildVector (Ty, Factors).getReg (0 );
5671
+ Shift = MIB.buildBuildVector (ShiftAmtTy, Shifts).getReg (0 );
5672
+ ShiftMask = MIB.buildBuildVector (Ty, ShiftMasks).getReg (0 );
5592
5673
} else {
5593
- Shift = Shifts[0 ];
5674
+ assert (MRI.getType (RHS).isScalar () &&
5675
+ " Non-build_vector operation should have been a scalar" );
5676
+ MagicFactor = MagicFactors[0 ];
5594
5677
Factor = Factors[0 ];
5678
+ Shift = Shifts[0 ];
5679
+ ShiftMask = ShiftMasks[0 ];
5595
5680
}
5596
5681
5597
- Register Res = LHS;
5682
+ Register Q = LHS;
5683
+ Q = MIB.buildSMulH (Ty, LHS, MagicFactor).getReg (0 );
5684
+
5685
+ // (Optionally) Add/subtract the numerator using Factor.
5686
+ Factor = MIB.buildMul (Ty, LHS, Factor).getReg (0 );
5687
+ Q = MIB.buildAdd (Ty, Q, Factor).getReg (0 );
5598
5688
5599
- if (UseSRA)
5600
- Res = MIB.buildAShr (Ty, Res , Shift, MachineInstr::IsExact ).getReg (0 );
5689
+ // Shift right algebraic by shift value.
5690
+ Q = MIB.buildAShr (Ty, Q , Shift).getReg (0 );
5601
5691
5602
- return MIB.buildMul (Ty, Res, Factor);
5692
+ // Extract the sign bit, mask it and add it to the quotient.
5693
+ auto SignShift = MIB.buildConstant (ShiftAmtTy, EltBits - 1 );
5694
+ auto T = MIB.buildLShr (Ty, Q, SignShift);
5695
+ T = MIB.buildAnd (Ty, T, ShiftMask);
5696
+ return MIB.buildAdd (Ty, Q, T);
5603
5697
}
5604
5698
5605
5699
bool CombinerHelper::matchDivByPow2 (MachineInstr &MI, bool IsSigned) const {
0 commit comments