Skip to content

Commit b700760

Browse files
committed
[GlobaISel] Allow expanding of sdiv -> mul by constant combine for general case
1 parent 84e5451 commit b700760

File tree

12 files changed

+2289
-317
lines changed

12 files changed

+2289
-317
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ class CombinerHelper {
143143
/// Query is legal on the target.
144144
bool isLegalOrBeforeLegalizer(const LegalityQuery &Query) const;
145145

146+
/// \return true if \p Query is legal on the target, or if \p Query will
147+
/// perform WidenScalar action on the target.
148+
bool isLegalorHasWidenScalar(const LegalityQuery &Query) const;
149+
146150
/// \return true if the combine is running prior to legalization, or if \p Ty
147151
/// is a legal integer constant type on the target.
148152
bool isConstantLegalOrBeforeLegalizer(const LLT Ty) const;

llvm/include/llvm/Target/GlobalISel/Combine.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2054,9 +2054,9 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
20542054
div_rem_to_divrem, funnel_shift_combines, bitreverse_shift, commute_shift,
20552055
form_bitfield_extract, constant_fold_binops, constant_fold_fma,
20562056
constant_fold_cast_op, fabs_fneg_fold,
2057-
intdiv_combines, mulh_combines, redundant_neg_operands,
2057+
mulh_combines, redundant_neg_operands,
20582058
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
2059-
intrem_combines, sub_add_reg, select_to_minmax,
2059+
intrem_combines, intdiv_combines, sub_add_reg, select_to_minmax,
20602060
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
20612061
simplify_neg_minmax, combine_concat_vector,
20622062
sext_trunc, zext_trunc, prefer_sign_combines, shuffle_combines,

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 114 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ bool CombinerHelper::isLegalOrBeforeLegalizer(
162162
return isPreLegalize() || isLegal(Query);
163163
}
164164

165+
bool CombinerHelper::isLegalorHasWidenScalar(const LegalityQuery &Query) const {
166+
return isLegal(Query) ||
167+
LI->getAction(Query).Action == LegalizeActions::WidenScalar;
168+
}
169+
165170
bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
166171
if (!Ty.isVector())
167172
return isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, {Ty}});
@@ -5522,6 +5527,8 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
55225527
Register Dst = MI.getOperand(0).getReg();
55235528
Register RHS = MI.getOperand(2).getReg();
55245529
LLT DstTy = MRI.getType(Dst);
5530+
auto SizeInBits = DstTy.getScalarSizeInBits();
5531+
LLT WideTy = DstTy.changeElementSize(SizeInBits * 2);
55255532

55265533
auto &MF = *MI.getMF();
55275534
AttributeList Attr = MF.getFunction().getAttributes();
@@ -5541,8 +5548,21 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
55415548
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
55425549
}
55435550

5544-
// Don't support the general case for now.
5545-
return false;
5551+
auto *RHSDef = MRI.getVRegDef(RHS);
5552+
if (!isConstantOrConstantVector(*RHSDef, MRI))
5553+
return false;
5554+
5555+
// Don't do this if the types are not going to be legal.
5556+
if (LI) {
5557+
if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}}))
5558+
return false;
5559+
if (!isLegal({TargetOpcode::G_SMULH, {DstTy}}) &&
5560+
!isLegalorHasWidenScalar({TargetOpcode::G_MUL, {WideTy, WideTy}}))
5561+
return false;
5562+
}
5563+
5564+
return matchUnaryPredicate(
5565+
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
55465566
}
55475567

