diff --git a/lld/test/wasm/Inputs/branch-hints-multifile.ll b/lld/test/wasm/Inputs/branch-hints-multifile.ll new file mode 100644 index 0000000000000..ded4e546242a3 --- /dev/null +++ b/lld/test/wasm/Inputs/branch-hints-multifile.ll @@ -0,0 +1,14 @@ +define i32 @bw_bh_test_2(i32 %a, i32 %b) { +entry: + %1 = icmp ult i32 %a, %b + br i1 %1, label %fail, label %success, !prof !0 + +fail: + ret i32 -1 + +success: + ret i32 0 +} + +!0 = !{!"branch_weights", !"expected", i32 1, i32 2000} + diff --git a/lld/test/wasm/branch-hints-multifile.ll b/lld/test/wasm/branch-hints-multifile.ll new file mode 100644 index 0000000000000..57627df59aee9 --- /dev/null +++ b/lld/test/wasm/branch-hints-multifile.ll @@ -0,0 +1,30 @@ + +; RUN: llc -mtriple=wasm32-unknown-unknown -filetype=obj -o %t1.o < %s +; RUN: llc -mtriple=wasm32-unknown-unknown -filetype=obj -o %t2.o < %S/Inputs/branch-hints-multifile.ll +; RUN: wasm-ld -o %t.wasm %t1.o %t2.o --no-entry --no-gc-sections +; RUN: obj2yaml %t.wasm | FileCheck %s + +define i32 @bw_bh_test_1(i32 %a, i32 %b) { +entry: + %1 = icmp ult i32 %a, %b + br i1 %1, label %fail, label %success, !prof !0 + +fail: + ret i32 -1 + +success: + ret i32 0 +} + +!0 = !{!"branch_weights", !"expected", i32 2000, i32 1} + +; Test that we combine branch hint sections properly. The number of functions +; should be reported once at the start (even though it appears in each object +; file, and the hints for each object file should then be concatenated). +; CHECK: - Type: CUSTOM +; CHECK-NEXT: Name: metadata.code.branch_hint +; CHECK-NEXT: Payload: '8280808000818080800001080100828080800001080101' +; ^^ two functions (5-byte padded LEB) +; ^^hint for func 1^ +; ^^hint for func 2^ + diff --git a/lld/test/wasm/branch-hints.ll b/lld/test/wasm/branch-hints.ll new file mode 100644 index 0000000000000..49cd0ffe9abda --- /dev/null +++ b/lld/test/wasm/branch-hints.ll @@ -0,0 +1,31 @@ + +; RUN: llc -mtriple=wasm32-unknown-unknown -filetype=obj -o %t.o < %s +; RUN: wasm-ld -o %t.wasm %t.o --no-entry --no-gc-sections +; RUN: obj2yaml %t.wasm | FileCheck %s + +define i32 @bw_bh_test(i32 %a, i32 %b) { +; The weights below mean we are far more likely to go to %fail and return -1. +; Codegen will emit the -1 first, so we emit a hint of 0, below, for the value +; of the hint (as in llvm/test/Codegen/WebAssembly/branch-hints.ll). +entry: + %1 = icmp ult i32 %a, %b + br i1 %1, label %fail, label %success, !prof !0 + +fail: + ret i32 -1 + +success: + ret i32 0 +} + +!0 = !{!"branch_weights", !"expected", i32 2000, i32 1} + +; CHECK: - Type: CUSTOM +; CHECK-NEXT: Name: metadata.code.branch_hint +; CHECK-NEXT: Payload: '8180808000818080800001080100' +; ^^ one function (5-byte padded LEB) +; ^^^^^^^^^^ LEB of function index 1 +; ^^ one hint in function +; ^^ offset 8 +; ^^ hint size 1 +; ^^ hint value: 0 diff --git a/lld/wasm/InputChunks.h b/lld/wasm/InputChunks.h index 1fe78d76631f1..4c0d1b302b8c2 100644 --- a/lld/wasm/InputChunks.h +++ b/lld/wasm/InputChunks.h @@ -354,9 +354,11 @@ class InputSection : public InputChunk { const uint64_t tombstoneValue; + // XXX + const WasmSection §ion; + protected: static uint64_t getTombstoneForSection(StringRef name); - const WasmSection §ion; }; } // namespace wasm diff --git a/lld/wasm/Writer.cpp b/lld/wasm/Writer.cpp index b704677d36c93..443bcba2e2692 100644 --- a/lld/wasm/Writer.cpp +++ b/lld/wasm/Writer.cpp @@ -28,6 +28,7 @@ #include "llvm/BinaryFormat/Wasm.h" #include "llvm/Support/FileOutputBuffer.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LEB128.h" #include "llvm/Support/Parallel.h" #include "llvm/Support/RandomNumberGenerator.h" #include "llvm/Support/SHA1.h" @@ -94,6 +95,7 @@ class Writer { void addSections(); void createCustomSections(); + void createBranchHintSection(); void createSyntheticSections(); void createSyntheticSectionsPostLayout(); void finalizeSections(); @@ -164,7 +166,11 @@ void Writer::createCustomSections() { log("createCustomSections"); for (auto &pair : customSectionMapping) { StringRef name = pair.first; - LLVM_DEBUG(dbgs() << "createCustomSection: " << name << "\n"); + + if (name == "metadata.code.branch_hint") // XXX unneeded + continue; + + dbgs() << "createCustomSection: " << name << "\n"; OutputSection *sec = make(std::string(name), pair.second); if (ctx.arg.relocatable || ctx.arg.emitRelocs) { @@ -176,6 +182,122 @@ void Writer::createCustomSections() { } } +// A Branch Hint section is a Custom Section with some custom rules for how it +// is created. Rather than simply concatenate the input sections, we must also +// adjust the field that reports the number of functions, as follows. +// +// Our input chunks each begin with a 5-byte LEB of the number of functions. If +// we simply concatenated, we'd get this: +// +// ;; from object file 1 +// [num functions_1] : 5 byte LEB +// [..data_1..] +// ;; from object file 2 +// [num functions_2] : 5 byte LEB +// [..data_2..] +// .. +// ;; from object file N +// [num functions_N] : 5 byte LEB +// [..data_N..] +// +// But the final output should report the total number of functions at the very +// start. To fix that, we must accumulate the total number of functions and use +// that at the very start (which comes from the first object file), and we must +// remove the first 5 bytes of the others: +// +// [num functions_1 + _2 + .. + _N] : 5 byte LEB +// [..data_1..] +// [..data_2..] +// .. +// [..data_N..] +// +// That is now correct. +// +// Also, the Branch Hint section must appear *before* the code, so we call this +// earlier than for other custom sections. +void Writer::createBranchHintSection() { + std::string name = "metadata.code.branch_hint"; + + auto iter = customSectionMapping.find(name); + if (iter == customSectionMapping.end()) + return; + auto& inputChunks = iter->second; + + dbgs() << "createBranchHintSection!: " << name << "\n"; + + assert(!inputChunks.empty()); + CustomSection *sec; + if (inputChunks.size() == 1) { + // Just use the originals. We don't need to do any work. + sec = make(name, inputChunks); + } else { + // We need to merge the input chunks in the special format that the spec + // expects, as explained above. To do so, create new input chunks with those + // minor modifications, and then the normal custom section behavior of + // concatenating the chunks will give us the right output. + auto *newInputChunksAlloc = make>(inputChunks.size()); + auto &newInputChunks = *newInputChunksAlloc; + + // Remove the first 5 bytes from all sections but the first, and count how + // many functions there are (so we can add that to the first). + uint64_t totalFunctions = 0; + for (unsigned i = 1; i < inputChunks.size(); i++) { + assert(InputSection::classof(inputChunks[i])); + auto *section = static_cast(inputChunks[i]); + const WasmSection &wasmSection = section->section; + + // Read the number of functions in this section. + totalFunctions += decodeULEB128(wasmSection.Content.data()); + + // Create an adjusted wasm section, without the first 5 bytes. + WasmSection *adjustedWasmSection = make(wasmSection); + adjustedWasmSection->Content = adjustedWasmSection->Content.slice(5); + for (auto& relocation : adjustedWasmSection->Relocations) + relocation.Offset -= 5; + + newInputChunks[i] = make(*adjustedWasmSection, section->file, section->alignment); + newInputChunks[i]->setRelocations(adjustedWasmSection->Relocations); + } + + // Add the number of functions to the first section. + { + assert(InputSection::classof(inputChunks[0])); + auto *section = static_cast(inputChunks[0]); + const WasmSection &wasmSection = section->section; + + // Read the number of functions in this section. + totalFunctions += decodeULEB128(wasmSection.Content.data()); + + // Create an adjusted wasm section, with the first 5 bytes modified so that + // we apply the total number of functions. + WasmSection *adjustedWasmSection = make(wasmSection); + auto* adjustedContent = make>(adjustedWasmSection->Content.begin(), adjustedWasmSection->Content.end()); + + std::string str; + raw_string_ostream os(str); + encodeULEB128(totalFunctions, os, 5); + memcpy(adjustedContent->data(), str.data(), 5); + adjustedWasmSection->Content = ArrayRef(adjustedContent->data(), adjustedContent->size()); + + newInputChunks[0] = make(*adjustedWasmSection, section->file, section->alignment); + newInputChunks[0]->setRelocations(adjustedWasmSection->Relocations); + } + + sec = make(name, newInputChunks); + } + + // Otherwise, add the section normally, like any custom section. + auto *sym = make(sec); + out.linkingSec->addToSymtab(sym); + sec->sectionSym = sym; + addSection(sec); + + // After emitting this section, avoid processing it again in the place that + // custom sections are normally created, which is later in the binary (inside + // createCustomSections). + customSectionMapping.erase("branch_hint"); +} + // Create relocations sections in the final output. // These are only created when relocatable output is requested. void Writer::createRelocSections() { @@ -544,6 +666,9 @@ void Writer::addSections() { addSection(out.elemSec); addSection(out.dataCountSec); + // The Branch Hints section must be emitted before the code section. + createBranchHintSection(); + addSection(make(out.functionSec->inputFunctions)); addSection(make(segments)); diff --git a/llvm/include/llvm/MC/MCAsmBackend.h b/llvm/include/llvm/MC/MCAsmBackend.h index e49e786a10f58..0c88edcb15b78 100644 --- a/llvm/include/llvm/MC/MCAsmBackend.h +++ b/llvm/include/llvm/MC/MCAsmBackend.h @@ -127,6 +127,11 @@ class LLVM_ABI MCAsmBackend { const MCValue &Target, MutableArrayRef Data, uint64_t Value, bool IsResolved) = 0; + /// Given a ULEB128 of a particular padded size, return the fixup for it. + virtual MCFixupKind getULEB128Fixup(unsigned PadTo) const { + llvm_unreachable("Need to implement hook if target has ULEB128 fixups"); + } + /// @} /// \name Target Relaxation Interfaces diff --git a/llvm/include/llvm/MC/MCObjectStreamer.h b/llvm/include/llvm/MC/MCObjectStreamer.h index c987bc2426e9f..443b7eb4a7af8 100644 --- a/llvm/include/llvm/MC/MCObjectStreamer.h +++ b/llvm/include/llvm/MC/MCObjectStreamer.h @@ -117,7 +117,8 @@ class MCObjectStreamer : public MCStreamer { const MCExpr *Value) override; void emitValueImpl(const MCExpr *Value, unsigned Size, SMLoc Loc = SMLoc()) override; - void emitULEB128Value(const MCExpr *Value) override; + void emitULEB128Value(const MCExpr *Value, + unsigned PadTo = 0) override; void emitSLEB128Value(const MCExpr *Value) override; void emitWeakReference(MCSymbol *Alias, const MCSymbol *Target) override; void changeSection(MCSection *Section, uint32_t Subsection = 0) override; diff --git a/llvm/include/llvm/MC/MCStreamer.h b/llvm/include/llvm/MC/MCStreamer.h index 8f2e137ea0c84..8fa56dad45971 100644 --- a/llvm/include/llvm/MC/MCStreamer.h +++ b/llvm/include/llvm/MC/MCStreamer.h @@ -750,7 +750,8 @@ class LLVM_ABI MCStreamer { emitIntValue(Value, Size); } - virtual void emitULEB128Value(const MCExpr *Value); + virtual void emitULEB128Value(const MCExpr *Value, + unsigned PadTo = 0); virtual void emitSLEB128Value(const MCExpr *Value); diff --git a/llvm/lib/MC/MCAsmStreamer.cpp b/llvm/lib/MC/MCAsmStreamer.cpp index da0d99e70d9ea..17943c7b4227e 100644 --- a/llvm/lib/MC/MCAsmStreamer.cpp +++ b/llvm/lib/MC/MCAsmStreamer.cpp @@ -256,7 +256,8 @@ class MCAsmStreamer final : public MCStreamer { void emitIntValueInHex(uint64_t Value, unsigned Size) override; void emitIntValueInHexWithPadding(uint64_t Value, unsigned Size) override; - void emitULEB128Value(const MCExpr *Value) override; + void emitULEB128Value(const MCExpr *Value, + unsigned PadTo) override; void emitSLEB128Value(const MCExpr *Value) override; @@ -1402,13 +1403,21 @@ void MCAsmStreamer::emitValueImpl(const MCExpr *Value, unsigned Size, } } -void MCAsmStreamer::emitULEB128Value(const MCExpr *Value) { +void MCAsmStreamer::emitULEB128Value(const MCExpr *Value, + unsigned PadTo) { int64_t IntValue; if (Value->evaluateAsAbsolute(IntValue)) { emitULEB128IntValue(IntValue); return; } - OS << "\t.uleb128 "; + if (!PadTo) + OS << "\t.uleb128 "; + else { + // A padding size has been specified. For now, all that is supported is a 5- + // byte LEB, which is an int32. + assert(PadTo == 5); + OS << "\t.uleb128_int32 "; + } Value->print(OS, MAI); EmitEOL(); } diff --git a/llvm/lib/MC/MCObjectStreamer.cpp b/llvm/lib/MC/MCObjectStreamer.cpp index e3d5a5a9a1327..b3bab04fc341b 100644 --- a/llvm/lib/MC/MCObjectStreamer.cpp +++ b/llvm/lib/MC/MCObjectStreamer.cpp @@ -259,13 +259,30 @@ void MCObjectStreamer::emitLabelAtPos(MCSymbol *Symbol, SMLoc Loc, Symbol->setOffset(Offset); } -void MCObjectStreamer::emitULEB128Value(const MCExpr *Value) { +void MCObjectStreamer::emitULEB128Value(const MCExpr *Value, + unsigned PadTo) { int64_t IntValue; + // Avoid fixups when possible. if (Value->evaluateAsAbsolute(IntValue, getAssemblerPtr())) { emitULEB128IntValue(IntValue); return; } - insert(getContext().allocFragment(*Value, false)); + + if (!PadTo) { + // Emit the Value as best we can without padding or a fixup. + insert(getContext().allocFragment(*Value, false)); + return; + } + + // Use the proper fixup from the specific assembler backend. + const MCAsmBackend &MAB = getAssembler().getBackend(); + MCFixupKind Fixup = MAB.getULEB128Fixup(PadTo); + + // Use the given padding and fixup. + MCDataFragment *DF = getOrCreateDataFragment(); + DF->getFixups().push_back(MCFixup::create( + DF->getContents().size(), Value, Fixup)); + DF->appendContents(PadTo, 0); } void MCObjectStreamer::emitSLEB128Value(const MCExpr *Value) { diff --git a/llvm/lib/MC/MCStreamer.cpp b/llvm/lib/MC/MCStreamer.cpp index d70639b7bfe20..cce2faf144242 100644 --- a/llvm/lib/MC/MCStreamer.cpp +++ b/llvm/lib/MC/MCStreamer.cpp @@ -1323,7 +1323,8 @@ void MCStreamer::emitBinaryData(StringRef Data) { emitBytes(Data); } void MCStreamer::emitValueImpl(const MCExpr *Value, unsigned Size, SMLoc Loc) { visitUsedExpr(*Value); } -void MCStreamer::emitULEB128Value(const MCExpr *Value) {} +void MCStreamer::emitULEB128Value(const MCExpr *Value, + unsigned PadTo) {} void MCStreamer::emitSLEB128Value(const MCExpr *Value) {} void MCStreamer::emitFill(const MCExpr &NumBytes, uint64_t Value, SMLoc Loc) {} void MCStreamer::emitFill(const MCExpr &NumValues, int64_t Size, int64_t Expr, diff --git a/llvm/lib/Target/Mips/MipsDelaySlotFiller.cpp b/llvm/lib/Target/Mips/MipsDelaySlotFiller.cpp index b13394a607f6a..5c1a020299564 100644 --- a/llvm/lib/Target/Mips/MipsDelaySlotFiller.cpp +++ b/llvm/lib/Target/Mips/MipsDelaySlotFiller.cpp @@ -875,7 +875,7 @@ MipsDelaySlotFiller::selectSuccBB(MachineBasicBlock &B) const { if (B.succ_empty()) return nullptr; - // Select the successor with the larget edge weight. + // Select the successor with the largest edge weight. auto &Prob = getAnalysis().getMBPI(); MachineBasicBlock *S = *llvm::max_element(B.successors(), [&](const MachineBasicBlock *Dst0, diff --git a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp index 9649381f07b14..624ce3c795b8a 100644 --- a/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp +++ b/llvm/lib/Target/WebAssembly/AsmParser/WebAssemblyAsmParser.cpp @@ -1116,6 +1116,15 @@ class WebAssemblyAsmParser final : public MCTargetAsmParser { return expect(AsmToken::EndOfStatement, "EOL"); } + if (DirectiveID.getString() == ".uleb128_i32") { + const MCExpr *Val; + SMLoc End; + if (Parser.parseExpression(Val, End)) + return error("Cannot parse .uleb128_i32 expression: ", Lexer.getTok()); + Out.emitULEB128Value(Val, 5); + return expect(AsmToken::EndOfStatement, "EOL"); + } + if (DirectiveID.getString() == ".asciz") { if (checkDataSection()) return ParseStatus::Failure; diff --git a/llvm/lib/Target/WebAssembly/CMakeLists.txt b/llvm/lib/Target/WebAssembly/CMakeLists.txt index 1e83cbeac50d6..7795d6f9ec974 100644 --- a/llvm/lib/Target/WebAssembly/CMakeLists.txt +++ b/llvm/lib/Target/WebAssembly/CMakeLists.txt @@ -18,6 +18,7 @@ add_llvm_target(WebAssemblyCodeGen WebAssemblyAddMissingPrototypes.cpp WebAssemblyArgumentMove.cpp WebAssemblyAsmPrinter.cpp + WebAssemblyBranchHinting.cpp WebAssemblyCFGStackify.cpp WebAssemblyCleanCodeAfterTrap.cpp WebAssemblyCFGSort.cpp diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp index 7bc672c069476..8cbe0207e8e91 100644 --- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp +++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp @@ -46,6 +46,8 @@ class WebAssemblyAsmBackend final : public MCAsmBackend { bool writeNopData(raw_ostream &OS, uint64_t Count, const MCSubtargetInfo *STI) const override; + + MCFixupKind getULEB128Fixup(unsigned PadTo) const override; }; MCFixupKindInfo @@ -78,6 +80,12 @@ bool WebAssemblyAsmBackend::writeNopData(raw_ostream &OS, uint64_t Count, return true; } +MCFixupKind WebAssemblyAsmBackend::getULEB128Fixup(unsigned PadTo) const { + // Only 32-bit is supported for now, which is padded to 5 bytes. + assert(PadTo == 5); + return MCFixupKind(WebAssembly::fixup_uleb128_i32); +} + void WebAssemblyAsmBackend::applyFixup(const MCFragment &, const MCFixup &Fixup, const MCValue &Target, MutableArrayRef Data, diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.h b/llvm/lib/Target/WebAssembly/WebAssembly.h index 17481d77c120a..a331d755b9162 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.h +++ b/llvm/lib/Target/WebAssembly/WebAssembly.h @@ -31,6 +31,7 @@ ModulePass *createWebAssemblyFixFunctionBitcasts(); FunctionPass *createWebAssemblyOptimizeReturned(); FunctionPass *createWebAssemblyLowerRefTypesIntPtrConv(); FunctionPass *createWebAssemblyRefTypeMem2Local(); +FunctionPass *createWebAssemblyBranchHinting(); // ISel and immediate followup passes. FunctionPass *createWebAssemblyISelDag(WebAssemblyTargetMachine &TM, @@ -88,6 +89,7 @@ void initializeWebAssemblyRegNumberingPass(PassRegistry &); void initializeWebAssemblyRegStackifyPass(PassRegistry &); void initializeWebAssemblyReplacePhysRegsPass(PassRegistry &); void initializeWebAssemblySetP2AlignOperandsPass(PassRegistry &); +void initializeWebAssemblyBranchHintingPass(PassRegistry &); namespace WebAssembly { enum TargetIndex { diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp index c61ed3c7d5d81..54bc11b2cbd2b 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp @@ -14,6 +14,7 @@ //===----------------------------------------------------------------------===// #include "WebAssemblyAsmPrinter.h" +#include "MCTargetDesc/WebAssemblyFixupKinds.h" #include "MCTargetDesc/WebAssemblyMCExpr.h" #include "MCTargetDesc/WebAssemblyMCTargetDesc.h" #include "MCTargetDesc/WebAssemblyTargetStreamer.h" @@ -32,6 +33,7 @@ #include "llvm/BinaryFormat/Wasm.h" #include "llvm/CodeGen/Analysis.h" #include "llvm/CodeGen/AsmPrinter.h" +#include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/CodeGen/MachineConstantPool.h" #include "llvm/CodeGen/MachineInstr.h" #include "llvm/CodeGen/MachineModuleInfoImpls.h" @@ -441,6 +443,7 @@ void WebAssemblyAsmPrinter::emitEndOfAsmFile(Module &M) { EmitProducerInfo(M); EmitTargetFeatures(M); EmitFunctionAttributes(M); + EmitBranchHints(M); } void WebAssemblyAsmPrinter::EmitProducerInfo(Module &M) { @@ -599,6 +602,70 @@ void WebAssemblyAsmPrinter::EmitFunctionAttributes(Module &M) { } } +void WebAssemblyAsmPrinter::EmitBranchHints(Module &M) { + if (AllFuncBranchHints.empty()) + return; + + MCSectionWasm *BranchHintSection = OutContext.getWasmSection( + ".custom_section.metadata.code.branch_hint", + SectionKind::getMetadata()); + + OutStreamer->pushSection(); + OutStreamer->switchSection(BranchHintSection); + + // Number of functions with hints. We pad this to 5 bytes to make the linker's + // life easier: given multiple Branch Hint sections, wasm-ld will by default + // simply concatenate them, just like any other custom section. That would end + // up with + // + // [num functions_1] : 5 byte LEB + // [..data_1..] + // [num functions_2] : 5 byte LEB + // [..data_2..] + // .. + // [num functions_N] : 5 byte LEB + // [..data_N..] + // + // To get this to the proper form the linker should emit we can trample the + // number of functions at the very start, and not emit the others: + // + // [num functions_1 + _2 + .. + _N] : 5 byte LEB + // [..data_1..] + // [..data_2..] + // .. + // [..data_N..] + // + // TODO: fix existing tests for this + OutStreamer->emitULEB128IntValue(AllFuncBranchHints.size(), 5); + + for (auto& FuncHints : AllFuncBranchHints) { + auto* FuncSymbol = getSymbol(FuncHints.F); + // The function index. We use S_None because WasmObjectWriter has + // case WebAssembly::S_FUNCINDEX: + // return wasm::R_WASM_FUNCTION_INDEX_I32; + // i.e. S_FUNCINDEX is always relocated as an I32, but we need an LEB. Using + // None gets us to pick the relocation based on the fixup, and + // MCObjectStreamer will emit a proper LEB fixup for emitULEB128Value. + OutStreamer->emitULEB128Value( + MCSymbolRefExpr::create(FuncSymbol, WebAssembly::S_None, OutContext), + 5); + + // The number of hints in the function. + OutStreamer->emitULEB128IntValue(FuncHints.Hints.size()); + + for (auto& Hint : FuncHints.Hints) { + OutStreamer->emitAbsoluteSymbolDiffAsULEB128(Hint.Label, FuncSymbol); + + // Hints are of size 1. + OutStreamer->emitULEB128IntValue(1); + // The hint itself, likely or not. + OutStreamer->emitULEB128IntValue(Hint.IsLikely); + } + } + + OutStreamer->popSection(); +} + void WebAssemblyAsmPrinter::emitConstantPool() { emitDecls(*MMI->getModule()); assert(MF->getConstantPool()->getConstants().empty() && @@ -609,6 +676,94 @@ void WebAssemblyAsmPrinter::emitJumpTableInfo() { // Nothing to do; jump tables are incorporated into the instruction stream. } +std::optional WebAssemblyAsmPrinter::getBranchHint(const MachineInstr& MI) { + // WebAssemblyLowerBrUnless should have run before us, removing all BR_UNLESS, + // which makes things simpler for us here. + assert(MI.getOpcode() != WebAssembly::BR_UNLESS); + + if (MI.getOpcode() != WebAssembly::BR_IF) + return {}; + + // We only handle the simple case of our being the last terminator of the + // block. If there are other terminators after us, things may be complicated. + auto *ParentMBB = MI.getParent(); + const MachineInstr *LastTerminator = nullptr; + for (auto& Terminator : ParentMBB->terminators()) + LastTerminator = &Terminator; + if (LastTerminator != &MI) + return {}; + + // This is a BR. It has two successors, and perhaps branch probability + // info between them. Finding the successor blocks is not trivial, since we + // run after CFGstackify (and even WebAssemblyInstrInfo::analyzeBranch returns + // that it cannot analyze branch targets). First, find the two successors of + // the parent block of this BR_IF. (If there are three successors, due to an + // additional unwind (EH pad) successor, give up on this hint.) + if (ParentMBB->succ_size() != 2) + return {}; + auto iter = ParentMBB->succ_begin(); + MachineBasicBlock* MBB1 = *iter; + iter++; + MachineBasicBlock* MBB2 = *iter; + + // Find the fallthrough, that is, the block right after us, where control flow + // goes if we do not branch. + auto ParentMBBI = ParentMBB->getIterator(); + ++ParentMBBI; + // A block with a BR_IF must have something after it. + assert(ParentMBBI != MF->end()); + auto *Fallthrough = &*ParentMBBI; + // In some corner cases, wasm's structured control flow makes it hard to infer + // control flow, like this: + // + // bb1: + // ;; successor: if.true, otherwise + // .. + // br_if $if.true + // br $otherwise + // + // ehpad: + // + // If `br $otherwise` is optimized out, it would look like we fall through + // to the ehpad (since it is the block physically after us), but in wasm's + // structured control flow that is not the case. Rather than try to find the + // true fallthrough, give up on a hint in any case where the fallthrough is + // not one of our block's successors. + if (Fallthrough != MBB1 && Fallthrough != MBB2) + return {}; + // We ruled out complex cases, so what is left is simple control flow, and the + // false destination, i.e. where we go if the br_if condition is false, is the + // the fallthrough. + auto *FalseDest = Fallthrough; + // The true destination, i.e. if the condition is true, is the other one. + MachineBasicBlock *TrueDest = (FalseDest == MBB1 ? MBB2 : MBB1); + + // Find the probability of branching. + const MachineBranchProbabilityInfo *MBPI = + &getAnalysis().getMBPI(); + BranchProbability ProbTarget = MBPI->getEdgeProbability(ParentMBB, TrueDest); + + // Wasm branch hints are boolean, and each one takes space in the binary, so + // we do not want to emit hints for trivial things like 55%/45%. Err on the + // side of caution for now and focus on really powerful hints (such as those + // given by __builtin_expect), and ignore hints of 100%/0% (code leading to an + // unreachable; we emit an unreachable for them already, which is good enough + // for both toolchains and VMs). + // We detect __builtin_expected-generated hints as follows. XXX horrible + auto isFromExpected = [](BranchProbability Prob) { + // Such hints appear as pairs of + // 0x00106035 / 0x80000000 = 0.05% , 0x7fef9fcb / 0x80000000 = 99.95% + return (Prob.getNumerator() == 0x00106035 || Prob.getNumerator() == 0x7fef9fcb) && + Prob.getDenominator() == 0x80000000; + }; + if (!isFromExpected(ProbTarget)) + return {}; + + const BranchProbability Half = BranchProbability(1, 2); + assert(ProbTarget != Half); + return ProbTarget > Half; +} + void WebAssemblyAsmPrinter::emitFunctionBodyStart() { const Function &F = MF->getFunction(); SmallVector ResultVTs; @@ -635,10 +790,26 @@ void WebAssemblyAsmPrinter::emitFunctionBodyStart() { getTargetStreamer()->emitLocal(Locals); AsmPrinter::emitFunctionBodyStart(); + } void WebAssemblyAsmPrinter::emitInstruction(const MachineInstr *MI) { LLVM_DEBUG(dbgs() << "EmitInstruction: " << *MI << '\n'); + + if (auto Hint = getBranchHint(*MI)) { + // Create a temp symbol for this instruction, so we can refer to it. + MCSymbol *InstructionSymbol = OutContext.createTempSymbol(); + OutStreamer->emitLabel(InstructionSymbol); + + // Stash the hint for later, on the proper function. We will emit the branch + // hints section at the end, when we know all the information we need. + Function *F = &MF->getFunction(); + if (AllFuncBranchHints.empty() || AllFuncBranchHints.back().F != F) { + AllFuncBranchHints.emplace_back(FuncBranchHints{F, {}}); + } + AllFuncBranchHints.back().Hints.emplace_back(BranchHint{InstructionSymbol, *Hint}); + } + WebAssembly_MC::verifyInstructionPredicates(MI->getOpcode(), Subtarget->getFeatureBits()); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h index 46063bbe0fba1..79cea2c87718a 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h @@ -28,6 +28,25 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter { WebAssemblyFunctionInfo *MFI; bool signaturesEmitted = false; + // A branch hint for an instruction. We gather them all (& by function) so we + // can emit the section at the end, knowing how many hints are present (which + // must be emitted before the hints, so we can't do it in a streaming manner). + // And we must gather during emitInstruction(), as we must emit a label for + // each branch we want to annotate (so we can refer to it). + struct BranchHint { + // The location the hint refers to. + MCSymbol *Label; + // Whether the branch there is likely. + bool IsLikely; + }; + + struct FuncBranchHints { + Function *F; + std::vector Hints; + }; + + std::vector AllFuncBranchHints; + public: explicit WebAssemblyAsmPrinter(TargetMachine &TM, std::unique_ptr Streamer) @@ -59,6 +78,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter { void EmitProducerInfo(Module &M); void EmitTargetFeatures(Module &M); void EmitFunctionAttributes(Module &M); + void EmitBranchHints(Module &M); void emitSymbolType(const MCSymbolWasm *Sym); void emitGlobalVariable(const GlobalVariable *GV) override; void emitJumpTableInfo() override; @@ -78,6 +98,10 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter { bool &InvokeDetected); MCSymbol *getOrCreateWasmSymbol(StringRef Name); void emitDecls(const Module &M); + + // See if there is a branch hint for an instruction, and if so, if it is true + // or false. + std::optional getBranchHint(const MachineInstr& MI); }; } // end namespace llvm diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyBranchHinting.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyBranchHinting.cpp new file mode 100644 index 0000000000000..7652b1ded63f1 --- /dev/null +++ b/llvm/lib/Target/WebAssembly/WebAssemblyBranchHinting.cpp @@ -0,0 +1,95 @@ +//===-- WebAssemblyBranchHinting.cpp - Filter branch hints from LLVM IR --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// Remove LLVM IR branch_weights that we do not want in wasm. Wasm branch hints +/// are boolean, and always add size to the binary, so we only want to keep +/// certainly-useful hints. Specifically, we keep hints that began as +/// __builtin_expect in the source. +/// +//===----------------------------------------------------------------------===// + +#include "WebAssembly.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +//? +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" + + +using namespace llvm; + +#define DEBUG_TYPE "wasm-branch-hint" + +namespace { +class WebAssemblyBranchHinting final : public FunctionPass, + public InstVisitor { + StringRef getPassName() const override { + return "WebAssembly Branch Hint"; + } + + bool runOnFunction(Function &F) override; + +public: + static char ID; + WebAssemblyBranchHinting() : FunctionPass(ID) {} + + void visitBranchInst(BranchInst &I); +}; +} // End anonymous namespace + +char WebAssemblyBranchHinting::ID = 0; +INITIALIZE_PASS(WebAssemblyBranchHinting, DEBUG_TYPE, + "Filter WebAssembly branch hints", + false, false) + +FunctionPass *llvm::createWebAssemblyBranchHinting() { + return new WebAssemblyBranchHinting(); +} + +void WebAssemblyBranchHinting::visitBranchInst(BranchInst &I) { + I.eraseMetadataIf([](unsigned MDKind, MDNode *Node) { + // Look for profiling metadata. + if (MDKind != LLVMContext::MD_prof) + return false; + + // Check for profiling metadata of "branch_weights". + if (Node->getNumOperands() == 0) + return false; + MDString *MDName = dyn_cast(Node->getOperand(0)); + if (!MDName || MDName->getString() != "branch_weights") + return false; + + // This is a branch weights metadata. Keep it only if it has the "expected" + // string. TODO explain + if (Node->getNumOperands() >= 2) { + MDString *MDName = dyn_cast(Node->getOperand(1)); + if (MDName && MDName->getString() == "expected") + errs() << "waka keeping: " << *Node << '\n'; + return false; + } + + // Discard anything else. + errs() << "waka dumping: " << *Node << '\n'; + return true; + }); +} + +bool WebAssemblyBranchHinting::runOnFunction(Function &F) { + LLVM_DEBUG(dbgs() << "********** Filter wasm branch hints **********\n" + "********** Function: " + << F.getName() << '\n'); + + visit(F); + return true; +} diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp index adb446b20ebf5..7bd91147dc458 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyTargetMachine.cpp @@ -498,6 +498,9 @@ void WebAssemblyPassConfig::addIRPasses() { // Expand indirectbr instructions to switches. addPass(createIndirectBrExpandPass()); + // TODO flag (here and elsewhere) + addPass(createWebAssemblyBranchHinting()); + TargetPassConfig::addIRPasses(); } diff --git a/llvm/test/CodeGen/WebAssembly/branch-hints.ll b/llvm/test/CodeGen/WebAssembly/branch-hints.ll new file mode 100644 index 0000000000000..51fd390259238 --- /dev/null +++ b/llvm/test/CodeGen/WebAssembly/branch-hints.ll @@ -0,0 +1,47 @@ +; RUN: llc -mtriple=wasm32-unknown-unknown -filetype=asm -o - < %s | FileCheck %s + +define i32 @bw_bh_test(i32 %a, i32 %b) { +entry: + %1 = icmp ult i32 %a, %b + br i1 %1, label %fail, label %success, !prof !0 + +; The weights below mean we are far more likely to go to %fail and return -1. +; Codegen will emit the -1 first (the same as appearing here): + +; CHECK: i32.const -1 +; CHECK: i32.const 0 + +; Given that layout, we must emit a hint of 0, below, for the value of the hint: +; the VM can assume the condition of the br_if is likely *false*, which means we +; likely fall through to return -1. + +; CHECK: .section .custom_section.metadata.code.branch_hint,"",@ + +; Number of functions with hints (1, padded LEB to 5 bytes +; CHECK-NEXT: .asciz "\201\200\200\200" + +; Function with the hint. +; CHECK-NEXT: .uleb128_int32 bw_bh_test + +; Number of hints in function. +; CHECK-NEXT: .int8 1 + +; Offset of the hint. +; CHECK-NEXT: .uleb128 .Ltmp0-bw_bh_test + +; Size of the hint. +; CHECK-NEXT: .int8 1 + +; Value of the hint. +; CHECK-NEXT: .int8 0 + +fail: + ret i32 -1 + +success: + ret i32 0 +} + +!0 = !{!"branch_weights", !"expected", i32 2000, i32 1} + +; TODO: a test that starts as asm, and checks either disassembly/objdump or YAML output. Something like llvm/test/MC/WebAssembly/debuginfo-relocs.s or the similar tests in there. diff --git a/llvm/test/MC/WebAssembly/branch-hints.s b/llvm/test/MC/WebAssembly/branch-hints.s new file mode 100644 index 0000000000000..6c26ad07cc810 --- /dev/null +++ b/llvm/test/MC/WebAssembly/branch-hints.s @@ -0,0 +1,45 @@ +# RUN: llvm-mc -filetype=obj -triple=wasm32-unknown-unknown -o %t.o %s +# RUN: obj2yaml %t.o | FileCheck %s + + .file "c.ll" + .functype bw_bh_test (i32, i32) -> (i32) + .section .text.bw_bh_test,"",@ + .globl bw_bh_test # -- Begin function bw_bh_test + .type bw_bh_test,@function +bw_bh_test: # @bw_bh_test + .functype bw_bh_test (i32, i32) -> (i32) +# %bb.0: + block + local.get 0 + local.get 1 + i32.ge_u +.Ltmp0: + br_if 0 # 0: down to label0 +# %bb.1: # %fail + i32.const -1 + return +.LBB0_2: # %success + end_block # label0: + i32.const 0 + # fallthrough-return + end_function + # -- End function + .section .text.bw_bh_test,"",@ + .section .custom_section.metadata.code.branch_hint,"",@ + .int8 1 + .uleb128_i32 bw_bh_test + .int8 1 + .uleb128 .Ltmp0-bw_bh_test + .int8 1 + .int8 0 + .section .text.bw_bh_test,"",@ + +## Test handling of ULEB128 fields in the branch hints section. +# CHECK: Name: metadata.code.branch_hint +# CHECK-NEXT: Payload: '01808080800001080100' +# ^^ one function +# ^^^^^^^^^^ LEB of function index - +# ^^ one hint in function +# ^^ offset 8 +# ^^ hint size 1 +# ^^ hint value: 0