Skip to content

Commit 22cd89f

Browse files
committed
add combine rule to simplify vector stores
1 parent f676be0 commit 22cd89f

File tree

2 files changed

+186
-130
lines changed

2 files changed

+186
-130
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4619,26 +4619,109 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
46194619
return SDValue();
46204620
}
46214621

4622+
// If {Lo, Hi} = <packed f32x2 val>, returns that value
4623+
static SDValue peekThroughF32x2Copy(const SDValue &Lo, const SDValue &Hi) {
4624+
if (Lo.getValueType() != MVT::f32 || Lo.getOpcode() != ISD::CopyFromReg ||
4625+
Lo.getNode() != Hi.getNode() || Lo == Hi)
4626+
return SDValue();
4627+
4628+
SDNode *CopyF = Lo.getNode();
4629+
SDNode *CopyT = CopyF->getOperand(0).getNode();
4630+
if (CopyT->getOpcode() != ISD::CopyToReg)
4631+
return SDValue();
4632+
4633+
// check the two registers are the same
4634+
if (cast<RegisterSDNode>(CopyF->getOperand(1))->getReg() !=
4635+
cast<RegisterSDNode>(CopyT->getOperand(1))->getReg())
4636+
return SDValue();
4637+
4638+
SDValue OrigV = CopyT->getOperand(2);
4639+
if (OrigV.getValueType() != MVT::i64)
4640+
return SDValue();
4641+
return OrigV;
4642+
}
4643+
4644+
static SDValue
4645+
PerformPackedF32StoreCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
4646+
CodeGenOptLevel OptLevel) {
4647+
if (OptLevel == CodeGenOptLevel::None)
4648+
return SDValue();
4649+
4650+
// rewrite stores of packed f32 values
4651+
auto *MemN = cast<MemSDNode>(N);
4652+
if (MemN->getMemoryVT() == MVT::f32) {
4653+
std::optional<NVPTXISD::NodeType> NewOpcode;
4654+
switch (MemN->getOpcode()) {
4655+
case NVPTXISD::StoreRetvalV2:
4656+
NewOpcode = NVPTXISD::StoreRetval;
4657+
break;
4658+
case NVPTXISD::StoreRetvalV4:
4659+
NewOpcode = NVPTXISD::StoreRetvalV2;
4660+
break;
4661+
case NVPTXISD::StoreParamV2:
4662+
NewOpcode = NVPTXISD::StoreParam;
4663+
break;
4664+
case NVPTXISD::StoreParamV4:
4665+
NewOpcode = NVPTXISD::StoreParamV2;
4666+
break;
4667+
}
4668+
4669+
if (NewOpcode) {
4670+
SmallVector<SDValue> NewOps = {N->getOperand(0), N->getOperand(1)};
4671+
unsigned NumPacked = 0;
4672+
4673+
// gather all packed operands
4674+
for (unsigned I = 2, E = MemN->getNumOperands(); I < E; I += 2) {
4675+
if (SDValue Packed = peekThroughF32x2Copy(MemN->getOperand(I),
4676+
MemN->getOperand(I + 1))) {
4677+
NewOps.push_back(Packed);
4678+
++NumPacked;
4679+
} else {
4680+
NumPacked = 0;
4681+
break;
4682+
}
4683+
}
4684+
4685+
if (NumPacked) {
4686+
return DCI.DAG.getMemIntrinsicNode(
4687+
*NewOpcode, SDLoc(N), N->getVTList(), NewOps, MVT::i64,
4688+
MemN->getPointerInfo(), MemN->getAlign(),
4689+
MachineMemOperand::MOStore);
4690+
}
4691+
}
4692+
}
4693+
return SDValue();
4694+
}
4695+
46224696
static SDValue PerformStoreCombineHelper(SDNode *N, std::size_t Front,
4623-
std::size_t Back) {
4697+
std::size_t Back,
4698+
TargetLowering::DAGCombinerInfo &DCI,
4699+
CodeGenOptLevel OptLevel) {
46244700
if (all_of(N->ops().drop_front(Front).drop_back(Back),
46254701
[](const SDUse &U) { return U.get()->isUndef(); }))
46264702
// Operand 0 is the previous value in the chain. Cannot return EntryToken
46274703
// as the previous value will become unused and eliminated later.
46284704
return N->getOperand(0);
46294705

4706+
if (SDValue V = PerformPackedF32StoreCombine(N, DCI, OptLevel))
4707+
return V;
4708+
46304709
return SDValue();
46314710
}
46324711

4633-
static SDValue PerformStoreParamCombine(SDNode *N) {
4712+
static SDValue PerformStoreParamCombine(SDNode *N,
4713+
TargetLowering::DAGCombinerInfo &DCI,
4714+
CodeGenOptLevel OptLevel) {
46344715
// Operands from the 3rd to the 2nd last one are the values to be stored.
46354716
// {Chain, ArgID, Offset, Val, Glue}
4636-
return PerformStoreCombineHelper(N, 3, 1);
4717+
return PerformStoreCombineHelper(N, 3, 1, DCI, OptLevel);
46374718
}
46384719

4639-
static SDValue PerformStoreRetvalCombine(SDNode *N) {
4720+
static SDValue PerformStoreRetvalCombine(SDNode *N,
4721+
TargetLowering::DAGCombinerInfo &DCI,
4722+
CodeGenOptLevel OptLevel) {
46404723
// Operands from the 2nd to the last one are the values to be stored
4641-
return PerformStoreCombineHelper(N, 2, 0);
4724+
return PerformStoreCombineHelper(N, 2, 0, DCI, OptLevel);
46424725
}
46434726

46444727
/// PerformADDCombine - Target-specific dag combine xforms for ISD::ADD.
@@ -5329,11 +5412,11 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
53295412
case NVPTXISD::StoreRetval:
53305413
case NVPTXISD::StoreRetvalV2:
53315414
case NVPTXISD::StoreRetvalV4:
5332-
return PerformStoreRetvalCombine(N);
5415+
return PerformStoreRetvalCombine(N, DCI, OptLevel);
53335416
case NVPTXISD::StoreParam:
53345417
case NVPTXISD::StoreParamV2:
53355418
case NVPTXISD::StoreParamV4:
5336-
return PerformStoreParamCombine(N);
5419+
return PerformStoreParamCombine(N, DCI, OptLevel);
53375420
case ISD::EXTRACT_VECTOR_ELT:
53385421
return PerformEXTRACTCombine(N, DCI);
53395422
case ISD::VSELECT:

0 commit comments

Comments
 (0)