55485568
void CombinerHelper::applySDivByConst(MachineInstr &MI) const {
@@ -5558,21 +5578,22 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
55585578
Register RHS = SDiv.getReg(2);
55595579
LLT Ty = MRI.getType(Dst);
55605580
LLT ScalarTy = Ty.getScalarType();
5581+
const unsigned EltBits = ScalarTy.getScalarSizeInBits();
55615582
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
55625583
LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
55635584
auto &MIB = Builder;
55645585

55655586
bool UseSRA = false;
5566-
SmallVector<Register, 16> Shifts, Factors;
5587+
SmallVector<Register, 16> ExactShifts, ExactFactors;
55675588

5568-
auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
5569-
bool IsSplat = getIConstantSplatVal(*RHSDef, MRI).has_value();
5589+
auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
5590+
bool IsSplat = getIConstantSplatVal(*RHSDefInstr, MRI).has_value();
55705591

5571-
auto BuildSDIVPattern = [&](const Constant *C) {
5592+
auto BuildExactSDIVPattern = [&](const Constant *C) {
55725593
// Don't recompute inverses for each splat element.
5573-
if (IsSplat && !Factors.empty()) {
5574-
Shifts.push_back(Shifts[0]);
5575-
Factors.push_back(Factors[0]);
5594+
if (IsSplat && !ExactFactors.empty()) {
5595+
ExactShifts.push_back(ExactShifts[0]);
5596+
ExactFactors.push_back(ExactFactors[0]);
55765597
return true;
55775598
}
55785599

@@ -5587,31 +5608,104 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
55875608
// Calculate the multiplicative inverse modulo BW.
55885609
// 2^W requires W + 1 bits, so we have to extend and then truncate.
55895610
APInt Factor = Divisor.multiplicativeInverse();
5590-
Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
5591-
Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
5611+
ExactShifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
5612+
ExactFactors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
55925613
return true;
55935614
};
55945615

5595-
// Collect all magic values from the build vector.
5616+
if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
5617+
// Collect all magic values from the build vector.
5618+
bool Matched = matchUnaryPredicate(MRI, RHS, BuildExactSDIVPattern);
5619+
(void)Matched;
5620+
assert(Matched && "Expected unary predicate match to succeed");
5621+
5622+
Register Shift, Factor;
5623+
if (Ty.isVector()) {
5624+
Shift = MIB.buildBuildVector(ShiftAmtTy, ExactShifts).getReg(0);
5625+
Factor = MIB.buildBuildVector(Ty, ExactFactors).getReg(0);
5626+
} else {
5627+
Shift = ExactShifts[0];
5628+
Factor = ExactFactors[0];
5629+
}
5630+
5631+
Register Res = LHS;
5632+
5633+
if (UseSRA)
5634+
Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
5635+
5636+
return MIB.buildMul(Ty, Res, Factor);
5637+
}
5638+
5639+
SmallVector<Register, 16> MagicFactors, Factors, Shifts, ShiftMasks;
5640+
5641+
auto BuildSDIVPattern = [&](const Constant *C) {
5642+
auto *CI = cast<ConstantInt>(C);
5643+
const APInt &Divisor = CI->getValue();
5644+
5645+
SignedDivisionByConstantInfo magics =
5646+
SignedDivisionByConstantInfo::get(Divisor);
5647+
int NumeratorFactor = 0;
5648+
int ShiftMask = -1;
5649+
5650+
if (Divisor.isOne() || Divisor.isAllOnes()) {
5651+
// If d is +1/-1, we just multiply the numerator by +1/-1.
5652+
NumeratorFactor = Divisor.getSExtValue();
5653+
magics.Magic = 0;
5654+
magics.ShiftAmount = 0;
5655+
ShiftMask = 0;
5656+
} else if (Divisor.isStrictlyPositive() && magics.Magic.isNegative()) {
5657+
// If d > 0 and m < 0, add the numerator.
5658+
NumeratorFactor = 1;
5659+
} else if (Divisor.isNegative() && magics.Magic.isStrictlyPositive()) {
5660+
// If d < 0 and m > 0, subtract the numerator.
5661+
NumeratorFactor = -1;
5662+
}
5663+
5664+
MagicFactors.push_back(MIB.buildConstant(ScalarTy, magics.Magic).getReg(0));
5665+
Factors.push_back(MIB.buildConstant(ScalarTy, NumeratorFactor).getReg(0));
5666+
Shifts.push_back(
5667+
MIB.buildConstant(ScalarShiftAmtTy, magics.ShiftAmount).getReg(0));
5668+
ShiftMasks.push_back(MIB.buildConstant(ScalarTy, ShiftMask).getReg(0));
5669+
5670+
return true;
5671+
};
5672+
5673+
// Collect the shifts/magic values from each element.
55965674
bool Matched = matchUnaryPredicate(MRI, RHS, BuildSDIVPattern);
55975675
(void)Matched;
55985676
assert(Matched && "Expected unary predicate match to succeed");
55995677

5600-
Register Shift, Factor;
5601-
if (Ty.isVector()) {
5602-
Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
5678+
Register MagicFactor, Factor, Shift, ShiftMask;
5679+
auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI);
5680+
if (RHSDef) {
5681+
MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0);
56035682
Factor = MIB.buildBuildVector(Ty, Factors).getReg(0);
5683+
Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
5684+
ShiftMask = MIB.buildBuildVector(Ty, ShiftMasks).getReg(0);
56045685
} else {
5605-
Shift = Shifts[0];
5686+
assert(MRI.getType(RHS).isScalar() &&
5687+
"Non-build_vector operation should have been a scalar");
5688+
MagicFactor = MagicFactors[0];
56065689
Factor = Factors[0];
5690+
Shift = Shifts[0];
5691+
ShiftMask = ShiftMasks[0];
56075692
}
56085693

5609-
Register Res = LHS;
5694+
Register Q = LHS;
5695+
Q = MIB.buildSMulH(Ty, LHS, MagicFactor).getReg(0);
5696+
5697+
// (Optionally) Add/subtract the numerator using Factor.
5698+
Factor = MIB.buildMul(Ty, LHS, Factor).getReg(0);
5699+
Q = MIB.buildAdd(Ty, Q, Factor).getReg(0);
56105700

5611-
if (UseSRA)
5612-
Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
5701+
// Shift right algebraic by shift value.
5702+
Q = MIB.buildAShr(Ty, Q, Shift).getReg(0);
56135703

5614-
return MIB.buildMul(Ty, Res, Factor);
5704+
// Extract the sign bit, mask it and add it to the quotient.
5705+
auto SignShift = MIB.buildConstant(ShiftAmtTy, EltBits - 1);
5706+
auto T = MIB.buildLShr(Ty, Q, SignShift);
5707+
T = MIB.buildAnd(Ty, T, ShiftMask);
5708+
return MIB.buildAdd(Ty, Q, T);
56155709
}
56165710

56175711
bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const {

0 commit comments

Comments
 (0)