Skip to content

Commit 0a77c0e

Browse files
asbDisasm
authored andcommitted
[RISCV] Custom-legalise 32-bit variable shifts on RV64
The previous DAG combiner-based approach had an issue with infinite loops between the target-dependent and target-independent combiner logic (see PR40333). Although this was worked around in rL351806, the combiner-based approach is still potentially brittle and can fail to select the 32-bit shift variant when profitable to do so, as demonstrated in the pr40333.ll test case. This patch instead introduces target-specific SelectionDAG nodes for SHLW/SRLW/SRAW and custom-lowers variable i32 shifts to them. pr40333.ll is a good example of how this approach can improve codegen. This adds DAG combine that does SimplifyDemandedBits on the operands (only lower 32-bits of first operand and lower 5 bits of second operand are read). This seems better than implementing SimplifyDemandedBitsForTargetNode as there is no guarantee that would be called (and it's not for e.g. the anyext return test cases). Also implements ComputeNumSignBitsForTargetNode. There are codegen changes in atomic-rmw.ll and atomic-cmpxchg.ll but the new instruction sequences are semantically equivalent. Differential Revision: https://reviews.llvm.org/D57085 llvm-svn: 352169
1 parent c6daa4f commit 0a77c0e

File tree

6 files changed

+366
-335
lines changed

6 files changed

+366
-335
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 86 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
8181
setOperationAction(ISD::SIGN_EXTEND_INREG, VT, Expand);
8282

8383
if (Subtarget.is64Bit()) {
84-
setTargetDAGCombine(ISD::SHL);
85-
setTargetDAGCombine(ISD::SRL);
86-
setTargetDAGCombine(ISD::SRA);
8784
setTargetDAGCombine(ISD::ANY_EXTEND);
85+
setOperationAction(ISD::SHL, MVT::i32, Custom);
86+
setOperationAction(ISD::SRA, MVT::i32, Custom);
87+
setOperationAction(ISD::SRL, MVT::i32, Custom);
8888
}
8989

9090
if (!Subtarget.hasStdExtM()) {
@@ -513,15 +513,52 @@ SDValue RISCVTargetLowering::lowerRETURNADDR(SDValue Op,
513513
return DAG.getCopyFromReg(DAG.getEntryNode(), DL, Reg, XLenVT);
514514
}
515515

516-
// Return true if the given node is a shift with a non-constant shift amount.
517-
static bool isVariableShift(SDValue Val) {
518-
switch (Val.getOpcode()) {
516+
// Returns the opcode of the target-specific SDNode that implements the 32-bit
517+
// form of the given Opcode.
518+
static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) {
519+
switch (Opcode) {
519520
default:
520-
return false;
521+
llvm_unreachable("Unexpected opcode");
521522
case ISD::SHL:
523+
return RISCVISD::SLLW;
522524
case ISD::SRA:
525+
return RISCVISD::SRAW;
523526
case ISD::SRL:
524-
return Val.getOperand(1).getOpcode() != ISD::Constant;
527+
return RISCVISD::SRLW;
528+
}
529+
}
530+
531+
// Converts the given 32-bit operation to a target-specific SelectionDAG node.
532+
// Because i32 isn't a legal type for RV64, these operations would otherwise
533+
// be promoted to i64, making it difficult to select the SLLW/DIVUW/.../*W
534+
// later one because the fact the operation was originally of type i32 is
535+
// lost.
536+
static SDValue customLegalizeToWOp(SDNode *N, SelectionDAG &DAG) {
537+
SDLoc DL(N);
538+
RISCVISD::NodeType WOpcode = getRISCVWOpcode(N->getOpcode());
539+
SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0));
540+
SDValue NewOp1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1));
541+
SDValue NewRes = DAG.getNode(WOpcode, DL, MVT::i64, NewOp0, NewOp1);
542+
// ReplaceNodeResults requires we maintain the same type for the return value.
543+
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes);
544+
}
545+
546+
void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
547+
SmallVectorImpl<SDValue> &Results,
548+
SelectionDAG &DAG) const {
549+
SDLoc DL(N);
550+
switch (N->getOpcode()) {
551+
default:
552+
llvm_unreachable("Don't know how to custom type legalize this operation!");
553+
case ISD::SHL:
554+
case ISD::SRA:
555+
case ISD::SRL:
556+
assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() &&
557+
"Unexpected custom legalisation");
558+
if (N->getOperand(1).getOpcode() == ISD::Constant)
559+
return;
560+
Results.push_back(customLegalizeToWOp(N, DAG));
561+
break;
525562
}
526563
}
527564

