@@ -162,6 +162,11 @@ bool CombinerHelper::isLegalOrBeforeLegalizer(
162
162
return isPreLegalize () || isLegal (Query);
163
163
}
164
164
165
+ bool CombinerHelper::isLegalorHasWidenScalar (const LegalityQuery &Query) const {
166
+ return isLegal (Query) ||
167
+ LI->getAction (Query).Action == LegalizeActions::WidenScalar;
168
+ }
169
+
165
170
bool CombinerHelper::isConstantLegalOrBeforeLegalizer (const LLT Ty) const {
166
171
if (!Ty.isVector ())
167
172
return isLegalOrBeforeLegalizer ({TargetOpcode::G_CONSTANT, {Ty}});
@@ -5522,6 +5527,8 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
5522
5527
Register Dst = MI.getOperand (0 ).getReg ();
5523
5528
Register RHS = MI.getOperand (2 ).getReg ();
5524
5529
LLT DstTy = MRI.getType (Dst);
5530
+ auto SizeInBits = DstTy.getScalarSizeInBits ();
5531
+ LLT WideTy = DstTy.changeElementSize (SizeInBits * 2 );
5525
5532
5526
5533
auto &MF = *MI.getMF ();
5527
5534
AttributeList Attr = MF.getFunction ().getAttributes ();
@@ -5541,8 +5548,21 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
5541
5548
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue (); });
5542
5549
}
5543
5550
5544
- // Don't support the general case for now.
5545
- return false ;
5551
+ auto *RHSDef = MRI.getVRegDef (RHS);
5552
+ if (!isConstantOrConstantVector (*RHSDef, MRI))
5553
+ return false ;
5554
+
5555
+ // Don't do this if the types are not going to be legal.
5556
+ if (LI) {
5557
+ if (!isLegalOrBeforeLegalizer ({TargetOpcode::G_MUL, {DstTy, DstTy}}))
5558
+ return false ;
5559
+ if (!isLegal ({TargetOpcode::G_SMULH, {DstTy}}) &&
5560
+ !isLegalorHasWidenScalar ({TargetOpcode::G_MUL, {WideTy, WideTy}}))
5561
+ return false ;
5562
+ }
5563
+
5564
+ return matchUnaryPredicate (
5565
+ MRI, RHS, [](const Constant *C) { return C && !C->isNullValue (); });
5546
5566
}
5547
5567
5548
5568
void CombinerHelper::applySDivByConst (MachineInstr &MI) const {
@@ -5558,21 +5578,22 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
5558
5578
Register RHS = SDiv.getReg (2 );
5559
5579
LLT Ty = MRI.getType (Dst);
5560
5580
LLT ScalarTy = Ty.getScalarType ();
5581
+ const unsigned EltBits = ScalarTy.getScalarSizeInBits ();
5561
5582
LLT ShiftAmtTy = getTargetLowering ().getPreferredShiftAmountTy (Ty);
5562
5583
LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType ();
5563
5584
auto &MIB = Builder;
5564
5585
5565
5586
bool UseSRA = false ;
5566
- SmallVector<Register, 16 > Shifts, Factors ;
5587
+ SmallVector<Register, 16 > ExactShifts, ExactFactors ;
5567
5588
5568
- auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies (RHS, MRI));
5569
- bool IsSplat = getIConstantSplatVal (*RHSDef , MRI).has_value ();
5589
+ auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies (RHS, MRI));
5590
+ bool IsSplat = getIConstantSplatVal (*RHSDefInstr , MRI).has_value ();
5570
5591
5571
- auto BuildSDIVPattern = [&](const Constant *C) {
5592
+ auto BuildExactSDIVPattern = [&](const Constant *C) {
5572
5593
// Don't recompute inverses for each splat element.
5573
- if (IsSplat && !Factors .empty ()) {
5574
- Shifts .push_back (Shifts [0 ]);
5575
- Factors .push_back (Factors [0 ]);
5594
+ if (IsSplat && !ExactFactors .empty ()) {
5595
+ ExactShifts .push_back (ExactShifts [0 ]);
5596
+ ExactFactors .push_back (ExactFactors [0 ]);
5576
5597
return true ;
5577
5598
}
5578
5599
@@ -5587,31 +5608,104 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
5587
5608
// Calculate the multiplicative inverse modulo BW.
5588
5609
// 2^W requires W + 1 bits, so we have to extend and then truncate.
5589
5610
APInt Factor = Divisor.multiplicativeInverse ();
5590
- Shifts .push_back (MIB.buildConstant (ScalarShiftAmtTy, Shift).getReg (0 ));
5591
- Factors .push_back (MIB.buildConstant (ScalarTy, Factor).getReg (0 ));
5611
+ ExactShifts .push_back (MIB.buildConstant (ScalarShiftAmtTy, Shift).getReg (0 ));
5612
+ ExactFactors .push_back (MIB.buildConstant (ScalarTy, Factor).getReg (0 ));
5592
5613
return true ;
5593
5614
};
5594
5615
5595
- // Collect all magic values from the build vector.
5616
+ if (MI.getFlag (MachineInstr::MIFlag::IsExact)) {
5617
+ // Collect all magic values from the build vector.
5618
+ bool Matched = matchUnaryPredicate (MRI, RHS, BuildExactSDIVPattern);
5619
+ (void )Matched;
5620
+ assert (Matched && " Expected unary predicate match to succeed" );
5621
+
5622
+ Register Shift, Factor;
5623
+ if (Ty.isVector ()) {
5624
+ Shift = MIB.buildBuildVector (ShiftAmtTy, ExactShifts).getReg (0 );
5625
+ Factor = MIB.buildBuildVector (Ty, ExactFactors).getReg (0 );
5626
+ } else {
5627
+ Shift = ExactShifts[0 ];
5628
+ Factor = ExactFactors[0 ];
5629
+ }
5630
+
5631
+ Register Res = LHS;
5632
+
5633
+ if (UseSRA)
5634
+ Res = MIB.buildAShr (Ty, Res, Shift, MachineInstr::IsExact).getReg (0 );
5635
+
5636
+ return MIB.buildMul (Ty, Res, Factor);
5637
+ }
5638
+
5639
+ SmallVector<Register, 16 > MagicFactors, Factors, Shifts, ShiftMasks;
5640
+
5641
+ auto BuildSDIVPattern = [&](const Constant *C) {
5642
+ auto *CI = cast<ConstantInt>(C);
5643
+ const APInt &Divisor = CI->getValue ();
5644
+
5645
+ SignedDivisionByConstantInfo magics =
5646
+ SignedDivisionByConstantInfo::get (Divisor);
5647
+ int NumeratorFactor = 0 ;
5648
+ int ShiftMask = -1 ;
5649
+
5650
+ if (Divisor.isOne () || Divisor.isAllOnes ()) {
5651
+ // If d is +1/-1, we just multiply the numerator by +1/-1.
5652
+ NumeratorFactor = Divisor.getSExtValue ();
5653
+ magics.Magic = 0 ;
5654
+ magics.ShiftAmount = 0 ;
5655
+ ShiftMask = 0 ;
5656
+ } else if (Divisor.isStrictlyPositive () && magics.Magic .isNegative ()) {
5657
+ // If d > 0 and m < 0, add the numerator.
5658
+ NumeratorFactor = 1 ;
5659
+ } else if (Divisor.isNegative () && magics.Magic .isStrictlyPositive ()) {
5660
+ // If d < 0 and m > 0, subtract the numerator.
5661
+ NumeratorFactor = -1 ;
5662
+ }
5663
+
5664
+ MagicFactors.push_back (MIB.buildConstant (ScalarTy, magics.Magic ).getReg (0 ));
5665
+ Factors.push_back (MIB.buildConstant (ScalarTy, NumeratorFactor).getReg (0 ));
5666
+ Shifts.push_back (
5667
+ MIB.buildConstant (ScalarShiftAmtTy, magics.ShiftAmount ).getReg (0 ));
5668
+ ShiftMasks.push_back (MIB.buildConstant (ScalarTy, ShiftMask).getReg (0 ));
5669
+
5670
+ return true ;
5671
+ };
5672
+
5673
+ // Collect the shifts/magic values from each element.
5596
5674
bool Matched = matchUnaryPredicate (MRI, RHS, BuildSDIVPattern);
5597
5675
(void )Matched;
5598
5676
assert (Matched && " Expected unary predicate match to succeed" );
5599
5677
5600
- Register Shift, Factor;
5601
- if (Ty.isVector ()) {
5602
- Shift = MIB.buildBuildVector (ShiftAmtTy, Shifts).getReg (0 );
5678
+ Register MagicFactor, Factor, Shift, ShiftMask;
5679
+ auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI);
5680
+ if (RHSDef) {
5681
+ MagicFactor = MIB.buildBuildVector (Ty, MagicFactors).getReg (0 );
5603
5682
Factor = MIB.buildBuildVector (Ty, Factors).getReg (0 );
5683
+ Shift = MIB.buildBuildVector (ShiftAmtTy, Shifts).getReg (0 );
5684
+ ShiftMask = MIB.buildBuildVector (Ty, ShiftMasks).getReg (0 );
5604
5685
} else {
5605
- Shift = Shifts[0 ];
5686
+ assert (MRI.getType (RHS).isScalar () &&
5687
+ " Non-build_vector operation should have been a scalar" );
5688
+ MagicFactor = MagicFactors[0 ];
5606
5689
Factor = Factors[0 ];
5690
+ Shift = Shifts[0 ];
5691
+ ShiftMask = ShiftMasks[0 ];
5607
5692
}
5608
5693
5609
- Register Res = LHS;
5694
+ Register Q = LHS;
5695
+ Q = MIB.buildSMulH (Ty, LHS, MagicFactor).getReg (0 );
5696
+
5697
+ // (Optionally) Add/subtract the numerator using Factor.
5698
+ Factor = MIB.buildMul (Ty, LHS, Factor).getReg (0 );
5699
+ Q = MIB.buildAdd (Ty, Q, Factor).getReg (0 );
5610
5700
5611
- if (UseSRA)
5612
- Res = MIB.buildAShr (Ty, Res , Shift, MachineInstr::IsExact ).getReg (0 );
5701
+ // Shift right algebraic by shift value.
5702
+ Q = MIB.buildAShr (Ty, Q , Shift).getReg (0 );
5613
5703
5614
- return MIB.buildMul (Ty, Res, Factor);
5704
+ // Extract the sign bit, mask it and add it to the quotient.
5705
+ auto SignShift = MIB.buildConstant (ShiftAmtTy, EltBits - 1 );
5706
+ auto T = MIB.buildLShr (Ty, Q, SignShift);
5707
+ T = MIB.buildAnd (Ty, T, ShiftMask);
5708
+ return MIB.buildAdd (Ty, Q, T);
5615
5709
}
5616
5710
5617
5711
bool CombinerHelper::matchDivByPow2 (MachineInstr &MI, bool IsSigned) const {
0 commit comments