From db8616817248c054eabd96e1ea0d7a28661023ab Mon Sep 17 00:00:00 2001 From: Yonghong Song Date: Mon, 31 Mar 2025 21:25:26 -0700 Subject: [PATCH] [RFC][BPF] Support Jump Table NOTE: We probably need cpu v5 or other flags to enable this feature. We can add it later when necessary. This patch adds jump table support. A new insn 'gotox ' is added to allow goto through a register. The register represents the address in the current section. The function is a concrete example with bpf selftest progs/user_ringbuf_success.c. Compilation command line to generate .s file: ============================================= clang -g -Wall -Werror -D__TARGET_ARCH_x86 -mlittle-endian \ -I/home/yhs/work/bpf-next/tools/testing/selftests/bpf/tools/include \ -I/home/yhs/work/bpf-next/tools/testing/selftests/bpf \ -I/home/yhs/work/bpf-next/tools/include/uapi \ -I/home/yhs/work/bpf-next/tools/testing/selftests/usr/include -std=gnu11 \ -fno-strict-aliasing -Wno-compare-distinct-pointer-types \ -idirafter /home/yhs/work/llvm-project/llvm/build.21/Release/lib/clang/21/include \ -idirafter /usr/local/include -idirafter /usr/include \ -DENABLE_ATOMICS_TESTS -O2 -S progs/user_ringbuf_success.c \ -o /home/yhs/work/bpf-next/tools/testing/selftests/bpf/user_ringbuf_success.bpf.o.s \ --target=bpf -mcpu=v3 The related assembly: read_protocol_msg: ... r3 <<= 3 r1 = .LJTI1_0 ll r1 += r3 r1 = *(u64 *)(r1 + 0) gotox r1 LBB1_4: r1 = *(u64 *)(r0 + 8) goto LBB1_5 LBB1_7: r1 = *(u64 *)(r0 + 8) goto LBB1_8 LBB1_9: w1 = *(u32 *)(r0 + 8) r1 <<= 32 r1 s>>= 32 r2 = kern_mutated ll r3 = *(u64 *)(r2 + 0) r3 *= r1 *(u64 *)(r2 + 0) = r3 goto LBB1_11 LBB1_6: w1 = *(u32 *)(r0 + 8) r1 <<= 32 r1 s>>= 32 LBB1_5: ... .section .rodata,"a",@progbits .p2align 3, 0x0 .LJTI1_0: .quad LBB1_4 .quad LBB1_6 .quad LBB1_7 .quad LBB1_9 ... publish_next_kern_msg: ... r6 <<= 3 r1 = .LJTI6_0 ll r1 += r6 r1 = *(u64 *)(r1 + 0) gotox r1 LBB6_3: ... LBB6_5: ... LBB6_6: ... LBB6_4: ... .section .rodata,"a",@progbits .p2align 3, 0x0 .LJTI6_0: .quad LBB6_3 .quad LBB6_4 .quad LBB6_5 .quad LBB6_6 You can see in the above .LJTI1_0 and .LJTI6_0 are actually jump table targets and these two jump tables are used in insns so they can get proper jump table target with gotox insn. Now let us look at sections in .o file ======================================= For example, [ 6] .rodata PROGBITS 0000000000000000 000740 0000d6 00 A 0 0 8 [ 7] .rel.rodata REL 0000000000000000 003860 000080 10 I 39 6 8 [ 8] .llvm_jump_table_sizes LLVM_JT_SIZES 0000000000000000 000816 000010 00 0 0 1 [ 9] .rel.llvm_jump_table_sizes REL 0000000000000000 0038e0 000010 10 I 39 8 8 ... [14] .llvm_jump_table_sizes LLVM_JT_SIZES 0000000000000000 000958 000010 00 0 0 1 [15] .rel.llvm_jump_table_sizes REL 0000000000000000 003970 000010 10 I 39 14 8 With llvm-readelf dump section 8 and 14: $ llvm-readelf -x 8 user_ringbuf_success.bpf.o Hex dump of section '.llvm_jump_table_sizes': 0x00000000 00000000 00000000 04000000 00000000 ................ $ llvm-readelf -x 14 user_ringbuf_success.bpf.o Hex dump of section '.llvm_jump_table_sizes': 0x00000000 20000000 00000000 04000000 00000000 ............... You can see. There are two jump tables: jump table 1: offset 0, size 4 (4 labels) jump table 2: offset 0x20, size 4 (4 labels) Check sections 9 and 15, we can find the corresponding section: Relocation section '.rel.llvm_jump_table_sizes' at offset 0x38e0 contains 1 entries: Offset Info Type Symbol's Value Symbol's Name 0000000000000000 0000000a00000002 R_BPF_64_ABS64 0000000000000000 .rodata Relocation section '.rel.llvm_jump_table_sizes' at offset 0x3970 contains 1 entries: Offset Info Type Symbol's Value Symbol's Name 0000000000000000 0000000a00000002 R_BPF_64_ABS64 0000000000000000 .rodata and confirmed that the relocation is against '.rodata'. Dump .rodata section: 0x00000000 a8000000 00000000 10010000 00000000 ................ 0x00000010 b8000000 00000000 c8000000 00000000 ................ 0x00000020 28040000 00000000 00050000 00000000 (............... 0x00000030 70040000 00000000 b8040000 00000000 p............... 0x00000040 44726169 6e207265 7475726e 65643a20 Drain returned: So we can get two jump tables: .rodata offset 0, # of lables 4: 0x00000000 a8000000 00000000 10010000 00000000 ................ 0x00000010 b8000000 00000000 c8000000 00000000 ................ .rodata offset 0x200, # of lables 4: 0x00000020 28040000 00000000 00050000 00000000 (............... 0x00000030 70040000 00000000 b8040000 00000000 p............... This way, you just need to scan related code section. As long as it matches one of jump tables (.rodata relocation, offset also matching), you do not need to care about gotox at all in libbpf. An option -bpf-min-jump-table-entries is implemented to control the minimum number of entries to use a jump table on BPF. The default value 4, but it can be changed with the following clang option clang ... -mllvm -bpf-min-jump-table-entries=6 where the number of jump table cases needs to be >= 6 in order to use jump table. --- llvm/include/llvm/CodeGen/AsmPrinter.h | 2 + llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp | 2 +- .../lib/Target/BPF/AsmParser/BPFAsmParser.cpp | 1 + llvm/lib/Target/BPF/BPFAsmPrinter.cpp | 1 + llvm/lib/Target/BPF/BPFISelLowering.cpp | 36 +++++++++++++++- llvm/lib/Target/BPF/BPFISelLowering.h | 2 + llvm/lib/Target/BPF/BPFInstrInfo.cpp | 41 +++++++++++++++++++ llvm/lib/Target/BPF/BPFInstrInfo.h | 3 ++ llvm/lib/Target/BPF/BPFInstrInfo.td | 27 ++++++++++++ llvm/lib/Target/BPF/BPFMCInstLower.cpp | 3 ++ 10 files changed, 115 insertions(+), 3 deletions(-) diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h index 6ad54fcd6d0e5..8cf00cc370821 100644 --- a/llvm/include/llvm/CodeGen/AsmPrinter.h +++ b/llvm/include/llvm/CodeGen/AsmPrinter.h @@ -26,6 +26,7 @@ #include "llvm/CodeGen/StackMaps.h" #include "llvm/DebugInfo/CodeView/CodeView.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/ErrorHandling.h" #include @@ -34,6 +35,7 @@ #include namespace llvm { +extern cl::opt EmitJumpTableSizesSection; class AddrLabelMap; class AsmPrinterHandler; diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp index a2c3b50b24670..6a93569e52b28 100644 --- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp +++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp @@ -168,7 +168,7 @@ static cl::opt BBAddrMapSkipEmitBBEntries( "unnecessary for some PGOAnalysisMap features."), cl::Hidden, cl::init(false)); -static cl::opt EmitJumpTableSizesSection( +cl::opt llvm::EmitJumpTableSizesSection( "emit-jump-table-sizes-section", cl::desc("Emit a section containing jump table addresses and sizes"), cl::Hidden, cl::init(false)); diff --git a/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp b/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp index 7d1819134d162..3a8f559be942c 100644 --- a/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp +++ b/llvm/lib/Target/BPF/AsmParser/BPFAsmParser.cpp @@ -232,6 +232,7 @@ struct BPFOperand : public MCParsedAsmOperand { .Case("callx", true) .Case("goto", true) .Case("gotol", true) + .Case("gotox", true) .Case("may_goto", true) .Case("*", true) .Case("exit", true) diff --git a/llvm/lib/Target/BPF/BPFAsmPrinter.cpp b/llvm/lib/Target/BPF/BPFAsmPrinter.cpp index 5dd71cc91427a..e2856bab354c8 100644 --- a/llvm/lib/Target/BPF/BPFAsmPrinter.cpp +++ b/llvm/lib/Target/BPF/BPFAsmPrinter.cpp @@ -57,6 +57,7 @@ class BPFAsmPrinter : public AsmPrinter { } // namespace bool BPFAsmPrinter::doInitialization(Module &M) { + EmitJumpTableSizesSection = true; AsmPrinter::doInitialization(M); // Only emit BTF when debuginfo available. diff --git a/llvm/lib/Target/BPF/BPFISelLowering.cpp b/llvm/lib/Target/BPF/BPFISelLowering.cpp index f4f414d192df0..154db34be786a 100644 --- a/llvm/lib/Target/BPF/BPFISelLowering.cpp +++ b/llvm/lib/Target/BPF/BPFISelLowering.cpp @@ -38,6 +38,10 @@ static cl::opt BPFExpandMemcpyInOrder("bpf-expand-memcpy-in-order", cl::Hidden, cl::init(false), cl::desc("Expand memcpy into load/store pairs in order")); +static cl::opt BPFMinimumJumpTableEntries( + "bpf-min-jump-table-entries", cl::init(4), cl::Hidden, + cl::desc("Set minimum number of entries to use a jump table on BPF")); + static void fail(const SDLoc &DL, SelectionDAG &DAG, const Twine &Msg, SDValue Val = {}) { std::string Str; @@ -67,12 +71,13 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM, setOperationAction(ISD::BR_CC, MVT::i64, Custom); setOperationAction(ISD::BR_JT, MVT::Other, Expand); - setOperationAction(ISD::BRIND, MVT::Other, Expand); setOperationAction(ISD::BRCOND, MVT::Other, Expand); setOperationAction(ISD::TRAP, MVT::Other, Custom); - setOperationAction({ISD::GlobalAddress, ISD::ConstantPool}, MVT::i64, Custom); + setOperationAction({ISD::GlobalAddress, ISD::ConstantPool, ISD::JumpTable, + ISD::BlockAddress}, + MVT::i64, Custom); setOperationAction(ISD::DYNAMIC_STACKALLOC, MVT::i64, Custom); setOperationAction(ISD::STACKSAVE, MVT::Other, Expand); @@ -159,6 +164,7 @@ BPFTargetLowering::BPFTargetLowering(const TargetMachine &TM, setBooleanContents(ZeroOrOneBooleanContent); setMaxAtomicSizeInBitsSupported(64); + setMinimumJumpTableEntries(BPFMinimumJumpTableEntries); // Function alignments setMinFunctionAlignment(Align(8)); @@ -316,10 +322,14 @@ SDValue BPFTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { report_fatal_error("unimplemented opcode: " + Twine(Op.getOpcode())); case ISD::BR_CC: return LowerBR_CC(Op, DAG); + case ISD::JumpTable: + return LowerJumpTable(Op, DAG); case ISD::GlobalAddress: return LowerGlobalAddress(Op, DAG); case ISD::ConstantPool: return LowerConstantPool(Op, DAG); + case ISD::BlockAddress: + return LowerBlockAddress(Op, DAG); case ISD::SELECT_CC: return LowerSELECT_CC(Op, DAG); case ISD::SDIV: @@ -780,6 +790,11 @@ SDValue BPFTargetLowering::LowerTRAP(SDValue Op, SelectionDAG &DAG) const { return LowerCall(CLI, InVals); } +SDValue BPFTargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const { + JumpTableSDNode *N = cast(Op); + return getAddr(N, DAG); +} + const char *BPFTargetLowering::getTargetNodeName(unsigned Opcode) const { switch ((BPFISD::NodeType)Opcode) { case BPFISD::FIRST_NUMBER: @@ -811,6 +826,17 @@ static SDValue getTargetNode(ConstantPoolSDNode *N, const SDLoc &DL, EVT Ty, N->getOffset(), Flags); } +static SDValue getTargetNode(BlockAddressSDNode *N, const SDLoc &DL, EVT Ty, + SelectionDAG &DAG, unsigned Flags) { + return DAG.getTargetBlockAddress(N->getBlockAddress(), Ty, N->getOffset(), + Flags); +} + +static SDValue getTargetNode(JumpTableSDNode *N, const SDLoc &DL, EVT Ty, + SelectionDAG &DAG, unsigned Flags) { + return DAG.getTargetJumpTable(N->getIndex(), Ty, Flags); +} + template SDValue BPFTargetLowering::getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags) const { @@ -837,6 +863,12 @@ SDValue BPFTargetLowering::LowerConstantPool(SDValue Op, return getAddr(N, DAG); } +SDValue BPFTargetLowering::LowerBlockAddress(SDValue Op, + SelectionDAG &DAG) const { + BlockAddressSDNode *N = cast(Op); + return getAddr(N, DAG); +} + unsigned BPFTargetLowering::EmitSubregExt(MachineInstr &MI, MachineBasicBlock *BB, unsigned Reg, bool isSigned) const { diff --git a/llvm/lib/Target/BPF/BPFISelLowering.h b/llvm/lib/Target/BPF/BPFISelLowering.h index 23cbce7094e6b..acb8f27c647d7 100644 --- a/llvm/lib/Target/BPF/BPFISelLowering.h +++ b/llvm/lib/Target/BPF/BPFISelLowering.h @@ -81,6 +81,8 @@ class BPFTargetLowering : public TargetLowering { SDValue LowerConstantPool(SDValue Op, SelectionDAG &DAG) const; SDValue LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const; SDValue LowerTRAP(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const; + SDValue LowerJumpTable(SDValue Op, SelectionDAG &DAG) const; template SDValue getAddr(NodeTy *N, SelectionDAG &DAG, unsigned Flags = 0) const; diff --git a/llvm/lib/Target/BPF/BPFInstrInfo.cpp b/llvm/lib/Target/BPF/BPFInstrInfo.cpp index 70bc163615f61..78626c39e80f7 100644 --- a/llvm/lib/Target/BPF/BPFInstrInfo.cpp +++ b/llvm/lib/Target/BPF/BPFInstrInfo.cpp @@ -181,6 +181,10 @@ bool BPFInstrInfo::analyzeBranch(MachineBasicBlock &MBB, if (!isUnpredicatedTerminator(*I)) break; + // If a JX insn, we're done. + if (I->getOpcode() == BPF::JX) + break; + // A terminator that isn't a branch can't easily be handled // by this analysis. if (!I->isBranch()) @@ -259,3 +263,40 @@ unsigned BPFInstrInfo::removeBranch(MachineBasicBlock &MBB, return Count; } + +int BPFInstrInfo::getJumpTableIndex(const MachineInstr &MI) const { + // The pattern looks like: + // %0 = LD_imm64 %jump-table.0 ; load jump-table address + // %1 = ADD_rr %0, $another_reg ; address + offset + // %2 = LDD %1, 0 ; load the actual label + // JX %2 + const MachineFunction &MF = *MI.getParent()->getParent(); + const MachineRegisterInfo &MRI = MF.getRegInfo(); + + Register Reg = MI.getOperand(0).getReg(); + if (!Reg.isVirtual()) + return -1; + MachineInstr *Ldd = MRI.getUniqueVRegDef(Reg); + if (Ldd == nullptr || Ldd->getOpcode() != BPF::LDD) + return -1; + + Reg = Ldd->getOperand(1).getReg(); + if (!Reg.isVirtual()) + return -1; + MachineInstr *Add = MRI.getUniqueVRegDef(Reg); + if (Add == nullptr || Add->getOpcode() != BPF::ADD_rr) + return -1; + + Reg = Add->getOperand(1).getReg(); + if (!Reg.isVirtual()) + return -1; + MachineInstr *LDimm64 = MRI.getUniqueVRegDef(Reg); + if (LDimm64 == nullptr || LDimm64->getOpcode() != BPF::LD_imm64) + return -1; + + const MachineOperand &MO = LDimm64->getOperand(1); + if (!MO.isJTI()) + return -1; + + return MO.getIndex(); +} diff --git a/llvm/lib/Target/BPF/BPFInstrInfo.h b/llvm/lib/Target/BPF/BPFInstrInfo.h index d8bbad44e314e..d88e37975980a 100644 --- a/llvm/lib/Target/BPF/BPFInstrInfo.h +++ b/llvm/lib/Target/BPF/BPFInstrInfo.h @@ -58,6 +58,9 @@ class BPFInstrInfo : public BPFGenInstrInfo { MachineBasicBlock *FBB, ArrayRef Cond, const DebugLoc &DL, int *BytesAdded = nullptr) const override; + + int getJumpTableIndex(const MachineInstr &MI) const override; + private: void expandMEMCPY(MachineBasicBlock::iterator) const; diff --git a/llvm/lib/Target/BPF/BPFInstrInfo.td b/llvm/lib/Target/BPF/BPFInstrInfo.td index b21f1a0eee3b0..c715bdb01866a 100644 --- a/llvm/lib/Target/BPF/BPFInstrInfo.td +++ b/llvm/lib/Target/BPF/BPFInstrInfo.td @@ -183,6 +183,15 @@ class TYPE_LD_ST mode, bits<2> size, let Inst{60-59} = size; } +// For indirect jump +class TYPE_IND_JMP op, bits<1> srctype, + dag outs, dag ins, string asmstr, list pattern> + : InstBPF { + + let Inst{63-60} = op; + let Inst{59} = srctype; +} + // jump instructions class JMP_RR : TYPE_ALU_JMP let BPFClass = BPF_JMP; } +class JMP_IND Pattern> + : TYPE_ALU_JMP { + bits<4> dst; + + let Inst{51-48} = dst; + let BPFClass = BPF_JMP; +} + class JMP_JCOND Pattern> : TYPE_ALU_JMP; defm JSLE : J; def JCOND : JMP_JCOND; + +let isIndirectBranch = 1 in { + def JX : JMP_IND; +} } // ALU instructions @@ -851,6 +876,8 @@ let usesCustomInserter = 1, isCodeGenOnly = 1 in { // load 64-bit global addr into register def : Pat<(BPFWrapper tglobaladdr:$in), (LD_imm64 tglobaladdr:$in)>; def : Pat<(BPFWrapper tconstpool:$in), (LD_imm64 tconstpool:$in)>; +def : Pat<(BPFWrapper tblockaddress:$in), (LD_imm64 tblockaddress:$in)>; +def : Pat<(BPFWrapper tjumptable:$in), (LD_imm64 tjumptable:$in)>; // 0xffffFFFF doesn't fit into simm32, optimize common case def : Pat<(i64 (and (i64 GPR:$src), 0xffffFFFF)), diff --git a/llvm/lib/Target/BPF/BPFMCInstLower.cpp b/llvm/lib/Target/BPF/BPFMCInstLower.cpp index 040a1fb750702..164d172c241c8 100644 --- a/llvm/lib/Target/BPF/BPFMCInstLower.cpp +++ b/llvm/lib/Target/BPF/BPFMCInstLower.cpp @@ -77,6 +77,9 @@ void BPFMCInstLower::Lower(const MachineInstr *MI, MCInst &OutMI) const { case MachineOperand::MO_ConstantPoolIndex: MCOp = LowerSymbolOperand(MO, Printer.GetCPISymbol(MO.getIndex())); break; + case MachineOperand::MO_JumpTableIndex: + MCOp = LowerSymbolOperand(MO, Printer.GetJTISymbol(MO.getIndex())); + break; } OutMI.addOperand(MCOp);