@@ -546,34 +583,14 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
546583
switch (N->getOpcode()) {
547584
default:
548585
break;
549-
case ISD::SHL:
550-
case ISD::SRL:
551-
case ISD::SRA: {
552-
assert(Subtarget.getXLen() == 64 && "Combine should be 64-bit only");
553-
if (!DCI.isBeforeLegalize())
554-
break;
555-
SDValue RHS = N->getOperand(1);
556-
if (N->getValueType(0) != MVT::i32 || RHS->getOpcode() == ISD::Constant ||
557-
(RHS->getOpcode() == ISD::AssertZext &&
558-
cast<VTSDNode>(RHS->getOperand(1))->getVT().getSizeInBits() <= 5))
559-
break;
560-
SDValue LHS = N->getOperand(0);
561-
SDLoc DL(N);
562-
SDValue NewRHS =
563-
DAG.getNode(ISD::AssertZext, DL, RHS.getValueType(), RHS,
564-
DAG.getValueType(EVT::getIntegerVT(*DAG.getContext(), 5)));
565-
return DCI.CombineTo(
566-
N, DAG.getNode(N->getOpcode(), DL, LHS.getValueType(), LHS, NewRHS));
567-
}
568586
case ISD::ANY_EXTEND: {
569-
// If any-extending an i32 variable-length shift or sdiv/udiv/urem to i64,
570-
// then instead sign-extend in order to increase the chance of being able
571-
// to select the sllw/srlw/sraw/divw/divuw/remuw instructions.
587+
// If any-extending an i32 sdiv/udiv/urem to i64, then instead sign-extend
588+
// in order to increase the chance of being able to select the
589+
// divw/divuw/remuw instructions.
572590
SDValue Src = N->getOperand(0);
573591
if (N->getValueType(0) != MVT::i64 || Src.getValueType() != MVT::i32)
574592
break;
575-
if (!isVariableShift(Src) &&
576-
!(Subtarget.hasStdExtM() && isVariableSDivUDivURem(Src)))
593+
if (!(Subtarget.hasStdExtM() && isVariableSDivUDivURem(Src)))
577594
break;
578595
SDLoc DL(N);
579596
// Don't add the new node to the DAGCombiner worklist, in order to avoid
@@ -590,11 +607,42 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
590607
break;
591608
return DCI.CombineTo(N, Op0.getOperand(0), Op0.getOperand(1));
592609
}
610+
case RISCVISD::SLLW:
611+
case RISCVISD::SRAW:
612+
case RISCVISD::SRLW: {
613+
// Only the lower 32 bits of LHS and lower 5 bits of RHS are read.
614+
SDValue LHS = N->getOperand(0);
615+
SDValue RHS = N->getOperand(1);
616+
APInt LHSMask = APInt::getLowBitsSet(LHS.getValueSizeInBits(), 32);
617+
APInt RHSMask = APInt::getLowBitsSet(RHS.getValueSizeInBits(), 5);
618+
if ((SimplifyDemandedBits(N->getOperand(0), LHSMask, DCI)) ||
619+
(SimplifyDemandedBits(N->getOperand(1), RHSMask, DCI)))
620+
return SDValue();
621+
break;
622+
}
593623
}
594624

595625
return SDValue();
596626
}
597627

628+
unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
629+
SDValue Op, const APInt &DemandedElts, const SelectionDAG &DAG,
630+
unsigned Depth) const {
631+
switch (Op.getOpcode()) {
632+
default:
633+
break;
634+
case RISCVISD::SLLW:
635+
case RISCVISD::SRAW:
636+
case RISCVISD::SRLW:
637+
// TODO: As the result is sign-extended, this is conservatively correct. A
638+
// more precise answer could be calculated for SRAW depending on known
639+
// bits in the shift amount.
640+
return 33;
641+
}
642+
643+
return 1;
644+
}
645+
598646
static MachineBasicBlock *emitSplitF64Pseudo(MachineInstr &MI,
599647
MachineBasicBlock *BB) {
600648
assert(MI.getOpcode() == RISCV::SplitF64Pseudo && "Unexpected instruction");
@@ -1683,6 +1731,12 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
16831731
return "RISCVISD::SplitF64";
16841732
case RISCVISD::TAIL:
16851733
return "RISCVISD::TAIL";
1734+
case RISCVISD::SLLW:
1735+
return "RISCVISD::SLLW";
1736+
case RISCVISD::SRAW:
1737+
return "RISCVISD::SRAW";
1738+
case RISCVISD::SRLW:
1739+
return "RISCVISD::SRLW";
16861740
}
16871741
return nullptr;
16881742
}

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ enum NodeType : unsigned {
3232
SELECT_CC,
3333
BuildPairF64,
3434
SplitF64,
35-
TAIL
35+
TAIL,
36+
// RV64I shifts, directly matching the semantics of the named RISC-V
37+
// instructions.
38+
SLLW,
39+
SRAW,
40+
SRLW
3641
};
3742
}
3843

@@ -58,9 +63,16 @@ class RISCVTargetLowering : public TargetLowering {
5863

5964
// Provide custom lowering hooks for some operations.
6065
SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override;
66+
void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
67+
SelectionDAG &DAG) const override;
6168

