Skip to content

Commit 2fc7ba1

Browse files
committed
[GlobaISel] Allow expanding of sdiv -> mul by constant combine for general case
1 parent 4e30f81 commit 2fc7ba1

File tree

12 files changed

+2289
-317
lines changed

12 files changed

+2289
-317
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CombinerHelper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ class CombinerHelper {
143143
/// Query is legal on the target.
144144
bool isLegalOrBeforeLegalizer(const LegalityQuery &Query) const;
145145

146+
/// \return true if \p Query is legal on the target, or if \p Query will
147+
/// perform WidenScalar action on the target.
148+
bool isLegalorHasWidenScalar(const LegalityQuery &Query) const;
149+
146150
/// \return true if the combine is running prior to legalization, or if \p Ty
147151
/// is a legal integer constant type on the target.
148152
bool isConstantLegalOrBeforeLegalizer(const LLT Ty) const;

llvm/include/llvm/Target/GlobalISel/Combine.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2046,9 +2046,9 @@ def all_combines : GICombineGroup<[integer_reassoc_combines, trivial_combines,
20462046
div_rem_to_divrem, funnel_shift_combines, bitreverse_shift, commute_shift,
20472047
form_bitfield_extract, constant_fold_binops, constant_fold_fma,
20482048
constant_fold_cast_op, fabs_fneg_fold,
2049-
intdiv_combines, mulh_combines, redundant_neg_operands,
2049+
mulh_combines, redundant_neg_operands,
20502050
and_or_disjoint_mask, fma_combines, fold_binop_into_select,
2051-
sub_add_reg, select_to_minmax,
2051+
intdiv_combines, sub_add_reg, select_to_minmax,
20522052
fsub_to_fneg, commute_constant_to_rhs, match_ands, match_ors,
20532053
simplify_neg_minmax, combine_concat_vector,
20542054
sext_trunc, zext_trunc, prefer_sign_combines, shuffle_combines,

llvm/lib/CodeGen/GlobalISel/CombinerHelper.cpp

Lines changed: 114 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ bool CombinerHelper::isLegalOrBeforeLegalizer(
162162
return isPreLegalize() || isLegal(Query);
163163
}
164164

165+
bool CombinerHelper::isLegalorHasWidenScalar(const LegalityQuery &Query) const {
166+
return isLegal(Query) ||
167+
LI->getAction(Query).Action == LegalizeActions::WidenScalar;
168+
}
169+
165170
bool CombinerHelper::isConstantLegalOrBeforeLegalizer(const LLT Ty) const {
166171
if (!Ty.isVector())
167172
return isLegalOrBeforeLegalizer({TargetOpcode::G_CONSTANT, {Ty}});
@@ -5510,6 +5515,8 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
55105515
Register Dst = MI.getOperand(0).getReg();
55115516
Register RHS = MI.getOperand(2).getReg();
55125517
LLT DstTy = MRI.getType(Dst);
5518+
auto SizeInBits = DstTy.getScalarSizeInBits();
5519+
LLT WideTy = DstTy.changeElementSize(SizeInBits * 2);
55135520

55145521
auto &MF = *MI.getMF();
55155522
AttributeList Attr = MF.getFunction().getAttributes();
@@ -5529,8 +5536,21 @@ bool CombinerHelper::matchSDivByConst(MachineInstr &MI) const {
55295536
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
55305537
}
55315538

5532-
// Don't support the general case for now.
5533-
return false;
5539+
auto *RHSDef = MRI.getVRegDef(RHS);
5540+
if (!isConstantOrConstantVector(*RHSDef, MRI))
5541+
return false;
5542+
5543+
// Don't do this if the types are not going to be legal.
5544+
if (LI) {
5545+
if (!isLegalOrBeforeLegalizer({TargetOpcode::G_MUL, {DstTy, DstTy}}))
5546+
return false;
5547+
if (!isLegal({TargetOpcode::G_SMULH, {DstTy}}) &&
5548+
!isLegalorHasWidenScalar({TargetOpcode::G_MUL, {WideTy, WideTy}}))
5549+
return false;
5550+
}
5551+
5552+
return matchUnaryPredicate(
5553+
MRI, RHS, [](const Constant *C) { return C && !C->isNullValue(); });
55345554
}
55355555

55365556
void CombinerHelper::applySDivByConst(MachineInstr &MI) const {
@@ -5546,21 +5566,22 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
55465566
Register RHS = SDiv.getReg(2);
55475567
LLT Ty = MRI.getType(Dst);
55485568
LLT ScalarTy = Ty.getScalarType();
5569+
const unsigned EltBits = ScalarTy.getScalarSizeInBits();
55495570
LLT ShiftAmtTy = getTargetLowering().getPreferredShiftAmountTy(Ty);
55505571
LLT ScalarShiftAmtTy = ShiftAmtTy.getScalarType();
55515572
auto &MIB = Builder;
55525573

55535574
bool UseSRA = false;
5554-
SmallVector<Register, 16> Shifts, Factors;
5575+
SmallVector<Register, 16> ExactShifts, ExactFactors;
55555576

5556-
auto *RHSDef = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
5557-
bool IsSplat = getIConstantSplatVal(*RHSDef, MRI).has_value();
5577+
auto *RHSDefInstr = cast<GenericMachineInstr>(getDefIgnoringCopies(RHS, MRI));
5578+
bool IsSplat = getIConstantSplatVal(*RHSDefInstr, MRI).has_value();
55585579

5559-
auto BuildSDIVPattern = [&](const Constant *C) {
5580+
auto BuildExactSDIVPattern = [&](const Constant *C) {
55605581
// Don't recompute inverses for each splat element.
5561-
if (IsSplat && !Factors.empty()) {
5562-
Shifts.push_back(Shifts[0]);
5563-
Factors.push_back(Factors[0]);
5582+
if (IsSplat && !ExactFactors.empty()) {
5583+
ExactShifts.push_back(ExactShifts[0]);
5584+
ExactFactors.push_back(ExactFactors[0]);
55645585
return true;
55655586
}
55665587

@@ -5575,31 +5596,104 @@ MachineInstr *CombinerHelper::buildSDivUsingMul(MachineInstr &MI) const {
55755596
// Calculate the multiplicative inverse modulo BW.
55765597
// 2^W requires W + 1 bits, so we have to extend and then truncate.
55775598
APInt Factor = Divisor.multiplicativeInverse();
5578-
Shifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
5579-
Factors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
5599+
ExactShifts.push_back(MIB.buildConstant(ScalarShiftAmtTy, Shift).getReg(0));
5600+
ExactFactors.push_back(MIB.buildConstant(ScalarTy, Factor).getReg(0));
55805601
return true;
55815602
};
55825603

5583-
// Collect all magic values from the build vector.
5604+
if (MI.getFlag(MachineInstr::MIFlag::IsExact)) {
5605+
// Collect all magic values from the build vector.
5606+
bool Matched = matchUnaryPredicate(MRI, RHS, BuildExactSDIVPattern);
5607+
(void)Matched;
5608+
assert(Matched && "Expected unary predicate match to succeed");
5609+
5610+
Register Shift, Factor;
5611+
if (Ty.isVector()) {
5612+
Shift = MIB.buildBuildVector(ShiftAmtTy, ExactShifts).getReg(0);
5613+
Factor = MIB.buildBuildVector(Ty, ExactFactors).getReg(0);
5614+
} else {
5615+
Shift = ExactShifts[0];
5616+
Factor = ExactFactors[0];
5617+
}
5618+
5619+
Register Res = LHS;
5620+
5621+
if (UseSRA)
5622+
Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
5623+
5624+
return MIB.buildMul(Ty, Res, Factor);
5625+
}
5626+
5627+
SmallVector<Register, 16> MagicFactors, Factors, Shifts, ShiftMasks;
5628+
5629+
auto BuildSDIVPattern = [&](const Constant *C) {
5630+
auto *CI = cast<ConstantInt>(C);
5631+
const APInt &Divisor = CI->getValue();
5632+
5633+
SignedDivisionByConstantInfo magics =
5634+
SignedDivisionByConstantInfo::get(Divisor);
5635+
int NumeratorFactor = 0;
5636+
int ShiftMask = -1;
5637+
5638+
if (Divisor.isOne() || Divisor.isAllOnes()) {
5639+
// If d is +1/-1, we just multiply the numerator by +1/-1.
5640+
NumeratorFactor = Divisor.getSExtValue();
5641+
magics.Magic = 0;
5642+
magics.ShiftAmount = 0;
5643+
ShiftMask = 0;
5644+
} else if (Divisor.isStrictlyPositive() && magics.Magic.isNegative()) {
5645+
// If d > 0 and m < 0, add the numerator.
5646+
NumeratorFactor = 1;
5647+
} else if (Divisor.isNegative() && magics.Magic.isStrictlyPositive()) {
5648+
// If d < 0 and m > 0, subtract the numerator.
5649+
NumeratorFactor = -1;
5650+
}
5651+
5652+
MagicFactors.push_back(MIB.buildConstant(ScalarTy, magics.Magic).getReg(0));
5653+
Factors.push_back(MIB.buildConstant(ScalarTy, NumeratorFactor).getReg(0));
5654+
Shifts.push_back(
5655+
MIB.buildConstant(ScalarShiftAmtTy, magics.ShiftAmount).getReg(0));
5656+
ShiftMasks.push_back(MIB.buildConstant(ScalarTy, ShiftMask).getReg(0));
5657+
5658+
return true;
5659+
};
5660+
5661+
// Collect the shifts/magic values from each element.
55845662
bool Matched = matchUnaryPredicate(MRI, RHS, BuildSDIVPattern);
55855663
(void)Matched;
55865664
assert(Matched && "Expected unary predicate match to succeed");
55875665

5588-
Register Shift, Factor;
5589-
if (Ty.isVector()) {
5590-
Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
5666+
Register MagicFactor, Factor, Shift, ShiftMask;
5667+
auto *RHSDef = getOpcodeDef<GBuildVector>(RHS, MRI);
5668+
if (RHSDef) {
5669+
MagicFactor = MIB.buildBuildVector(Ty, MagicFactors).getReg(0);
55915670
Factor = MIB.buildBuildVector(Ty, Factors).getReg(0);
5671+
Shift = MIB.buildBuildVector(ShiftAmtTy, Shifts).getReg(0);
5672+
ShiftMask = MIB.buildBuildVector(Ty, ShiftMasks).getReg(0);
55925673
} else {
5593-
Shift = Shifts[0];
5674+
assert(MRI.getType(RHS).isScalar() &&
5675+
"Non-build_vector operation should have been a scalar");
5676+
MagicFactor = MagicFactors[0];
55945677
Factor = Factors[0];
5678+
Shift = Shifts[0];
5679+
ShiftMask = ShiftMasks[0];
55955680
}
55965681

5597-
Register Res = LHS;
5682+
Register Q = LHS;
5683+
Q = MIB.buildSMulH(Ty, LHS, MagicFactor).getReg(0);
5684+
5685+
// (Optionally) Add/subtract the numerator using Factor.
5686+
Factor = MIB.buildMul(Ty, LHS, Factor).getReg(0);
5687+
Q = MIB.buildAdd(Ty, Q, Factor).getReg(0);
55985688

5599-
if (UseSRA)
5600-
Res = MIB.buildAShr(Ty, Res, Shift, MachineInstr::IsExact).getReg(0);
5689+
// Shift right algebraic by shift value.
5690+
Q = MIB.buildAShr(Ty, Q, Shift).getReg(0);
56015691

5602-
return MIB.buildMul(Ty, Res, Factor);
5692+
// Extract the sign bit, mask it and add it to the quotient.
5693+
auto SignShift = MIB.buildConstant(ShiftAmtTy, EltBits - 1);
5694+
auto T = MIB.buildLShr(Ty, Q, SignShift);
5695+
T = MIB.buildAnd(Ty, T, ShiftMask);
5696+
return MIB.buildAdd(Ty, Q, T);
56035697
}
56045698

56055699
bool CombinerHelper::matchDivByPow2(MachineInstr &MI, bool IsSigned) const {

0 commit comments

Comments
 (0)