diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 4539efd591c8b..b097fb9b414ea 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3606,11 +3606,15 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits, if (Use.getOperandNo() == 1 && Bits >= Log2_32(Subtarget->getXLen())) break; return false; - case RISCV::SLLI: + case RISCV::SLLI: { // SLLI only uses the lower (XLen - ShAmt) bits. - if (Bits >= Subtarget->getXLen() - User->getConstantOperandVal(1)) + uint64_t ShAmt = User->getConstantOperandVal(1); + if (Bits >= Subtarget->getXLen() - ShAmt) + break; + if (hasAllNBitUsers(User, Bits + ShAmt, Depth + 1)) break; return false; + } case RISCV::ANDI: if (Bits >= (unsigned)llvm::bit_width(User->getConstantOperandVal(1))) break; @@ -3621,20 +3625,39 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits, break; [[fallthrough]]; } + case RISCV::COPY: + case RISCV::PHI: + case RISCV::ADD: + case RISCV::ADDI: case RISCV::AND: + case RISCV::MUL: case RISCV::OR: + case RISCV::SUB: case RISCV::XOR: case RISCV::XORI: case RISCV::ANDN: + case RISCV::BREV8: + case RISCV::CLMUL: + case RISCV::ORC_B: case RISCV::ORN: case RISCV::XNOR: case RISCV::SH1ADD: case RISCV::SH2ADD: case RISCV::SH3ADD: + case RISCV::BSETI: + case RISCV::BCLRI: + case RISCV::BINVI: RecCheck: if (hasAllNBitUsers(User, Bits, Depth + 1)) break; return false; + case RISCV::CZERO_EQZ: + case RISCV::CZERO_NEZ: + if (Use.getOperandNo() != 0) + return false; + if (hasAllNBitUsers(User, Bits, Depth + 1)) + break; + return false; case RISCV::SRLI: { unsigned ShAmt = User->getConstantOperandVal(1); // If we are shifting right by less than Bits, and users don't demand any @@ -3670,6 +3693,10 @@ bool RISCVDAGToDAGISel::hasAllNBitUsers(SDNode *Node, unsigned Bits, if (Use.getOperandNo() == 0 && Bits >= 32) break; return false; + case RISCV::BEXTI: + if (User->getConstantOperandVal(1) >= Bits) + return false; + break; case RISCV::SB: if (Use.getOperandNo() == 0 && Bits >= 8) break;