6269
SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;
6370

71+
unsigned ComputeNumSignBitsForTargetNode(SDValue Op,
72+
const APInt &DemandedElts,
73+
const SelectionDAG &DAG,
74+
unsigned Depth) const override;
75+
6476
// This method returns the name of a target specific DAG node.
6577
const char *getTargetNodeName(unsigned Opcode) const override;
6678

llvm/lib/Target/RISCV/RISCVInstrInfo.td

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def riscv_selectcc : SDNode<"RISCVISD::SELECT_CC", SDT_RISCVSelectCC,
5252
def riscv_tail : SDNode<"RISCVISD::TAIL", SDT_RISCVCall,
5353
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
5454
SDNPVariadic]>;
55+
def riscv_sllw : SDNode<"RISCVISD::SLLW", SDTIntShiftOp>;
56+
def riscv_sraw : SDNode<"RISCVISD::SRAW", SDTIntShiftOp>;
57+
def riscv_srlw : SDNode<"RISCVISD::SRLW", SDTIntShiftOp>;
5558

5659
//===----------------------------------------------------------------------===//
5760
// Operand and SDNode transformation definitions.
@@ -668,21 +671,9 @@ def sexti32 : PatFrags<(ops node:$src),
668671
def assertzexti32 : PatFrag<(ops node:$src), (assertzext node:$src), [{
669672
return cast<VTSDNode>(N->getOperand(1))->getVT() == MVT::i32;
670673
}]>;
671-
def assertzexti5 : PatFrag<(ops node:$src), (assertzext node:$src), [{
672-
return cast<VTSDNode>(N->getOperand(1))->getVT().getSizeInBits() <= 5;
673-
}]>;
674674
def zexti32 : PatFrags<(ops node:$src),
675675
[(and node:$src, 0xffffffff),
676676
(assertzexti32 node:$src)]>;
677-
// Defines a legal mask for (assertzexti5 (and src, mask)) to be combinable
678-
// with a shiftw operation. The mask mustn't modify the lower 5 bits or the
679-
// upper 32 bits.
680-
def shiftwamt_mask : ImmLeaf<XLenVT, [{
681-
return countTrailingOnes<uint64_t>(Imm) >= 5 && isUInt<32>(Imm);
682-
}]>;
683-
def shiftwamt : PatFrags<(ops node:$src),
684-
[(assertzexti5 (and node:$src, shiftwamt_mask)),
685-
(assertzexti5 node:$src)]>;
686677

687678
/// Immediates
688679

@@ -942,28 +933,9 @@ def : Pat<(sext_inreg (shl GPR:$rs1, uimm5:$shamt), i32),
942933
def : Pat<(sra (sext_inreg GPR:$rs1, i32), uimm5:$shamt),
943934
(SRAIW GPR:$rs1, uimm5:$shamt)>;
944935

945-
// For variable-length shifts, we rely on assertzexti5 being inserted during
946-
// lowering (see RISCVTargetLowering::PerformDAGCombine). This enables us to
947-
// guarantee that selecting a 32-bit variable shift is legal (as the variable
948-
// shift is known to be <= 32). We must also be careful not to create
949-
// semantically incorrect patterns. For instance, selecting SRLW for
950-
// (srl (zexti32 GPR:$rs1), (shiftwamt GPR:$rs2)),
951-
// is not guaranteed to be safe, as we don't know whether the upper 32-bits of
952-
// the result are used or not (in the case where rs2=0, this is a
953-
// sign-extension operation).
954-
955-
def : Pat<(sext_inreg (shl GPR:$rs1, (shiftwamt GPR:$rs2)), i32),
956-
(SLLW GPR:$rs1, GPR:$rs2)>;
957-
def : Pat<(zexti32 (shl GPR:$rs1, (shiftwamt GPR:$rs2))),
958-
(SRLI (SLLI (SLLW GPR:$rs1, GPR:$rs2), 32), 32)>;
959-
960-
def : Pat<(sext_inreg (srl (zexti32 GPR:$rs1), (shiftwamt GPR:$rs2)), i32),
961-
(SRLW GPR:$rs1, GPR:$rs2)>;
962-
def : Pat<(zexti32 (srl (zexti32 GPR:$rs1), (shiftwamt GPR:$rs2))),
963-
(SRLI (SLLI (SRLW GPR:$rs1, GPR:$rs2), 32), 32)>;
964-
965-
def : Pat<(sra (sexti32 GPR:$rs1), (shiftwamt GPR:$rs2)),
966-
(SRAW GPR:$rs1, GPR:$rs2)>;
936+
def : PatGprGpr<riscv_sllw, SLLW>;
937+
def : PatGprGpr<riscv_srlw, SRLW>;
938+
def : PatGprGpr<riscv_sraw, SRAW>;
967939

968940
/// Loads
969941

0 commit comments

Comments
 (0)