diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h index 6ad54fcd6d0e5..8cf00cc370821 100644 --- a/llvm/include/llvm/CodeGen/AsmPrinter.h +++ b/llvm/include/llvm/CodeGen/AsmPrinter.h @@ -26,6 +26,7 @@ #include "llvm/CodeGen/StackMaps.h" #include "llvm/DebugInfo/CodeView/CodeView.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" #include @@ -34,6 +35,7 @@ #include namespace llvm { +extern cl::opt EmitJumpTableSizesSection; class AddrLabelMap; class AsmPrinterHandler; diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index a2c3b50b24670..6a93569e52b28 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -168,7 +168,7 @@ static cl::opt BBAddrMapSkipEmitBBEntries( "unnecessary for some PGOAnalysisMap features."), cl::Hidden, cl::init(false)); -static cl::opt EmitJumpTableSizesSection( +cl::opt llvm::EmitJumpTableSizesSection( "emit-jump-table-sizes-section", cl::desc("Emit a section containing jump table addresses and sizes"), cl::Hidden, cl::init(false)); diff --git a/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp b/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp index 7d1819134d162..3a8f559be942c 100644 --- a/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp +++ b/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp @@ -232,6 +232,7 @@ struct BPFOperand : public MCParsedAsmOperand { .Case("callx", true) .Case("goto", true) .Case("gotol", true) + .Case("gotox", true) .Case("may_goto", true) .Case("*", true) .Case("exit", true) diff --git a/llvm/lib/Target/BPF/BPFAsmPrinter.cpp b/llvm/lib/Target/BPF/BPFAsmPrinter.cpp index 5dd71cc91427a..e2856bab354c8 100644 --- a/llvm/lib/Target/BPF/BPFAsmPrinter.cpp +++ b/llvm/lib/Target/BPF/BPFAsmPrinter.cpp @@ -57,6 +57,7 @@ class BPFAsmPrinter : public AsmPrinter { } // namespace bool BPFAsmPrinter::doInitialization(Module &M) { + EmitJumpTableSizesSection = true; AsmPrinter::doInitialization(M); // Only emit BTF when debuginfo available. diff --git a/llvm/lib/Target/BPF/BPFISelLowering.cpp b/llvm/lib/Target/BPF/BPFISelLowering.cpp index f4f414d192df0..154db34be786a 100644 --- a/llvm/lib/Target/BPF/BPFISelLowering.cpp +++ b/llvm/lib/Target/BPF/BPFISelLowering.cpp @@ -38,6 +38,10 @@ static cl::opt BPFExpandMemcpyInOrder("bpf-expand-memcpy-in-order", cl::Hidden, cl::init(false), cl::desc("Expand memcpy into load/store pairs in order")); +static cl::opt BPFMinimumJumpTableEntries( + "bpf-min-jump-table-entries", cl::init(4), cl::Hidden, + cl::desc("Set minimum number of entries to use a jump table on BPF")); + static void fail(const SDLoc &DL, SelectionDAG &DAG, const Twine &Msg, SDValue Val = {}) { std::string Str; @@ -67,12 +71,13 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BR_CC, MVT::i64, Custom); setOperationAction(ISD::BR_JT, MVT::Other, Expand); - setOperationAction(ISD::BRIND, MVT::Other, Expand); setOperationAction(ISD::BRCOND, MVT::Other, Expand); setOperationAction(ISD::TRAP, MVT::Other, Custom); - setOperationAction({ISD::GlobalAddress, ISD::ConstantPool}, MVT::i64, Custom); + setOperationAction({ISD::GlobalAddress, ISD::ConstantPool, ISD::JumpTable, + ISD::BlockAddress}, + MVT::i64, Custom); setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom); setOperationAction(ISD::STACKSAVE, MVT::Other, Expand); @@ -159,6 +164,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM, setBooleanContents(ZeroOrOneBooleanContent); setMaxAtomicSizeInBitsSupported(64); + setMinimumJumpTableEntries(BPFMinimumJumpTableEntries); // Function alignments setMinFunctionAlignment(Align(8)); @@ -316,10 +322,14 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { report_fatal_error("unimplemented opcode: " + Twine(Op.getOpcode())); case ISD::BR_CC: return LowerBR_CC(Op, DAG); + case ISD::JumpTable: + return LowerJumpTable(Op, DAG); case ISD::GlobalAddress: return LowerGlobalAddress(Op, DAG); case ISD::ConstantPool: return LowerConstantPool(Op, DAG); + case ISD::BlockAddress: + return LowerBlockAddress(Op, DAG); case ISD::SELECT_CC: return LowerSELECT_CC(Op, DAG); case ISD::SDIV: @@ -780,6 +790,11 @@ SDValue BPFTargetLowering::LowerTRAP(SDValue Op, SelectionDAG &DAG) const { return LowerCall(CLI, InVals); } +SDValue BPFTargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const { + JumpTableSDNode *N = cast(Op); + return getAddr(N, DAG); +} + const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const { switch ((BPFISD::NodeType)Opcode) { case BPFISD::FIRST_NUMBER: @@ -811,6 +826,17 @@ static SDValue getTargetNode(ConstantPoolSDNode *N, const SDLoc &DL, EVT Ty, N->getOffset(), Flags); } +static SDValue getTargetNode(BlockAddressSDNode *N, const SDLoc &DL, EVT Ty, + SelectionDAG &DAG, unsigned Flags) { + return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, N->getOffset(), + Flags); +} + +static SDValue getTargetNode(JumpTableSDNode *N, const SDLoc &DL, EVT Ty, + SelectionDAG &DAG, unsigned Flags) { + return DAG.getTargetJumpTable(N->getIndex(), Ty, Flags); +} + template SDValue BPFTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags) const { @@ -837,6 +863,12 @@ SDValue BPFTargetLowering::LowerConstantPool(SDValue Op, return getAddr(N, DAG); } +SDValue BPFTargetLowering::LowerBlockAddress(SDValue Op, + SelectionDAG &DAG) const { + BlockAddressSDNode *N = cast(Op); + return getAddr(N, DAG); +} + unsigned BPFTargetLowering::EmitSubregExt(MachineInstr &MI, MachineBasicBlock *BB, unsigned Reg, bool isSigned) const { diff --git a/llvm/lib/Target/BPF/BPFISelLowering.h b/llvm/lib/Target/BPF/BPFISelLowering.h index 23cbce7094e6b..acb8f27c647d7 100644 --- a/llvm/lib/Target/BPF/BPFISelLowering.h +++ b/llvm/lib/Target/BPF/BPFISelLowering.h @@ -81,6 +81,8 @@ class BPFTargetLowering : public TargetLowering { SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const; SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const; SDValue LowerTRAP(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const; template SDValue getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const; diff --git a/llvm/lib/Target/BPF/BPFInstrInfo.cpp b/llvm/lib/Target/BPF/BPFInstrInfo.cpp index 70bc163615f61..78626c39e80f7 100644 --- a/llvm/lib/Target/BPF/BPFInstrInfo.cpp +++ b/llvm/lib/Target/BPF/BPFInstrInfo.cpp @@ -181,6 +181,10 @@ bool BPFInstrInfo::analyzeBranch(MachineBasicBlock &MBB, if (!isUnpredicatedTerminator(*I)) break; + // If a JX insn, we're done. + if (I->getOpcode() == BPF::JX) + break; + // A terminator that isn't a branch can't easily be handled // by this analysis. if (!I->isBranch()) @@ -259,3 +263,40 @@ unsigned BPFInstrInfo::removeBranch(MachineBasicBlock &MBB, return Count; } + +int BPFInstrInfo::getJumpTableIndex(const MachineInstr &MI) const { + // The pattern looks like: + // %0 = LD_imm64 %jump-table.0 ; load jump-table address + // %1 = ADD_rr %0, $another_reg ; address + offset + // %2 = LDD %1, 0 ; load the actual label + // JX %2 + const MachineFunction &MF = *MI.getParent()->getParent(); + const MachineRegisterInfo &MRI = MF.getRegInfo(); + + Register Reg = MI.getOperand(0).getReg(); + if (!Reg.isVirtual()) + return -1; + MachineInstr *Ldd = MRI.getUniqueVRegDef(Reg); + if (Ldd == nullptr || Ldd->getOpcode() != BPF::LDD) + return -1; + + Reg = Ldd->getOperand(1).getReg(); + if (!Reg.isVirtual()) + return -1; + MachineInstr *Add = MRI.getUniqueVRegDef(Reg); + if (Add == nullptr || Add->getOpcode() != BPF::ADD_rr) + return -1; + + Reg = Add->getOperand(1).getReg(); + if (!Reg.isVirtual()) + return -1; + MachineInstr *LDimm64 = MRI.getUniqueVRegDef(Reg); + if (LDimm64 == nullptr || LDimm64->getOpcode() != BPF::LD_imm64) + return -1; + + const MachineOperand &MO = LDimm64->getOperand(1); + if (!MO.isJTI()) + return -1; + + return MO.getIndex(); +} diff --git a/llvm/lib/Target/BPF/BPFInstrInfo.h b/llvm/lib/Target/BPF/BPFInstrInfo.h index d8bbad44e314e..d88e37975980a 100644 --- a/llvm/lib/Target/BPF/BPFInstrInfo.h +++ b/llvm/lib/Target/BPF/BPFInstrInfo.h @@ -58,6 +58,9 @@ class BPFInstrInfo : public BPFGenInstrInfo { MachineBasicBlock *FBB, ArrayRef Cond, const DebugLoc &DL, int *BytesAdded = nullptr) const override; + + int getJumpTableIndex(const MachineInstr &MI) const override; + private: void expandMEMCPY(MachineBasicBlock::iterator) const; diff --git a/llvm/lib/Target/BPF/BPFInstrInfo.td b/llvm/lib/Target/BPF/BPFInstrInfo.td index b21f1a0eee3b0..c715bdb01866a 100644 --- a/llvm/lib/Target/BPF/BPFInstrInfo.td +++ b/llvm/lib/Target/BPF/BPFInstrInfo.td @@ -183,6 +183,15 @@ class TYPE_LD_ST mode, bits<2> size, let Inst{60-59} = size; } +// For indirect jump +class TYPE_IND_JMP op, bits<1> srctype, + dag outs, dag ins, string asmstr, list pattern> + : InstBPF { + + let Inst{63-60} = op; + let Inst{59} = srctype; +} + // jump instructions class JMP_RR : TYPE_ALU_JMP let BPFClass = BPF_JMP; } +class JMP_IND Pattern> + : TYPE_ALU_JMP { + bits<4> dst; + + let Inst{51-48} = dst; + let BPFClass = BPF_JMP; +} + class JMP_JCOND Pattern> : TYPE_ALU_JMP; defm JSLE : J; def JCOND : JMP_JCOND; + +let isIndirectBranch = 1 in { + def JX : JMP_IND; +} } // ALU instructions @@ -851,6 +876,8 @@ let usesCustomInserter = 1, isCodeGenOnly = 1 in { // load 64-bit global addr into register def : Pat<(BPFWrapper tglobaladdr:$in), (LD_imm64 tglobaladdr:$in)>; def : Pat<(BPFWrapper tconstpool:$in), (LD_imm64 tconstpool:$in)>; +def : Pat<(BPFWrapper tblockaddress:$in), (LD_imm64 tblockaddress:$in)>; +def : Pat<(BPFWrapper tjumptable:$in), (LD_imm64 tjumptable:$in)>; // 0xffffFFFF doesn't fit into simm32, optimize common case def : Pat<(i64 (and (i64 GPR:$src), 0xffffFFFF)), diff --git a/llvm/lib/Target/BPF/BPFMCInstLower.cpp b/llvm/lib/Target/BPF/BPFMCInstLower.cpp index 040a1fb750702..164d172c241c8 100644 --- a/llvm/lib/Target/BPF/BPFMCInstLower.cpp +++ b/llvm/lib/Target/BPF/BPFMCInstLower.cpp @@ -77,6 +77,9 @@ void BPFMCInstLower::Lower(const MachineInstr *MI, MCInst &OutMI) const { case MachineOperand::MO_ConstantPoolIndex: MCOp = LowerSymbolOperand(MO, Printer.GetCPISymbol(MO.getIndex())); break; + case MachineOperand::MO_JumpTableIndex: + MCOp = LowerSymbolOperand(MO, Printer.GetJTISymbol(MO.getIndex())); + break; } OutMI.addOperand(MCOp);